├── README.md └── ProtoSeg.py /README.md: -------------------------------------------------------------------------------- 1 | # Segmentation Ability Map (SAM) for interpreting segmentation neural networks. 2 | 3 | Sheng He, Yanfang Feng, P. Ellen Grant, Yangming Ou "Segmentation ability map: Interpret deep features for medical image segmentation", Medical Image Analysis. [`PDF`](https://www.sciencedirect.com/science/article/pii/S1361841522003541) 4 | 5 | # How to use it? 6 | 7 | This is an example of feature tensors extracted on any layer. 8 | 9 | The feature can be extracted on any layer in your network 10 | If the size of the feature does not match your input size, please resize it to match the target size. 11 | 12 | ```Python 13 | x = torch.rand(2,64,32,32) 14 | ``` 15 | 16 | This is the probability map obtained from the output of your network, which is 17 | a guide for the protoSeg to compute the prototype of target leision or no-leison. 18 | 19 | Note: this is not the ground-truth (on test set the ground-truth is not available) 20 | The values of pred_map should be in [0,1] where 1 represents the target lesion. 21 | If you use the softmax on the last layer, convert it to probability map into [0,1] where 1 represents target leision. 22 | 23 | ```Python 24 | pred_map = torch.rand(2,1,32,32) 25 | neters = ProtoSeg(ndims='2d') 26 | probability_map = neters(x,pred,mask=None) 27 | ``` 28 | 29 | you will get a binary map (target lesion: 1, others: 0) based on the input features "x" 30 | ```Python 31 | binary_map = torch.argmax(probability_map,1) # Note: this is not differentiable. 32 | ``` 33 | -------------------------------------------------------------------------------- /ProtoSeg.py: -------------------------------------------------------------------------------- 1 | #--------------------- 2 | # 3 | # This is the code for our ProtoSeg paper: 4 | # Segmentation Ability Map: Interpret deep features for medical image segmentation, Medical Image Analysis 5 | # https://www.sciencedirect.com/science/article/pii/S1361841522003541 6 | # 7 | # @Author: Sheng He 8 | # @Email: heshengxgd@gmail.com 9 | # 10 | #-------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | class ProtoSeg(nn.Module): 16 | def __init__(self,ndims='2d'): 17 | super().__init__() 18 | 19 | # for 1D: self.dims=(2) 20 | # for 2D image: self.dims=(2,3) 21 | # for 3D image: self.dims=(2,3,4) 22 | if ndims == '1d': 23 | self.dims = (2) 24 | elif ndims == '2d': 25 | self.dims = (2,3) 26 | elif ndims == '3d': 27 | self.dims = (2,3,4) 28 | else: 29 | raise ValueError('ndims must be in [1d,2d,3d]') 30 | 31 | self.softmax = nn.Softmax(dim=1) 32 | 33 | def forward(self,xfeat,pred,mask=None): 34 | #@ xfeat: the deep feature need to be inperpreted 35 | #@ pred: the initial segmentation results from the last layer of the network 36 | #@ mask is to maks out the background of the image (without any tissue) 37 | 38 | if mask is not None: 39 | pos_prototype = torch.sum(xfeat*pred*mask,dim=self.dims,keepdim=True) 40 | num_prototype = torch.sum(pred*mask,dim=self.dims,keepdim=True) 41 | pos_prototype = pos_prototype / num_prototype 42 | 43 | rpred = 1 - pred 44 | 45 | neg_prototype = torch.sum(xfeat*rpred*mask,dim=self.dims,keepdim=True) 46 | num_prototype = torch.sum(rpred*mask,dim=self.dims,keepdim=True) 47 | neg_prototype = neg_prototype / num_prototype 48 | 49 | pfeat = -torch.pow(xfeat-pos_prototype,2).sum(1,keepdim=True) 50 | nfeat = -torch.pow(xfeat-neg_prototype,2).sum(1,keepdim=True) 51 | 52 | disfeat = torch.cat([nfeat,pfeat],dim=1) 53 | pred = self.softmax(disfeat) 54 | 55 | else: 56 | pos_prototype = torch.sum(xfeat*pred,dim=self.dims,keepdim=True) 57 | num_prototype = torch.sum(pred,dim=self.dims,keepdim=True) 58 | pos_prototype = pos_prototype / num_prototype 59 | 60 | rpred = 1 - pred 61 | 62 | neg_prototype = torch.sum(xfeat*rpred,dim=self.dims,keepdim=True) 63 | num_prototype = torch.sum(rpred,dim=self.dims,keepdim=True) 64 | neg_prototype = neg_prototype / num_prototype 65 | 66 | pfeat = -torch.pow(xfeat-pos_prototype,2).sum(1,keepdim=True) 67 | nfeat = -torch.pow(xfeat-neg_prototype,2).sum(1,keepdim=True) 68 | 69 | disfeat = torch.cat([nfeat,pfeat],dim=1) 70 | pred = self.softmax(disfeat) 71 | 72 | return pred 73 | 74 | if __name__ == '__main__': 75 | # examples to show how to use it 76 | #---------------------------------------- 77 | # this is an example of feature tensors extracted on any layers 78 | # The feature can be extracted on any layers in your network 79 | # If the size of the feature does not match your input size, please resize it 80 | 81 | x = torch.rand(2,64,32,32) 82 | 83 | # This is the probability map obtained from the output of your network, which is 84 | # an guide for the protoSeg to compute the prototype of target leision or no-leison 85 | # Note: this is not the ground-truth (on test set the ground-truth are not available) 86 | # The values of pred_map should be in [0,1] where 1 represents the target lesion. 87 | # If you use the softmax on the last layer, convert it to probability map into [0,1] where 1 represents target leision. 88 | 89 | pred_map = torch.rand(2,1,32,32) 90 | 91 | neters = ProtoSeg(ndims='2d') 92 | 93 | probability_map = neters(x,pred_map,mask=None) 94 | 95 | # you will get a binary map (target lesion: 1, others: 0) based on the input features "x" 96 | binary_map = torch.argmax(probability_map,1) # Note: this is not differentiable. 97 | 98 | 99 | --------------------------------------------------------------------------------