├── LICENSE ├── README.md ├── data ├── cvpr19_LRW_poster.pdf ├── raw.pytorch ├── seeds.pytorch └── target.pytorch ├── evaluation ├── README.md ├── __init__.py ├── environment.yml └── evaluation_test_cremi.py ├── example2D.py ├── exampleCREMI.py └── randomwalker ├── RandomWalkerModule.py ├── __init__.py ├── build_laplacian.py ├── randomwalker2D.py ├── randomwalker_loss_utils.py └── randomwalker_tools.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Heidelberg Collaboratory for Image Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-LearnedRandomWalker 2 | Implementation of the LearnedRandomWalker module as described in: 3 | * Paper: [End-To-End Learned Random Walker for Seeded Image Segmentation](https://openaccess.thecvf.com/content_CVPR_2019/papers/Cerrone_End-To-End_Learned_Random_Walker_for_Seeded_Image_Segmentation_CVPR_2019_paper.pdf) 4 | * [Supplementary Material](https://openaccess.thecvf.com/content_CVPR_2019/supplemental/Cerrone_End-To-End_Learned_Random_CVPR_2019_supplemental.pdf) 5 | * [CVPR2019 Poster](./data/cvpr19_LRW_poster.pdf) 6 | 7 | ## Data processing: 8 | The results reported in the paper are based on a modified version of the [CREMI](https://cremi.org/) challenge dataset. 9 | 10 | The following processing have been performed: 11 | * The `raw` and `labels` have been cropped to by 2 pixels in the `x/y` plane to to avoid potential 12 | misalignments during the upsampling and downsampling. 13 | * The slice `0` of the `labels` is ignored because the UNet used in the experiments uses 3 z-slices as input. 14 | * Some instances are connected in 3D but are not visually connected in 2D, therefore a slice-by-slice relabeling is 15 | performed. 16 | * Groundtruth segments smaller than `64` pixels in the `x/y` plane are merged with the surrounding segments using the 17 | watershed algorithm. 18 | * Groundtruth slices corrupted or with extreme artifacts are ignored in the testing and training. 19 | The following slices are removed from the test set: **Cremi B**: 44, 45, 15, 16, **Cremi C**: 14. 20 | * The first 50 **valid** slices from each CREMI volume (A, B, C) are used for testing. The remaining **valid** slices 21 | are used for training. 22 | 23 | The final 150 test set slices (with all above mentioned modifications): 24 | https://heibox.uni-heidelberg.de/published/cvpr2019_lrw/ 25 | 26 | Additionally the repository contains seeds, learned RW segmentation, standard WS segmentation, standard RW segmentation 27 | and the CNN predictions. 28 | 29 | ## Evaluation: 30 | In the [evaluation](./evaluation) directory you can find all instruction to reproduce the results in the manuscript 31 | and the evaluation script used. 32 | 33 | 34 | ## Cite: 35 | ``` 36 | @inproceedings{cerrone2019, 37 | title={End-to-end learned random walker for seeded image segmentation}, 38 | author={Cerrone, Lorenzo and Zeilmann, Alexander and Hamprecht, Fred A}, 39 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 40 | pages={12559--12568}, 41 | year={2019} 42 | } 43 | ``` 44 | 45 | -------------------------------------------------------------------------------- /data/cvpr19_LRW_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sciai-lab/pytorch-LearnedRandomWalker/b2d7b4fe0555bd0f69f2e5ec4fc899b97f108f5c/data/cvpr19_LRW_poster.pdf -------------------------------------------------------------------------------- /data/raw.pytorch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sciai-lab/pytorch-LearnedRandomWalker/b2d7b4fe0555bd0f69f2e5ec4fc899b97f108f5c/data/raw.pytorch -------------------------------------------------------------------------------- /data/seeds.pytorch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sciai-lab/pytorch-LearnedRandomWalker/b2d7b4fe0555bd0f69f2e5ec4fc899b97f108f5c/data/seeds.pytorch -------------------------------------------------------------------------------- /data/target.pytorch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sciai-lab/pytorch-LearnedRandomWalker/b2d7b4fe0555bd0f69f2e5ec4fc899b97f108f5c/data/target.pytorch -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation script 2 | 3 | ## Download data 4 | In order to reproduce the results reported in 5 | [End-To-End Learned Random Walker for Seeded Image Segmentation](https://openaccess.thecvf.com/content_CVPR_2019/papers/Cerrone_End-To-End_Learned_Random_Walker_for_Seeded_Image_Segmentation_CVPR_2019_paper.pdf) 6 | please download `cvpr_2109_lrw.h5` and `baseline_lrw.h5` from: 7 | https://heibox.uni-heidelberg.de/published/cvpr2019_lrw/ 8 | 9 | ## installation 10 | ``` 11 | conda env create -f ./environment.yml 12 | conda activate py36_cremi 13 | pip install git+https://github.com/lorenzocerrone/cremi_python.git 14 | ``` 15 | 16 | ## Run Learned RW evaluation 17 | ``` 18 | python evaluation_test_cremi.py --gtpath ./cvpr_2109_lrw.h5 --segpath ./cvpr_2109_lrw.h5 19 | ``` 20 | 21 | ## Run baseline evaluation 22 | standard RandomWalker algorithm 23 | ``` 24 | python evaluation_test_cremi.py --gtpath ./cvpr_2109_lrw.h5 --segpath ./baseline_lrw.h5 --segdataset segmentation_stRW 25 | ``` 26 | standard Watershed algorithm 27 | ``` 28 | python evaluation_test_cremi.py --gtpath ./cvpr_2109_lrw.h5 --segpath ./baseline_lrw.h5 --segdataset segmentation_stWS 29 | ``` 30 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sciai-lab/pytorch-LearnedRandomWalker/b2d7b4fe0555bd0f69f2e5ec4fc899b97f108f5c/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/environment.yml: -------------------------------------------------------------------------------- 1 | name: py36_cremi 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - blas=1.0=mkl 7 | - ca-certificates=2020.10.14=0 8 | - certifi=2020.6.20=py36h06a4308_2 9 | - h5py=2.10.0=py36hd6299e0_1 10 | - hdf5=1.10.6=hb1b8bf9_0 11 | - intel-openmp=2020.2=254 12 | - ld_impl_linux-64=2.33.1=h53a641e_7 13 | - libedit=3.1.20191231=h14c3975_1 14 | - libffi=3.3=he6710b0_2 15 | - libgcc-ng=9.1.0=hdf63c60_0 16 | - libgfortran-ng=7.3.0=hdf63c60_0 17 | - libstdcxx-ng=9.1.0=hdf63c60_0 18 | - mkl=2020.2=256 19 | - mkl-service=2.3.0=py36he904b0f_0 20 | - mkl_fft=1.2.0=py36h23d657b_0 21 | - mkl_random=1.1.1=py36h0573a6f_0 22 | - ncurses=6.2=he6710b0_1 23 | - numpy=1.19.2=py36h54aff64_0 24 | - numpy-base=1.19.2=py36hfa32c7d_0 25 | - openssl=1.1.1h=h7b6447c_0 26 | - pip=20.2.4=py36_0 27 | - python=3.6.12=hcff3b4d_2 28 | - readline=8.0=h7b6447c_0 29 | - scipy=1.5.2=py36h0b6359f_0 30 | - setuptools=50.3.0=py36hb0f4dca_1 31 | - six=1.15.0=py_0 32 | - sqlite=3.33.0=h62c20be_0 33 | - tk=8.6.10=hbc83047_0 34 | - wheel=0.35.1=py_0 35 | - xz=5.2.5=h7b6447c_0 36 | - zlib=1.2.11=h7b6447c_3 37 | - pip: 38 | - munkres==1.1.2 39 | prefix: /home/lcerrone/miniconda3/envs/py36_cremi 40 | -------------------------------------------------------------------------------- /evaluation/evaluation_test_cremi.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from cremi import Volume 4 | from cremi.evaluation import NeuronIds 5 | 6 | 7 | def cremi_score(gt, seg, return_all_scores=True, b_thresh=2, data_resolution=(1.0, 1.0, 1.0)): 8 | """compute cremi scores from np.array""" 9 | 10 | if len(gt.shape) == 2: 11 | gt = gt[None, :, :] 12 | seg = seg[None, :, :] 13 | gt_ = Volume(gt, resolution=data_resolution) 14 | seg_ = Volume(seg, resolution=data_resolution) 15 | 16 | metrics = NeuronIds(gt_, b_thresh) 17 | arand = metrics.adapted_rand(seg_) 18 | 19 | vi_s, vi_m = metrics.voi(seg_) 20 | # official cremi score 21 | cs = np.sqrt(vi_s * (vi_m + arand)) 22 | 23 | if return_all_scores: 24 | return cs, vi_s, vi_m, arand 25 | else: 26 | return cs 27 | 28 | 29 | def compute_scores(gtpath, segpath, gtdataset="targets", segdataset="segmentation"): 30 | # Data load 31 | with h5py.File(gtpath, "r") as f: 32 | gt = f[gtdataset][...] 33 | 34 | with h5py.File(segpath, "r") as f: 35 | segmentation = f[segdataset][...] 36 | 37 | # Compute 2D cremi scores 38 | scores = [] 39 | for i in range(gt.shape[0]): 40 | scores.append(cremi_score(gt[i], segmentation[i], b_thresh=2)) 41 | 42 | scores = np.array(scores) 43 | rand, vois, voim = scores[:, 3], scores[:, 1], scores[:, 2] 44 | print(f"results A rand: {np.mean(rand[:50])} pm {np.std(rand[:50])}") 45 | print(f"results A vois: {np.mean(vois[:50])} pm {np.std(vois[:50])}") 46 | print(f"results A voim: {np.mean(voim[:50])} pm {np.std(voim[:50])}\n") 47 | 48 | print(f"results B rand: {np.mean(rand[50:100])} pm {np.std(rand[50:100])}") 49 | print(f"results B vois: {np.mean(vois[50:100])} pm {np.std(vois[50:100])}") 50 | print(f"results B voim: {np.mean(voim[50:100])} pm {np.std(voim[50:100])}\n") 51 | 52 | print(f"results C rand: {np.mean(rand[100:])} pm {np.std(rand[100:])}") 53 | print(f"results C vois: {np.mean(vois[100:])} pm {np.std(vois[100:])}") 54 | print(f"results C voim: {np.mean(voim[100:])} pm {np.std(voim[100:])}\n") 55 | 56 | print(f"results rand: {np.mean(rand)} pm {np.std(rand)}") 57 | print(f"results vois: {np.mean(vois)} pm {np.std(vois)}") 58 | print(f"results voim: {np.mean(voim)} pm {np.std(voim)}") 59 | return scores 60 | 61 | 62 | if __name__ == "__main__": 63 | import argparse 64 | 65 | def _parser(): 66 | parser = argparse.ArgumentParser(description='Run cremi valuation on 2D ' 67 | '(x/y plane, test set A[0:50]B[0:50]C[0:50]).') 68 | parser.add_argument('--gtpath', type=str, help='Path to the groundtruth file (only h5).', 69 | required=True) 70 | parser.add_argument('--segpath', type=str, help='Path to the predicted segmentation file (only h5).', 71 | required=True) 72 | parser.add_argument('--gtdataset', type=str, help='Groundtruth labels dataset.', 73 | default="targets", required=False) 74 | parser.add_argument('--segdataset', type=str, help='Predicted labels dataset.', 75 | default="segmentation", required=False) 76 | return parser.parse_args() 77 | 78 | args = _parser() 79 | compute_scores(args.gtpath, args.segpath, args.gtdataset, args.segdataset) 80 | -------------------------------------------------------------------------------- /example2D.py: -------------------------------------------------------------------------------- 1 | from randomwalker.RandomWalkerModule import RandomWalker 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from randomwalker.randomwalker_loss_utils import NHomogeneousBatchLoss 6 | import os 7 | 8 | if not os.path.exists('results'): 9 | os.makedirs('results') 10 | 11 | 12 | def create_simple_image(size, batch=1): 13 | """ 14 | returns a simple target image with two and a seeds mask 15 | """ 16 | sizex, sizey = size 17 | 18 | target = torch.zeros(batch, 1, sizex, sizey).long() 19 | target[:, :, :, sizey//2:] = 1 20 | 21 | seeds = torch.zeros(batch, sizex, sizey).long() 22 | seeds[:, 3 * sizex//4, 5] = 1 23 | seeds[:, sizex//4, sizey - 5] = 2 24 | 25 | return seeds, target 26 | 27 | 28 | def make_summary_plot(it, output, net_output, seeds, target): 29 | """ 30 | This function create and save a summary figure 31 | """ 32 | f, axarr = plt.subplots(2, 2, figsize=(8, 9.5)) 33 | f.suptitle("RW summary, Iteration: " + repr(it)) 34 | 35 | axarr[0, 0].set_title("Ground Truth Image") 36 | axarr[0, 0].imshow(target[0, 0].detach().numpy(), alpha=0.8, vmin=-3, cmap="prism_r") 37 | seeds_listx, seeds_listy = np.where(seeds[0].data != 0) 38 | axarr[0, 0].scatter(seeds_listy[1], 39 | seeds_listx[1], c="w") 40 | axarr[0, 0].scatter(seeds_listy[0], 41 | seeds_listx[0], c="k") 42 | axarr[0, 0].axis("off") 43 | 44 | axarr[0, 1].set_title("LRW output (white seed)") 45 | axarr[0, 1].imshow(output[0][0, 0].detach().numpy(), cmap="gray") 46 | axarr[0, 1].axis("off") 47 | 48 | axarr[1, 0].set_title("Vertical Diffusivities") 49 | axarr[1, 0].imshow(net_output[0, 0].detach().numpy(), cmap="gray", vmax=1) 50 | axarr[1, 0].axis("off") 51 | 52 | axarr[1, 1].set_title("Horizontal Diffusivities") 53 | axarr[1, 1].imshow(net_output[0, 1].detach().numpy(), cmap="gray", vmax=1) 54 | axarr[1, 1].axis("off") 55 | 56 | plt.tight_layout() 57 | plt.savefig("./results/%04i.png"%it) 58 | plt.close() 59 | 60 | 61 | if __name__ == '__main__': 62 | # Init parameters 63 | batch_size = 1 64 | iterations = 60 65 | size = (60, 59) 66 | 67 | # Init the random walker modules 68 | rw = RandomWalker(1000, max_backprop=True) 69 | 70 | # Load data and init 71 | seeds, target = create_simple_image(size) 72 | diffusivities = torch.zeros(batch_size, 2, size[0], size[1], requires_grad=True) 73 | 74 | # Init optimizer 75 | optimizer = torch.optim.Adam([diffusivities], lr=0.5) 76 | 77 | # Loss has to been wrapped in order to work with random walker algorithm 78 | loss = NHomogeneousBatchLoss(torch.nn.NLLLoss) 79 | 80 | # Main overfit loop 81 | for it in range(iterations + 1): 82 | optimizer.zero_grad() 83 | 84 | # Diffusivities must be positive 85 | net_output = torch.sigmoid(diffusivities) 86 | 87 | # Random walker 88 | output = rw(net_output, seeds) 89 | 90 | # Loss and diffusivities update 91 | output_log = [torch.log(o) for o in output] 92 | l = loss(output_log, target) 93 | l.backward(retain_graph=True) 94 | optimizer.step() 95 | 96 | # Summary 97 | if it % 5 == 0: 98 | print("Iteration: ", it, " Loss: ", l.item()) 99 | make_summary_plot(it, output, net_output, seeds, target) 100 | -------------------------------------------------------------------------------- /exampleCREMI.py: -------------------------------------------------------------------------------- 1 | from randomwalker.RandomWalkerModule import RandomWalker 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from randomwalker.randomwalker_loss_utils import NHomogeneousBatchLoss 6 | import os 7 | 8 | if not os.path.exists('results'): 9 | os.makedirs('results') 10 | 11 | 12 | def make_summary_plot(it, raw, output, net_output, seeds, target): 13 | """ 14 | This function create and save a summary figure 15 | """ 16 | f, axarr = plt.subplots(2, 2, figsize=(8, 9.5)) 17 | f.suptitle("RW summary, Iteration: " + repr(it)) 18 | 19 | axarr[0, 0].set_title("Ground Truth Image") 20 | axarr[0, 0].imshow(raw[0].detach().numpy(), cmap="gray") 21 | axarr[0, 0].imshow(target[0, 0].detach().numpy(), alpha=0.6, vmin=-3, cmap="prism_r") 22 | seeds_listx, seeds_listy = np.where(seeds[0].data != 0) 23 | axarr[0, 0].scatter(seeds_listy, 24 | seeds_listx, c="r") 25 | axarr[0, 0].axis("off") 26 | 27 | axarr[0, 1].set_title("LRW output (white seed)") 28 | axarr[0, 1].imshow(raw[0].detach().numpy(), cmap="gray") 29 | axarr[0, 1].imshow(np.argmax(output[0][0].detach().numpy(), 0), alpha=0.6, vmin=-3, cmap="prism_r") 30 | axarr[0, 1].axis("off") 31 | 32 | axarr[1, 0].set_title("Vertical Diffusivities") 33 | axarr[1, 0].imshow(net_output[0, 0].detach().numpy(), cmap="gray") 34 | axarr[1, 0].axis("off") 35 | 36 | axarr[1, 1].set_title("Horizontal Diffusivities") 37 | axarr[1, 1].imshow(net_output[0, 1].detach().numpy(), cmap="gray") 38 | axarr[1, 1].axis("off") 39 | 40 | plt.tight_layout() 41 | plt.savefig("./results/%04i.png"%it) 42 | plt.close() 43 | 44 | 45 | if __name__ == '__main__': 46 | # Init parameters 47 | batch_size = 1 48 | iterations = 60 49 | size = (128, 128) 50 | datadir = "data/" 51 | 52 | # Init the random walker modules 53 | rw = RandomWalker(1000, max_backprop=True) 54 | 55 | # Load data and init 56 | raw = torch.load(datadir + "raw.pytorch") 57 | target = torch.load(datadir + "target.pytorch") 58 | seeds = torch.load(datadir + "seeds.pytorch") 59 | diffusivities = torch.zeros(batch_size, 2, size[0], size[1], requires_grad=True) 60 | 61 | # Init optimizer 62 | optimizer = torch.optim.Adam([diffusivities], lr=0.9) 63 | 64 | # Loss has to been wrapped in order to work with random walker algorithm 65 | loss = NHomogeneousBatchLoss(torch.nn.NLLLoss) 66 | 67 | # Main overfit loop 68 | for it in range(iterations + 1): 69 | optimizer.zero_grad() 70 | 71 | # Diffusivities must be positive 72 | net_output = torch.sigmoid(diffusivities) 73 | 74 | # Random walker 75 | output = rw(net_output, seeds) 76 | 77 | # Loss and diffusivities update 78 | output_log = [torch.log(o) for o in output] 79 | l = loss(output_log, target) 80 | l.backward(retain_graph=True) 81 | optimizer.step() 82 | 83 | # Summary 84 | if it % 5 == 0: 85 | print("Iteration: ", it, " Loss: ", l.item()) 86 | make_summary_plot(it, raw, output, net_output, seeds, target) 87 | -------------------------------------------------------------------------------- /randomwalker/RandomWalkerModule.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import torch.nn as nn 4 | import torch 5 | from .randomwalker2D import RandomWalker2D as RW2D 6 | 7 | 8 | class RandomWalker(nn.Module): 9 | def __init__(self, num_grad, max_backprop=True): 10 | """ 11 | num_grad: Number of sampled gradients 12 | max_backprop: Compute the loss only on the absolute maximum 13 | """ 14 | super(RandomWalker, self).__init__() 15 | self.rw = RW2D 16 | self.num_grad = num_grad 17 | self.max_backprop = max_backprop 18 | 19 | def forward(self, e, seeds): 20 | """ 21 | e: must be a torch tensors [b x 2 x X, Y] 22 | seeds: must be a torch tensors [b x X, Y] 23 | """ 24 | out_probabilities = [] 25 | for batch in range(e.size(0)): 26 | out_probabilities_ = self.rw(self.num_grad, self.max_backprop)(e[batch].cpu(), seeds[batch]) 27 | 28 | out_probabilities_ = torch.transpose(out_probabilities_, 0, 1) 29 | out_probabilities_ = out_probabilities_.view(1, -1, seeds.size(1), seeds.size(2)) 30 | out_probabilities.append(out_probabilities_) 31 | 32 | return out_probabilities 33 | -------------------------------------------------------------------------------- /randomwalker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sciai-lab/pytorch-LearnedRandomWalker/b2d7b4fe0555bd0f69f2e5ec4fc899b97f108f5c/randomwalker/__init__.py -------------------------------------------------------------------------------- /randomwalker/build_laplacian.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from numba import jit 5 | 6 | 7 | @jit(cache=True) 8 | def build_laplacian2D(elap0, size): 9 | """ 10 | lap0: Graph weights elements, must be (num pixels, 2) 11 | size: size original image, must be a 2D tuple 12 | returns: laplacian elements and COO indices 13 | """ 14 | num_elements_lap = 2 * (size[1] - 1)*size[0] + 2 * (size[0] - 1)*size[1] + size[0]*size[1] 15 | elap = np.zeros(num_elements_lap) 16 | 17 | # lap_index and off diagonal lap index 18 | i_ind, j_ind = (np.zeros(num_elements_lap, dtype=np.int), 19 | np.zeros(num_elements_lap, dtype=np.int)) 20 | 21 | cout, sk = 0, 1 22 | for i in range(size[0]): 23 | for j in range(size[1]): 24 | k = i * size[1] + j 25 | 26 | if i > sk - 1: 27 | n = elap0[0, i - sk, j] 28 | elap[cout] = -n 29 | i_ind[cout], j_ind[cout] = k, k - sk * size[1] 30 | cout += 1 31 | else: 32 | n = 0 33 | 34 | if j > sk - 1: 35 | w = elap0[1, i, j - sk] 36 | elap[cout] = -w 37 | i_ind[cout], j_ind[cout] = k, k - sk 38 | cout += 1 39 | else: 40 | w = 0 41 | 42 | if i < size[0] - sk: 43 | s = elap0[0, i, j] 44 | elap[cout] = -s 45 | i_ind[cout], j_ind[cout] = k, k + sk * size[1] 46 | cout += 1 47 | else: 48 | s = 0 49 | 50 | if j < size[1] - sk: 51 | e = elap0[1, i, j] 52 | elap[cout] = -e 53 | i_ind[cout], j_ind[cout] = k, k + sk 54 | cout += 1 55 | else: 56 | e = 0 57 | 58 | norm = (n + w + s + e) 59 | elap[cout] = max(norm, 1e-5) 60 | i_ind[cout], j_ind[cout] = k, k 61 | cout += 1 62 | 63 | return elap, i_ind, j_ind 64 | -------------------------------------------------------------------------------- /randomwalker/randomwalker2D.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | from scipy.sparse import csc_matrix, coo_matrix 5 | from sksparse.cholmod import cholesky 6 | import torch 7 | from torch.autograd import Function 8 | from randomwalker import randomwalker_tools 9 | 10 | 11 | class RandomWalker2D(Function): 12 | def __init__(self, num_grad=1000, max_backprop=True): 13 | super(RandomWalker2D, self).__init__() 14 | """ 15 | num_grad: Number of sampled gradients 16 | max_backprop: Compute the loss only on the absolute maximum 17 | """ 18 | self.num_grad = num_grad 19 | self.max_backprop = max_backprop 20 | self.lap_u = None 21 | self.pu = None 22 | self.gradout = None 23 | self.ch_lap = None 24 | self.c_max = None 25 | 26 | def forward(self, elap_input, seeds_input): 27 | """ 28 | input : essential Laplacian (s, e edge for each pixel) shape: N_pixels x 2 29 | 30 | output: instances probability shape: N_pixels x N_seeds 31 | """ 32 | # Pytorch Tensors to numpy 33 | elap = elap_input.clone().numpy() 34 | seeds = seeds_input.numpy() 35 | elap = np.squeeze(elap) 36 | 37 | # Building laplacian and running the RW 38 | self.pu, self.lap_u = randomwalker_tools.standard_RW(elap, seeds) 39 | 40 | # Fill seeds predictions 41 | p = randomwalker_tools.pu2p(self.pu, seeds) 42 | # save for backward 43 | self.save_for_backward(seeds_input) 44 | 45 | return torch.from_numpy(p) 46 | 47 | def backward(self, grad_output): 48 | """ 49 | input : grad from loss 50 | output: grad from the laplacian backprop 51 | """ 52 | 53 | # Pytorch Tensors to numpy 54 | gradout = grad_output.numpy() 55 | seeds = self.saved_tensors[0].numpy() 56 | 57 | # Remove seeds from grad_output 58 | self.gradout = randomwalker_tools.p2pu(gradout, seeds) 59 | 60 | # Back propagation 61 | grad_input = self.dlap_df() 62 | 63 | grad_input = randomwalker_tools.grad_fill(grad_input, seeds, 2).reshape(-1, seeds.shape[0], seeds.shape[1]) 64 | 65 | grad_input = grad_input[None, ...] 66 | 67 | return torch.FloatTensor(grad_input), None 68 | 69 | def dlap_df(self): 70 | """ 71 | Sampled back prop implementation 72 | grad_input: The gradient input for the previous layer 73 | """ 74 | 75 | # Solver + sampling 76 | grad_input = np.zeros((2, self.pu.shape[0])) 77 | lap_u = coo_matrix(self.lap_u) 78 | ind_i, ind_j = lap_u.col, lap_u.row 79 | 80 | # mask n and w direction 81 | mask = (ind_j - ind_i) > 0 82 | ind_i, ind_j = ind_i[mask], ind_j[mask] 83 | 84 | # find the edge direction 85 | mask = ind_j - ind_i == 1 86 | dir_e = np.zeros_like(ind_i) 87 | dir_e[mask] = 1 88 | 89 | # Sampling 90 | if self.num_grad < np.unique(ind_i).shape[0]: 91 | u_ind = np.unique(ind_i) 92 | grad2do = np.random.choice(u_ind, size=self.num_grad, replace=False) 93 | else: 94 | grad2do = np.unique(ind_i) 95 | 96 | # Compute the choalesky decomposition 97 | self.ch_lap = cholesky(csc_matrix(self.lap_u)) 98 | 99 | # find maxgrad for each region 100 | if self.max_backprop: 101 | self.c_max = np.argmax(np.abs(self.gradout), axis=1) 102 | else: 103 | # only biggest 10 104 | self.c_max = np.argsort(np.abs(self.gradout), axis=1) 105 | 106 | # Loops around all the edges 107 | for k, l, e in zip(ind_i, ind_j, dir_e): 108 | if k in grad2do: 109 | grad_input[e, k] = self.compute_grad(k, l) 110 | 111 | return grad_input 112 | 113 | def compute_grad(self, k, l): 114 | """ 115 | k, l: pixel indices, referred to the unseeded laplacian 116 | ch_lap: choaleshy decomposition of the undseeded laplacian 117 | pu: unseeded output probability 118 | gradout: previous layer jacobian 119 | return: grad for the edge k, l 120 | """ 121 | dl = np.zeros_like(self.pu) 122 | dl[l] = self.pu[k] - self.pu[l] 123 | dl[k] = self.pu[l] - self.pu[k] 124 | 125 | partial_grad = self.ch_lap.solve_A(dl[:, self.c_max[k]]) 126 | grad = np.sum(self.gradout[:, self.c_max[k]] * partial_grad) 127 | return grad 128 | -------------------------------------------------------------------------------- /randomwalker/randomwalker_loss_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class NHomogeneousBatchLoss: 4 | def __init__(self, loss, *args, **kwargs): 5 | """ 6 | Apply loss on list of Non-Homogeneous tensors 7 | """ 8 | self.loss = loss(**kwargs) 9 | 10 | def __call__(self, output, target): 11 | """ 12 | output: must be a list of torch tensors [1 x ? x spatial] 13 | target: must be a list of torch tensors or a single tensor, shape depends on loss function 14 | """ 15 | assert isinstance(output, list), "output must be a list of torch tensors" 16 | 17 | l, it = 0, 0 18 | for it, o in enumerate(output): 19 | l = l + self.loss(o, target[it]) 20 | 21 | return l / (it + 1) 22 | -------------------------------------------------------------------------------- /randomwalker/randomwalker_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | from scipy.sparse import csc_matrix 5 | from scipy.sparse.linalg import spsolve 6 | from .build_laplacian import build_laplacian2D 7 | 8 | 9 | def pu2p(pu, seeds): 10 | """ 11 | :param pu: unseeded output probability 12 | :param seeds: RW seeds, must be the same size as the image 13 | :return: p: the complete output probability 14 | """ 15 | seeds_r = seeds.ravel() 16 | mask_u = seeds_r == 0 17 | p = np.zeros((seeds_r.shape[0], pu.shape[-1]), dtype=np.float32) 18 | 19 | for s in range(seeds.max()): 20 | pos_s = np.where(seeds_r == s + 1) 21 | p[pos_s, s] = 1 22 | p[mask_u, s] = pu[:, s] 23 | 24 | return p 25 | 26 | 27 | def grad_fill(gradu, seeds, edges=2): 28 | """ 29 | :param gradu: unseeded output probability 30 | :param seeds: RW seeds, must be the same size as the image 31 | :param edges: number of affinities for each pixel 32 | :return: p: the complete output probability 33 | """ 34 | seeds_r = seeds.ravel() 35 | mask_u = seeds_r == 0 36 | grad = np.zeros((edges, seeds_r.shape[0])) 37 | grad[:, mask_u] = gradu 38 | 39 | return grad 40 | 41 | 42 | def p2pu(p, seeds): 43 | """ 44 | :param p: output probability 45 | :param seeds: RW seeds, must be the same size as the image 46 | :return: pu: unseeded output probability 47 | """ 48 | mask_u = seeds.ravel() == 0 49 | pu = p[mask_u] 50 | return pu 51 | 52 | 53 | def lap2lapu_bt(lap, seeds): 54 | mask_u = seeds.ravel() == 0 55 | return lap[mask_u][:, mask_u], - lap[mask_u][:, ~mask_u] 56 | 57 | 58 | def sparse_laplacian(elap, size_image): 59 | """ 60 | :param elap: Graph weights elements, must be (num pixels, 2) 61 | :param size_image: size original image, must be a 2D tuple 62 | :return: graph laplacian as a csc matrix 63 | """ 64 | if elap.shape[0] == 2: 65 | e, i_ind, j_ind = build_laplacian2D(elap, size_image) 66 | laplacian = csc_matrix((e, (i_ind, j_ind)), shape=(np.prod(size_image), np.prod(size_image))) 67 | return laplacian 68 | else: 69 | raise NotImplementedError 70 | 71 | 72 | def sparse_pm(seeds): 73 | """ 74 | :param seeds: RW seeds, must be the same size as the image 75 | :return: poss matrix for the standard RW 76 | """ 77 | k = np.where(seeds.ravel() != 0)[0] 78 | i_ind, j_ind = np.arange(k.shape[0]), seeds.ravel()[k] - 1 79 | val = np.ones_like(k, dtype=np.float) 80 | return csc_matrix((val, (i_ind, j_ind)), shape=(k.shape[0], j_ind.max() + 1)) 81 | 82 | 83 | def standard_RW(elap, seeds): 84 | """ 85 | laplacian: Graph laplacian 86 | seeds: RW seeds, must be the same size as the image 87 | return: the output probability for each pixel and instances 88 | """ 89 | laplacian = sparse_laplacian(elap, seeds.shape) 90 | lap_u, B_T = lap2lapu_bt(laplacian, seeds) 91 | pm = sparse_pm(seeds) 92 | 93 | # Random Walker Solution 94 | pu = spsolve(lap_u, B_T.dot(pm), use_umfpack=True) 95 | 96 | # Save out_put for backward 97 | if type(pu) == np.ndarray: 98 | return np.array(pu, dtype=np.float32), lap_u 99 | else: 100 | return np.array(pu.toarray(), dtype=np.float32), lap_u 101 | --------------------------------------------------------------------------------