├── README.md └── torch_mind.py /README.md: -------------------------------------------------------------------------------- 1 | # MIND-pytorch 2 | Modality independent neighbourhood descriptor 3 | https://www.sciencedirect.com/science/article/abs/pii/S1361841512000643?via%3Dihub 4 | -------------------------------------------------------------------------------- /torch_mind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | class MIND(torch.nn.Module): 6 | 7 | def __init__(self, non_local_region_size =9, patch_size =7, neighbor_size =3, gaussian_patch_sigma =3.0): 8 | super(MIND, self).__init__() 9 | self.nl_size =non_local_region_size 10 | self.p_size =patch_size 11 | self.n_size =neighbor_size 12 | self.sigma2 =gaussian_patch_sigma *gaussian_patch_sigma 13 | 14 | 15 | # calc shifted images in non local region 16 | self.image_shifter =torch.nn.Conv2d(in_channels =1, out_channels =self.nl_size *self.nl_size, 17 | kernel_size =(self.nl_size, self.nl_size), 18 | stride=1, padding=((self.nl_size-1)//2, (self.nl_size-1)//2), 19 | dilation=1, groups=1, bias=False, padding_mode='zeros') 20 | 21 | for i in range(self.nl_size*self.nl_size): 22 | t =torch.zeros((1, self.nl_size, self.nl_size)) 23 | t[0, i%self.nl_size, i//self.nl_size] =1 24 | self.image_shifter.weight.data[i] =t 25 | 26 | 27 | # patch summation 28 | self.summation_patcher =torch.nn.Conv2d(in_channels =self.nl_size*self.nl_size, out_channels =self.nl_size*self.nl_size, 29 | kernel_size =(self.p_size, self.p_size), 30 | stride=1, padding=((self.p_size-1)//2, (self.p_size-1)//2), 31 | dilation=1, groups=self.nl_size*self.nl_size, bias=False, padding_mode='zeros') 32 | 33 | for i in range(self.nl_size*self.nl_size): 34 | # gaussian kernel 35 | t =torch.zeros((1, self.p_size, self.p_size)) 36 | cx =(self.p_size-1)//2 37 | cy =(self.p_size-1)//2 38 | for j in range(self.p_size *self.p_size): 39 | x=j%self.p_size 40 | y=j//self.p_size 41 | d2 =torch.norm( torch.tensor([x-cx, y-cy]).float(), 2) 42 | t[0, x, y] =math.exp(-d2 / self.sigma2) 43 | 44 | self.summation_patcher.weight.data[i] =t 45 | 46 | 47 | # neighbor images 48 | self.neighbors =torch.nn.Conv2d(in_channels =1, out_channels =self.n_size*self.n_size, 49 | kernel_size =(self.n_size, self.n_size), 50 | stride=1, padding=((self.n_size-1)//2, (self.n_size-1)//2), 51 | dilation=1, groups=1, bias=False, padding_mode='zeros') 52 | 53 | for i in range(self.n_size*self.n_size): 54 | t =torch.zeros((1, self.n_size, self.n_size)) 55 | t[0, i%self.n_size, i//self.n_size] =1 56 | self.neighbors.weight.data[i] =t 57 | 58 | 59 | # neighbor patcher 60 | self.neighbor_summation_patcher =torch.nn.Conv2d(in_channels =self.n_size*self.n_size, out_channels =self.n_size*self.n_size, 61 | kernel_size =(self.p_size, self.p_size), 62 | stride=1, padding=((self.p_size-1)//2, (self.p_size-1)//2), 63 | dilation=1, groups=self.n_size*self.n_size, bias=False, padding_mode='zeros') 64 | 65 | for i in range(self.n_size*self.n_size): 66 | t =torch.ones((1, self.p_size, self.p_size)) 67 | self.neighbor_summation_patcher.weight.data[i] =t 68 | 69 | 70 | 71 | def forward(self, orig): 72 | assert(len(orig.shape) ==4) 73 | assert(orig.shape[1] ==1) 74 | 75 | # get original image channel stack 76 | orig_stack =torch.stack([orig.squeeze(dim=1) for i in range(self.nl_size*self.nl_size)], dim=1) 77 | 78 | # get shifted images 79 | shifted =self.image_shifter(orig) 80 | 81 | # get image diff 82 | diff_images =shifted -orig_stack 83 | 84 | # diff's L2 norm 85 | Dx_alpha =self.summation_patcher(torch.pow(diff_images, 2.0)) 86 | 87 | # calc neighbor's variance 88 | neighbor_images =self.neighbor_summation_patcher( self.neighbors(orig) ) 89 | Vx =neighbor_images.var(dim =1).unsqueeze(dim =1) 90 | 91 | # output mind 92 | nume =torch.exp(-Dx_alpha /(Vx +1e-8)) 93 | denomi =nume.sum(dim =1).unsqueeze(dim =1) 94 | mind =nume /denomi 95 | return mind 96 | 97 | 98 | class MINDLoss(torch.nn.Module): 99 | 100 | def __init__(self, non_local_region_size =9, patch_size =7, neighbor_size =3, gaussian_patch_sigma =3.0): 101 | super(MINDLoss, self).__init__() 102 | self.nl_size =non_local_region_size 103 | self.MIND =MIND(non_local_region_size =non_local_region_size, 104 | patch_size =patch_size, 105 | neighbor_size =neighbor_size, 106 | gaussian_patch_sigma =gaussian_patch_sigma) 107 | 108 | def forward(self, input, target): 109 | in_mind =self.MIND(input) 110 | tar_mind =self.MIND(target) 111 | mind_diff =in_mind -tar_mind 112 | l1 =torch.norm( mind_diff, 1) 113 | return l1/(input.shape[2] *input.shape[3] *self.nl_size *self.nl_size) 114 | 115 | 116 | if __name__ =="__main__": 117 | mind =MIND() 118 | orig =torch.ones(4,1,128,128) 119 | mind(orig) 120 | --------------------------------------------------------------------------------