├── BALoss ├── requirements.txt ├── cpp │ ├── wrap_py2.cpp │ ├── MBD_distance_2d.h │ ├── wrap_py3.cpp │ ├── util.h │ ├── MBD_distance.cpp │ ├── util.cpp │ └── MBD_distance_2d.cpp ├── demo.py ├── setup.py └── loss.py ├── images ├── results.png └── pipeline.png └── README.md /BALoss/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.1 2 | -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onvungocminh/MBD_BAL/HEAD/images/results.png -------------------------------------------------------------------------------- /images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onvungocminh/MBD_BAL/HEAD/images/pipeline.png -------------------------------------------------------------------------------- /BALoss/cpp/wrap_py2.cpp: -------------------------------------------------------------------------------- 1 | #include "MBD_distance.cpp" 2 | 3 | PyMODINIT_FUNC 4 | initMBD(void) { 5 | (void) Py_InitModule("MBD", Methods); 6 | import_array(); 7 | } 8 | -------------------------------------------------------------------------------- /BALoss/cpp/MBD_distance_2d.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | using namespace std; 4 | 5 | struct Point2D 6 | { 7 | // float distance; 8 | int w; 9 | int h; 10 | }; 11 | 12 | void MBD_cut(const unsigned char * img, const unsigned char * seeds, const unsigned char * destination, unsigned char * distance, 13 | int height, int width); -------------------------------------------------------------------------------- /BALoss/cpp/wrap_py3.cpp: -------------------------------------------------------------------------------- 1 | #include "MBD_distance.cpp" 2 | 3 | 4 | static struct PyModuleDef cMBDDis = 5 | { 6 | PyModuleDef_HEAD_INIT, 7 | "MBD", /* name of module */ 8 | "", /* module documentation, may be NULL */ 9 | -1, /* size of per-interpreter state of the module, or -1 if the module keeps state in global variables. */ 10 | Methods 11 | }; 12 | 13 | 14 | PyMODINIT_FUNC PyInit_MBD(void) { 15 | import_array(); 16 | return PyModule_Create(&cMBDDis); 17 | } 18 | -------------------------------------------------------------------------------- /BALoss/cpp/util.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | using namespace std; 4 | 5 | // for 2d images 6 | template 7 | T get_pixel(const T * data, int height, int width, int h, int w); 8 | 9 | template 10 | std::vector get_pixel_vector(const T * data, int height, int width, int channel, int h, int w); 11 | 12 | template 13 | void set_pixel(T * data, int height, int width, int h, int w, T value); 14 | 15 | // for 3d images 16 | template 17 | T get_pixel(const T * data, int depth, int height, int width, int d, int h, int w); 18 | 19 | template 20 | std::vector get_pixel_vector(const T * data, int depth, int height, int width, int channel, int d, int h, int w); 21 | 22 | template 23 | void set_pixel(T * data, int depth, int height, int width, int d, int h, int w, T value); -------------------------------------------------------------------------------- /BALoss/demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import cv2 6 | import MBD 7 | from loss import * 8 | import scipy.io 9 | from PIL import Image 10 | import pdb 11 | 12 | 13 | img_file = 'img.png' 14 | img = Image.open(img_file) 15 | img = np.array(img, dtype=np.float32) 16 | seed_file = 'seed.mat' 17 | seed = scipy.io.loadmat(seed_file)["seed"] 18 | boundary_file = 'boundary.mat' 19 | boundary = scipy.io.loadmat(boundary_file)["boundary"] 20 | contour_file = 'contour.png' 21 | contour = Image.open(contour_file) 22 | contour = contour.convert('L') 23 | contour = np.array(contour)/255. 24 | contour = contour.astype(np.uint8) 25 | 26 | 27 | img = torch.from_numpy(img.copy()).float() 28 | img = img.unsqueeze(axis=0) 29 | contour = torch.from_numpy(np.array([contour])).float() 30 | seed = torch.from_numpy(np.array([seed])).squeeze(axis=0).float() 31 | boundary = torch.from_numpy(np.array([boundary])).squeeze(axis=0).float() 32 | 33 | img, contour, seed, boundary = img.cuda(), contour.cuda(), seed.cuda(), boundary.cuda() 34 | 35 | BAloss = DAHU_loss(img, seed, boundary, contour) 36 | print(BAloss) -------------------------------------------------------------------------------- /BALoss/setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import setuptools 4 | from distutils.core import setup 5 | from distutils.extension import Extension 6 | 7 | package_name = 'MBD' 8 | module_name = 'MBD' 9 | version = sys.version[0] 10 | wrap_source = './cpp/wrap_py{0:}.cpp'.format(version) 11 | module1 = Extension(module_name, 12 | include_dirs = [np.get_include(), './cpp'], 13 | sources = ['./cpp/util.cpp', 14 | './cpp/MBD_distance_2d.cpp', 15 | './cpp/MBD_distance.cpp', 16 | wrap_source]) 17 | 18 | # Get the summary 19 | description = 'An open-source toolkit to calculate BAloss' 20 | 21 | 22 | setup( 23 | name = package_name, 24 | version = "0.1.7", 25 | author ='Minh', 26 | author_email = 'vungocminh@gmail.com', 27 | description = description, 28 | packages = setuptools.find_packages(), 29 | ext_modules = [module1], 30 | python_requires = '>=3.6', 31 | ) 32 | 33 | # to build, run python setup.py build or python setup.py build_ext --inplace 34 | # to install, run python setup.py install 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introducing the Boundary-Aware loss for deep image segmentation 2 | 3 | Official github repo. for BMVC 2021 4 | 5 | Authors: Minh Ôn Vũ Ngọc, Yizi Chen et al. 6 | 7 | ## Abstract 8 | 9 | Most contemporary supervised image segmentation methods do not preserve the initial topology of the given input (like the closeness of the contours). One can generally remark that edge points have been inserted or removed when the binary prediction and the ground truth are compared. This can be critical when accurate localization of multiple interconnected objects is required. In this paper, we present a new loss function, called, **Boundary-Aware loss** (BALoss), based on the **Minimum Barrier Distance** (MBD) cut algorithm. It is able to locate what we call the **leakage pixels** and to encode the boundary information coming from the given ground truth. Thanks to this adapted loss, we are able to significantly refine the quality of the predicted boundaries during the learning procedure. Furthermore, our loss function is differentiable and can be applied to any kind of neural network used in image processing. We apply this loss function on the standard U-Net and DC U-Net on Electron Microscopy datasets. They are well-known to be challenging due to their high noise level and to the close or even connected objects covering the image space. Our segmentation performance, in terms of Variation of Information (VOI) and Adapted Rank Index (ARI), are very promising and lead to ~15% better scores of VOI and ~5% better scores of ARI than the state-of-the-art. 10 | 11 | ## Our Proposed pipeline 12 | 13 | ![Results](./images/pipeline.png) 14 | 15 | ## Visualize results 16 | 17 | ![Results](./images/results.png) 18 | 19 | ## How to use this loss 20 | 21 | Compile the loss: 22 | ```bash 23 | python setup.py build 24 | python setup.py install 25 | ``` 26 | 27 | A short demo is also available in BALoss/demo.py: 28 | ```bash 29 | python demo.py 30 | ``` 31 | 32 | ## Some examples are given in this folder: 33 | https://drive.google.com/drive/folders/1ue2Eq_UGzY-21v7gUHVlhFIppn-orJFw?usp=sharing 34 | -------------------------------------------------------------------------------- /BALoss/cpp/MBD_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "numpy/arrayobject.h" 4 | #include "MBD_distance_2d.h" 5 | #include 6 | using namespace std; 7 | 8 | 9 | static PyObject * 10 | MBD_cut_wrapper(PyObject *self, PyObject *args) 11 | { 12 | PyObject *I=NULL, *Seed=NULL, *Destination=NULL; 13 | PyArrayObject *arr_I=NULL, *arr_Seed=NULL, *arr_Destination=NULL; 14 | 15 | if (!PyArg_ParseTuple(args, "OOO", &I, &Seed, &Destination)) return NULL; 16 | 17 | arr_I = (PyArrayObject*)PyArray_FROM_OTF(I, NPY_UINT8, NPY_IN_ARRAY); 18 | if (arr_I == NULL) return NULL; 19 | 20 | arr_Seed = (PyArrayObject*)PyArray_FROM_OTF(Seed, NPY_UINT8, NPY_IN_ARRAY); 21 | if (arr_Seed == NULL) return NULL; 22 | 23 | arr_Destination = (PyArrayObject*)PyArray_FROM_OTF(Destination, NPY_UINT8, NPY_IN_ARRAY); 24 | if (arr_Destination == NULL) return NULL; 25 | 26 | int nd = PyArray_NDIM(arr_I); //number of dimensions 27 | npy_intp * shape = PyArray_DIMS(arr_I); // npy_intp array of length nd showing length in each dim. 28 | npy_intp * shape_seed = PyArray_DIMS(arr_Seed); 29 | npy_intp * shape_destination = PyArray_DIMS(arr_Destination); 30 | 31 | // cout<<"input shape "; 32 | // for(int i=0; idata, (const unsigned char *)arr_Seed->data, 53 | (const unsigned char *)arr_Destination->data, (unsigned char *) shortestpath->data, shape[0], shape[1]); 54 | 55 | Py_DECREF(arr_I); 56 | Py_DECREF(arr_Seed); 57 | Py_DECREF(arr_Destination); 58 | 59 | //Py_INCREF(distance); 60 | return PyArray_Return(shortestpath); 61 | } 62 | 63 | static PyMethodDef Methods[] = { 64 | {"MBD_cut", MBD_cut_wrapper, METH_VARARGS, "computing 2d MBD cut"}, 65 | // {NULL, NULL, 0, NULL} 66 | }; 67 | -------------------------------------------------------------------------------- /BALoss/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import cv2 6 | import MBD 7 | 8 | def cross_entropy_loss2d(inputs, targets, cuda=True, balance=1.1): 9 | """ 10 | :param inputs: inputs is a 4 dimensional data nx1xhxw 11 | :param targets: targets is a 3 dimensional data nx1xhxw 12 | :return: 13 | """ 14 | n, c, h, w = inputs.size() 15 | weights = np.zeros((n, c, h, w)) 16 | for i in range(n): 17 | t = targets[i, :, :].cpu().data.numpy() 18 | pos = (t == 1).sum() 19 | neg = (t == 0).sum() 20 | valid = neg + pos 21 | weights[i, t == 1] = neg * 1. / valid 22 | weights[i, t == 0] = pos * balance / valid 23 | weights = torch.Tensor(weights) 24 | if cuda: 25 | weights = weights.cuda() 26 | loss = nn.BCEWithLogitsLoss(weights, reduction='mean')(inputs, targets) 27 | return loss 28 | 29 | 30 | def DAHU_loss(inputs, seed, boundary, labels): 31 | # target_image: the boundary total 32 | EPM = inputs.squeeze() 33 | labels = labels.squeeze() 34 | 35 | seed = seed.squeeze(axis = 0).squeeze(axis = 0) 36 | boundary = boundary.squeeze(axis=0).squeeze(axis = 0) 37 | 38 | d, _, _ = seed.shape 39 | loss_total = 0 40 | for i in range(d): 41 | loss = 0 42 | EPM_s, seed_s, boundary_s = EPM, seed[i,:,:], boundary[i,:,:] 43 | max_range = torch.max(seed_s).type(torch.uint8) 44 | seed_s = np.array(seed_s.detach().cpu().numpy()).astype(np.uint8) 45 | EPM_s = torch.sigmoid(EPM_s) 46 | EPM_s_clone = EPM_s.clone() 47 | EPM_s_clone = np.array((EPM_s_clone*255).detach().cpu().numpy()).astype(np.uint8) 48 | 49 | if max_range>1: 50 | 51 | inside = np.array(seed_s == 1).astype(np.uint8)*255 52 | outside = np.array(seed_s > 1).astype(np.uint8)*255 53 | destination = np.array(seed_s).astype(np.uint8) 54 | contour = MBD.MBD_cut(EPM_s_clone,inside, outside) 55 | kernel = np.ones((7,7),np.uint8) 56 | contour = cv2.dilate(contour,kernel,iterations = 1) 57 | contour = (contour/255).astype(np.uint8) 58 | contour = torch.from_numpy(np.array([contour])).cuda() 59 | 60 | EPM_contour = (EPM_s * contour).unsqueeze(axis=0) 61 | label_s = (labels * contour).unsqueeze(axis=0) 62 | BCE_loss = cross_entropy_loss2d(EPM_contour, label_s, True, 1.1)*10 63 | loss = loss + BCE_loss 64 | 65 | else: 66 | loss = torch.tensor(0).type(torch.float64) 67 | loss_total = loss_total + loss # Accumulate sum 68 | loss_total = loss_total / d 69 | 70 | return loss_total -------------------------------------------------------------------------------- /BALoss/cpp/util.cpp: -------------------------------------------------------------------------------- 1 | #include "util.h" 2 | #include 3 | // for 2d images 4 | template 5 | T get_pixel(const T * data, int height, int width, int h, int w) 6 | { 7 | return data[h * width + w]; 8 | } 9 | 10 | template 11 | std::vector get_pixel_vector(const T * data, int height, int width, int channel, int h, int w) 12 | { 13 | std::vector pixel_vector(channel); 14 | for (int c = 0; c < channel; c++){ 15 | pixel_vector[c]= data[h * width * channel + w * channel + c]; 16 | } 17 | return pixel_vector; 18 | } 19 | 20 | template 21 | void set_pixel(T * data, int height, int width, int h, int w, T value) 22 | { 23 | data[h * width + w] = value; 24 | } 25 | 26 | template 27 | float get_pixel(const float * data, int height, int width, int h, int w); 28 | 29 | template 30 | int get_pixel(const int * data, int height, int width, int h, int w); 31 | 32 | template 33 | std::vector get_pixel_vector(const float * data, int height, int width, int channel, int h, int w); 34 | 35 | template 36 | unsigned char get_pixel(const unsigned char * data, int height, int width, int h, int w); 37 | 38 | 39 | template 40 | void set_pixel(float * data, int height, int width, int h, int w, float value); 41 | 42 | template 43 | void set_pixel(int * data, int height, int width, int h, int w, int value); 44 | 45 | template 46 | void set_pixel(unsigned char * data, int height, int width, int h, int w, unsigned char value); 47 | 48 | // for 3d images 49 | template 50 | T get_pixel(const T * data, int depth, int height, int width, int d, int h, int w) 51 | { 52 | return data[(d*height + h) * width + w]; 53 | } 54 | 55 | template 56 | std::vector get_pixel_vector(const T * data, int depth, int height, int width, int channel, int d, int h, int w) 57 | { 58 | std::vector pixel_vector(channel); 59 | for (int c = 0; c < channel; c++){ 60 | pixel_vector[c]= data[d*height*width*channel + h * width * channel + w * channel + c]; 61 | } 62 | return pixel_vector; 63 | } 64 | 65 | template 66 | void set_pixel(T * data, int depth, int height, int width, int d, int h, int w, T value) 67 | { 68 | data[(d*height + h) * width + w] = value; 69 | } 70 | 71 | template 72 | float get_pixel(const float * data, int depth, int height, int width, int d, int h, int w); 73 | 74 | template 75 | std::vector get_pixel_vector(const float * data, int depth, int height, int width, int channel, int d, int h, int w); 76 | 77 | template 78 | int get_pixel(const int * data, int depth, int height, int width, int d, int h, int w); 79 | 80 | template 81 | unsigned char get_pixel(const unsigned char * data, 82 | int depth, int height, int width, 83 | int d, int h, int w); 84 | 85 | 86 | template 87 | void set_pixel(float * data, int depth, int height, int width, int d, int h, int w, float value); 88 | 89 | template 90 | void set_pixel(int * data, int depth, int height, int width, int d, int h, int w, int value); 91 | 92 | template 93 | void set_pixel(unsigned char * data, 94 | int depth, int height, int width, 95 | int d, int h, int w, unsigned char value); -------------------------------------------------------------------------------- /BALoss/cpp/MBD_distance_2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "util.h" 6 | #include "MBD_distance_2d.h" 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | using namespace std; 13 | 14 | 15 | 16 | void MBD_cut(const unsigned char * img, const unsigned char * seeds, const unsigned char * destination, unsigned char * distance, 17 | int height, int width) 18 | { 19 | // Inside 20 | int * state = new int[height * width]; 21 | unsigned char * min_image = new unsigned char[height * width]; 22 | unsigned char * max_image = new unsigned char[height * width]; 23 | unsigned char * distance_in = new unsigned char[height * width]; 24 | unsigned char * distance_out = new unsigned char[height * width]; 25 | unsigned char * segment = new unsigned char[height * width]; 26 | unsigned char * dejavu = new unsigned char[height * width]; 27 | unsigned char * post_segment = new unsigned char[height * width]; 28 | 29 | vector > Q(256); 30 | queue Q_segment; 31 | 32 | // point state: 0--acceptd, 1--temporary, 2--far away 33 | // get initial accepted set and far away set 34 | 35 | unsigned char init_dis; 36 | int init_state; 37 | 38 | for(int h = 0; h < height; h++) 39 | { 40 | for (int w = 0; w < width; w++) 41 | { 42 | Point2D p; 43 | p.h = h; 44 | p.w = w; 45 | unsigned char seed_type = get_pixel(seeds, height, width, h, w); 46 | unsigned char img_value = get_pixel(img, height, width, h, w); 47 | 48 | if(seed_type > 100){ 49 | init_dis = 0; 50 | init_state = 1; 51 | Q[init_dis].push(p); 52 | Q_segment.push(p); 53 | set_pixel(post_segment, height, width, h, w, 255); 54 | set_pixel(distance, height, width, h, w, 0); 55 | set_pixel(dejavu, height, width, h, w, 1); 56 | set_pixel(distance_in, height, width, h, w, init_dis); 57 | set_pixel(state, height, width, h, w, init_state); 58 | set_pixel(min_image, height, width, h, w, img_value); 59 | set_pixel(max_image, height, width, h, w, img_value); 60 | } 61 | else{ 62 | set_pixel(dejavu, height, width, h, w, 0); 63 | set_pixel(distance, height, width, h, w, 0); 64 | set_pixel(post_segment, height, width, h, w, 0); 65 | init_dis = 255; 66 | init_state = 0; 67 | set_pixel(distance_in, height, width, h, w, init_dis); 68 | set_pixel(state, height, width, h, w, init_state); 69 | set_pixel(min_image, height, width, h, w, img_value); 70 | set_pixel(max_image, height, width, h, w, img_value); 71 | } 72 | } 73 | } 74 | 75 | 76 | int dh[4] = { 1 ,-1 , 0, 0}; 77 | int dw[4] = { 0 , 0 , 1,-1}; 78 | 79 | // Proceed the propagation from the marker to all pixels in the image 80 | for (int lvl = 0; lvl < 256; lvl++) 81 | { 82 | 83 | while (!Q[lvl].empty()) 84 | { 85 | Point2D p = Q[lvl].front(); 86 | Q[lvl].pop(); 87 | 88 | int state_value = get_pixel(state, height, width, p.h, p.w); 89 | if (state_value == 2) 90 | continue; 91 | 92 | set_pixel(state, height, width, p.h, p.w, 2); 93 | 94 | 95 | for (int n1 = 0 ; n1 < 4 ; n1++) 96 | { 97 | int tmp_h = p.h + dh[n1]; 98 | int tmp_w = p.w + dw[n1]; 99 | 100 | if (tmp_h >= 0 and tmp_h < height and tmp_w >= 0 and tmp_w < width) 101 | { 102 | Point2D r; 103 | r.h = tmp_h; 104 | r.w = tmp_w; 105 | 106 | unsigned char temp_r = get_pixel(distance_in, height, width, r.h, r.w); 107 | unsigned char temp_p = get_pixel(distance_in, height, width, p.h, p.w); 108 | 109 | state_value = get_pixel(state, height, width, r.h, r.w); 110 | 111 | if (state_value == 1 && temp_r> temp_p) 112 | { 113 | unsigned char min_image_value = get_pixel(min_image, height, width, p.h, p.w); 114 | unsigned char max_image_value = get_pixel(max_image, height, width, p.h, p.w); 115 | 116 | set_pixel(min_image, height, width, r.h, r.w, min_image_value); 117 | set_pixel(max_image, height, width, r.h, r.w, max_image_value); 118 | 119 | unsigned char image_value = get_pixel(img, height, width, r.h, r.w); 120 | 121 | if (image_value < min_image_value) 122 | set_pixel(min_image, height, width, r.h, r.w, image_value); 123 | if (image_value > max_image_value) 124 | set_pixel(max_image, height, width, r.h, r.w, image_value); 125 | 126 | temp_r = get_pixel(distance_in, height, width, r.h, r.w); 127 | 128 | min_image_value = get_pixel(min_image, height, width, r.h, r.w); 129 | max_image_value = get_pixel(max_image, height, width, r.h, r.w); 130 | 131 | unsigned char temp_dis = max_image_value - min_image_value; 132 | 133 | if (temp_r > temp_dis) 134 | { 135 | set_pixel(distance_in, height, width, r.h, r.w, temp_dis); 136 | Q[temp_dis].push(r); 137 | } 138 | } 139 | 140 | else if (state_value == 0) 141 | { 142 | unsigned char min_image_value = get_pixel(min_image, height, width, p.h, p.w); 143 | unsigned char max_image_value = get_pixel(max_image, height, width, p.h, p.w); 144 | set_pixel(min_image, height, width, r.h, r.w, min_image_value); 145 | set_pixel(max_image, height, width, r.h, r.w, max_image_value); 146 | 147 | unsigned char image_value = get_pixel(img, height, width, r.h, r.w); 148 | 149 | if (image_value < min_image_value) 150 | set_pixel(min_image, height, width, r.h, r.w, image_value); 151 | if (image_value > max_image_value) 152 | set_pixel(max_image, height, width, r.h, r.w, image_value); 153 | 154 | min_image_value = get_pixel(min_image, height, width, r.h, r.w); 155 | max_image_value = get_pixel(max_image, height, width, r.h, r.w); 156 | 157 | unsigned char temp_dis = max_image_value - min_image_value; 158 | 159 | set_pixel(distance_in, height, width, r.h, r.w, temp_dis); 160 | Q[temp_dis].push(r); 161 | set_pixel(state, height, width, r.h, r.w, 1); 162 | } 163 | else 164 | continue; 165 | 166 | } 167 | } 168 | } 169 | } 170 | 171 | 172 | // Outside 173 | 174 | vector > Q1(256); 175 | 176 | // point state: 0--acceptd, 1--temporary, 2--far away 177 | // get initial accepted set and far away set 178 | 179 | // unsigned char init_dis; 180 | // int init_state; 181 | 182 | for(int h = 0; h < height; h++) 183 | { 184 | for (int w = 0; w < width; w++) 185 | { 186 | Point2D p; 187 | p.h = h; 188 | p.w = w; 189 | unsigned char seed_type = get_pixel(destination, height, width, h, w); 190 | unsigned char img_value = get_pixel(img, height, width, h, w); 191 | 192 | if(seed_type > 100){ 193 | init_dis = 0; 194 | init_state = 1; 195 | Q1[init_dis].push(p); 196 | set_pixel(distance_out, height, width, h, w, init_dis); 197 | set_pixel(state, height, width, h, w, init_state); 198 | set_pixel(min_image, height, width, h, w, img_value); 199 | set_pixel(max_image, height, width, h, w, img_value); 200 | } 201 | else{ 202 | init_dis = 255; 203 | init_state = 0; 204 | set_pixel(distance_out, height, width, h, w, init_dis); 205 | set_pixel(state, height, width, h, w, init_state); 206 | set_pixel(min_image, height, width, h, w, img_value); 207 | set_pixel(max_image, height, width, h, w, img_value); 208 | } 209 | } 210 | } 211 | 212 | 213 | 214 | 215 | // Proceed the propagation from the marker to all pixels in the image 216 | for (int lvl = 0; lvl < 256; lvl++) 217 | { 218 | 219 | while (!Q1[lvl].empty()) 220 | { 221 | Point2D p = Q1[lvl].front(); 222 | Q1[lvl].pop(); 223 | 224 | int state_value = get_pixel(state, height, width, p.h, p.w); 225 | if (state_value == 2) 226 | continue; 227 | 228 | set_pixel(state, height, width, p.h, p.w, 2); 229 | 230 | 231 | for (int n1 = 0 ; n1 < 4 ; n1++) 232 | { 233 | int tmp_h = p.h + dh[n1]; 234 | int tmp_w = p.w + dw[n1]; 235 | 236 | if (tmp_h >= 0 and tmp_h < height and tmp_w >= 0 and tmp_w < width) 237 | { 238 | Point2D r; 239 | r.h = tmp_h; 240 | r.w = tmp_w; 241 | 242 | unsigned char temp_r = get_pixel(distance_out, height, width, r.h, r.w); 243 | unsigned char temp_p = get_pixel(distance_out, height, width, p.h, p.w); 244 | 245 | state_value = get_pixel(state, height, width, r.h, r.w); 246 | 247 | if (state_value == 1 && temp_r> temp_p) 248 | { 249 | unsigned char min_image_value = get_pixel(min_image, height, width, p.h, p.w); 250 | unsigned char max_image_value = get_pixel(max_image, height, width, p.h, p.w); 251 | 252 | set_pixel(min_image, height, width, r.h, r.w, min_image_value); 253 | set_pixel(max_image, height, width, r.h, r.w, max_image_value); 254 | 255 | unsigned char image_value = get_pixel(img, height, width, r.h, r.w); 256 | 257 | if (image_value < min_image_value) 258 | set_pixel(min_image, height, width, r.h, r.w, image_value); 259 | if (image_value > max_image_value) 260 | set_pixel(max_image, height, width, r.h, r.w, image_value); 261 | 262 | temp_r = get_pixel(distance_out, height, width, r.h, r.w); 263 | 264 | min_image_value = get_pixel(min_image, height, width, r.h, r.w); 265 | max_image_value = get_pixel(max_image, height, width, r.h, r.w); 266 | 267 | unsigned char temp_dis = max_image_value - min_image_value; 268 | 269 | if (temp_r > temp_dis) 270 | { 271 | set_pixel(distance_out, height, width, r.h, r.w, temp_dis); 272 | Q1[temp_dis].push(r); 273 | } 274 | } 275 | 276 | else if (state_value == 0) 277 | { 278 | unsigned char min_image_value = get_pixel(min_image, height, width, p.h, p.w); 279 | unsigned char max_image_value = get_pixel(max_image, height, width, p.h, p.w); 280 | set_pixel(min_image, height, width, r.h, r.w, min_image_value); 281 | set_pixel(max_image, height, width, r.h, r.w, max_image_value); 282 | 283 | unsigned char image_value = get_pixel(img, height, width, r.h, r.w); 284 | 285 | if (image_value < min_image_value) 286 | set_pixel(min_image, height, width, r.h, r.w, image_value); 287 | if (image_value > max_image_value) 288 | set_pixel(max_image, height, width, r.h, r.w, image_value); 289 | 290 | min_image_value = get_pixel(min_image, height, width, r.h, r.w); 291 | max_image_value = get_pixel(max_image, height, width, r.h, r.w); 292 | 293 | unsigned char temp_dis = max_image_value - min_image_value; 294 | 295 | set_pixel(distance_out, height, width, r.h, r.w, temp_dis); 296 | Q1[temp_dis].push(r); 297 | set_pixel(state, height, width, r.h, r.w, 1); 298 | } 299 | else 300 | continue; 301 | 302 | } 303 | } 304 | } 305 | } 306 | 307 | 308 | // segment region by comparing the two distance maps 309 | 310 | for(int h = 0; h < height; h++) 311 | { 312 | for (int w = 0; w < width; w++) 313 | { 314 | unsigned char distance1 = get_pixel(distance_in, height, width, h, w); 315 | unsigned char distance2 = get_pixel(distance_out, height, width, h, w); 316 | if (distance1 < distance2 ) 317 | set_pixel(segment, height, width, h, w, 255); 318 | else 319 | set_pixel(segment, height, width, h, w, 0); 320 | } 321 | } 322 | 323 | // Post processing to select only one connected component from the seed point of th foreground 324 | 325 | while (!Q_segment.empty()) 326 | { 327 | Point2D p = Q_segment.front(); 328 | Q_segment.pop(); 329 | set_pixel(post_segment, height, width, p.h, p.w, 255); 330 | 331 | for (int n1 = 0 ; n1 < 4 ; n1++) 332 | { 333 | int tmp_h = p.h + dh[n1]; 334 | int tmp_w = p.w + dw[n1]; 335 | 336 | if (tmp_h >= 0 and tmp_h < height and tmp_w >= 0 and tmp_w < width) 337 | { 338 | Point2D r; 339 | r.h = tmp_h; 340 | r.w = tmp_w; 341 | unsigned char temp_r = get_pixel(segment, height, width, r.h, r.w); 342 | unsigned char temp_djv = get_pixel(dejavu, height, width, r.h, r.w); 343 | if (temp_djv == 0 and temp_r == 255) 344 | { 345 | set_pixel(dejavu, height, width, r.h, r.w, 1); 346 | Q_segment.push(r); 347 | } 348 | } 349 | } 350 | } 351 | 352 | for(int h = 0; h < height; h++) 353 | { 354 | for (int w = 0; w < width; w++) 355 | { 356 | Point2D p; 357 | p.h = h; 358 | p.w = w; 359 | for (int n1 = 0 ; n1 < 4 ; n1++) 360 | { 361 | int tmp_h = p.h + dh[n1]; 362 | int tmp_w = p.w + dw[n1]; 363 | 364 | if (tmp_h >= 0 and tmp_h < height and tmp_w >= 0 and tmp_w < width) 365 | { 366 | unsigned char temp_p = get_pixel(post_segment, height, width, p.h, p.w); 367 | unsigned char temp_r = get_pixel(post_segment, height, width, tmp_h, tmp_w); 368 | if (temp_p == 0 and temp_r == 255) 369 | set_pixel(distance, height, width, p.h, p.w, 255); 370 | } 371 | } 372 | } 373 | } 374 | 375 | 376 | delete state; 377 | delete min_image; 378 | delete max_image; 379 | delete distance_in; 380 | delete distance_out; 381 | delete segment; 382 | delete dejavu; 383 | delete post_segment; 384 | 385 | } 386 | 387 | --------------------------------------------------------------------------------