├── CGRmodes ├── CGR.py ├── __pycache__ │ ├── CGR.cpython-36.pyc │ ├── EPG.cpython-35.pyc │ ├── EPG.cpython-36.pyc │ ├── EPG2.cpython-36.pyc │ ├── EPGbaseline.cpython-35.pyc │ ├── layers.cpython-36.pyc │ └── utils.cpython-36.pyc ├── layers.py └── utils.py ├── README.md ├── __pycache__ ├── evaluation.cpython-36.pyc └── transform.cpython-36.pyc ├── data ├── __pycache__ │ ├── dataloader.cpython-36.pyc │ └── dataloader2.cpython-36.pyc ├── dataloader.py ├── dataloader2.py └── utils │ ├── comm.py │ ├── dice3D.py │ ├── direct_field │ ├── __pycache__ │ │ ├── df_cardia.cpython-36.pyc │ │ ├── df_cardia.cpython-37.pyc │ │ ├── utils_df.cpython-36.pyc │ │ └── utils_df.cpython-37.pyc │ ├── df_cardia.py │ └── utils_df.py │ ├── image_list.py │ ├── init_net.py │ ├── metrics.py │ ├── utils_loss.py │ └── vis_utils.py ├── evaluation.py ├── fig ├── 1.png ├── 2.png └── 3.png ├── models ├── ._models.py ├── AttU_Net_model.py ├── BaseNet.py ├── F3net.py ├── GSConv.py ├── InfNet_Res2Net.py ├── LDF.py ├── LDunet.py ├── PraNet_Res2Net.py ├── PraNet_ResNet.py ├── R2U_Net_model.py ├── Res2Net_v1b.py ├── UNet_2Plus.py ├── __pycache__ │ ├── AttU_Net_model.cpython-35.pyc │ ├── BaseNet.cpython-35.pyc │ ├── BaseNet.cpython-36.pyc │ ├── F3net.cpython-35.pyc │ ├── GSConv.cpython-35.pyc │ ├── GSConv.cpython-36.pyc │ ├── InfNet_Res2Net.cpython-35.pyc │ ├── InfNet_Res2Net.cpython-36.pyc │ ├── LDF.cpython-35.pyc │ ├── LDunet.cpython-35.pyc │ ├── PraNet_Res2Net.cpython-35.pyc │ ├── R2U_Net_model.cpython-35.pyc │ ├── Res2Net_v1b.cpython-35.pyc │ ├── UNet_2Plus.cpython-35.pyc │ ├── UNet_2Plus.cpython-36.pyc │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── attention_blocks.cpython-35.pyc │ ├── attention_blocks.cpython-36.pyc │ ├── cenet.cpython-35.pyc │ ├── cenet.cpython-36.pyc │ ├── custom_functions.cpython-35.pyc │ ├── custom_functions.cpython-36.pyc │ ├── deeplab_v3p.cpython-35.pyc │ ├── denseunet_model.cpython-35.pyc │ ├── fcn.cpython-35.pyc │ ├── fcn.cpython-36.pyc │ ├── init_weights.cpython-35.pyc │ ├── init_weights.cpython-36.pyc │ ├── layers.cpython-35.pyc │ ├── layers.cpython-36.pyc │ ├── models.cpython-35.pyc │ ├── models.cpython-36.pyc │ ├── multi_scale.cpython-35.pyc │ ├── multi_scale.cpython-36.pyc │ ├── multi_scale_module.cpython-35.pyc │ ├── newnet.cpython-35.pyc │ ├── norm.cpython-35.pyc │ ├── norm.cpython-36.pyc │ ├── resnet.cpython-35.pyc │ ├── resnet.cpython-36.pyc │ ├── unet.cpython-35.pyc │ ├── unet.cpython-36.pyc │ ├── vggunet.cpython-35.pyc │ ├── wassp.cpython-35.pyc │ ├── wassp.cpython-36.pyc │ └── wnet.cpython-35.pyc ├── adaptive_avgmax_pool.py ├── attention_blocks.py ├── backbone │ ├── DenseNet.py │ ├── Res2Net.py │ ├── ResNet.py │ ├── VGGNet.py │ ├── __init__.py │ └── __pycache__ │ │ ├── Res2Net.cpython-35.pyc │ │ ├── Res2Net.cpython-36.pyc │ │ ├── __init__.cpython-35.pyc │ │ └── __init__.cpython-36.pyc ├── cenet.py ├── custom_functions.py ├── deeplab │ ├── LDF.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── aspp.cpython-35.pyc │ │ ├── aspp.cpython-36.pyc │ │ ├── decoder.cpython-35.pyc │ │ └── decoder.cpython-36.pyc │ ├── aspp.py │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── drn.cpython-35.pyc │ │ │ ├── drn.cpython-36.pyc │ │ │ ├── mobilenet.cpython-35.pyc │ │ │ ├── mobilenet.cpython-36.pyc │ │ │ ├── resnet.cpython-35.pyc │ │ │ ├── resnet.cpython-36.pyc │ │ │ ├── xception.cpython-35.pyc │ │ │ └── xception.cpython-36.pyc │ │ ├── drn.py │ │ ├── mobilenet.py │ │ ├── resnet.py │ │ └── xception.py │ ├── decoder.py │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── batchnorm.cpython-35.pyc │ │ ├── batchnorm.cpython-36.pyc │ │ ├── comm.cpython-35.pyc │ │ ├── comm.cpython-36.pyc │ │ ├── replicate.cpython-35.pyc │ │ └── replicate.cpython-36.pyc │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py ├── deeplab_v3p.py ├── denseunet_model.py ├── fcn.py ├── init_weights.py ├── layers.py ├── models.py ├── multi_scale.py ├── multi_scale_module.py ├── mynn.py ├── net.py ├── new2net.py ├── newnet.py ├── norm.py ├── pretrain │ ├── SAMNet_with_ImageNet_pretrain.pth │ └── __init__.py ├── resnet.py ├── segnet.py ├── test.py ├── unet.py ├── vggunet.py ├── wassp.py └── wnet.py ├── test_XS2021.py ├── transform.py └── trian_CGR_XS.py /CGRmodes/__pycache__/CGR.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/CGR.cpython-36.pyc -------------------------------------------------------------------------------- /CGRmodes/__pycache__/EPG.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/EPG.cpython-35.pyc -------------------------------------------------------------------------------- /CGRmodes/__pycache__/EPG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/EPG.cpython-36.pyc -------------------------------------------------------------------------------- /CGRmodes/__pycache__/EPG2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/EPG2.cpython-36.pyc -------------------------------------------------------------------------------- /CGRmodes/__pycache__/EPGbaseline.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/EPGbaseline.cpython-35.pyc -------------------------------------------------------------------------------- /CGRmodes/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /CGRmodes/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /CGRmodes/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class conv(nn.Module): 5 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, padding='same', bias=False, bn=True, relu=False): 6 | super(conv, self).__init__() 7 | if '__iter__' not in dir(kernel_size): 8 | kernel_size = (kernel_size, kernel_size) 9 | if '__iter__' not in dir(stride): 10 | stride = (stride, stride) 11 | if '__iter__' not in dir(dilation): 12 | dilation = (dilation, dilation) 13 | 14 | if padding == 'same': 15 | width_pad_size = kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1) 16 | height_pad_size = kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1) 17 | elif padding == 'valid': 18 | width_pad_size = 0 19 | height_pad_size = 0 20 | else: 21 | if '__iter__' in dir(padding): 22 | width_pad_size = padding[0] * 2 23 | height_pad_size = padding[1] * 2 24 | else: 25 | width_pad_size = padding * 2 26 | height_pad_size = padding * 2 27 | 28 | width_pad_size = width_pad_size // 2 + (width_pad_size % 2 - 1) 29 | height_pad_size = height_pad_size // 2 + (height_pad_size % 2 - 1) 30 | pad_size = (width_pad_size, height_pad_size) 31 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_size, dilation, groups, bias=bias) 32 | self.reset_parameters() 33 | 34 | if bn is True: 35 | self.bn = nn.BatchNorm2d(out_channels) 36 | else: 37 | self.bn = None 38 | 39 | if relu is True: 40 | self.relu = nn.ReLU(inplace=True) 41 | else: 42 | self.relu = None 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | if self.bn is not None: 47 | x = self.bn(x) 48 | if self.relu is not None: 49 | x = self.relu(x) 50 | return x 51 | 52 | def reset_parameters(self): 53 | nn.init.kaiming_normal_(self.conv.weight) 54 | 55 | 56 | class self_attn(nn.Module): 57 | def __init__(self, in_channels, mode='hw'): 58 | super(self_attn, self).__init__() 59 | 60 | self.mode = mode 61 | 62 | self.query_conv = conv(in_channels, in_channels // 8, kernel_size=(1, 1)) 63 | self.key_conv = conv(in_channels, in_channels // 8, kernel_size=(1, 1)) 64 | self.value_conv = conv(in_channels, in_channels, kernel_size=(1, 1)) 65 | 66 | self.gamma = nn.Parameter(torch.zeros(1)) 67 | self.softmax = nn.Softmax(dim=-1) 68 | 69 | def forward(self, x): 70 | batch_size, channel, height, width = x.size() 71 | 72 | axis = 1 73 | if 'h' in self.mode: 74 | axis *= height 75 | if 'w' in self.mode: 76 | axis *= width 77 | 78 | view = (batch_size, -1, axis) 79 | 80 | projected_query = self.query_conv(x).view(*view).permute(0, 2, 1) 81 | projected_key = self.key_conv(x).view(*view) 82 | 83 | attention_map = torch.bmm(projected_query, projected_key) 84 | attention = self.softmax(attention_map) 85 | projected_value = self.value_conv(x).view(*view) 86 | 87 | out = torch.bmm(projected_value, attention.permute(0, 2, 1)) 88 | out = out.view(batch_size, channel, height, width) 89 | 90 | out = self.gamma * out + x 91 | return out -------------------------------------------------------------------------------- /CGRmodes/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from thop import profile 4 | from thop import clever_format 5 | from scipy.ndimage import map_coordinates 6 | 7 | from torch.optim.lr_scheduler import _LRScheduler 8 | 9 | 10 | class PolyLr(_LRScheduler): 11 | def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1): 12 | self.gamma = gamma 13 | self.max_iteration = max_iteration 14 | self.minimum_lr = minimum_lr 15 | self.warmup_iteration = warmup_iteration 16 | 17 | super(PolyLr, self).__init__(optimizer, last_epoch) 18 | 19 | def poly_lr(self, base_lr, step): 20 | return (base_lr - self.minimum_lr) * ((1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr 21 | 22 | def warmup_lr(self, base_lr, alpha): 23 | return base_lr * (1 / 10.0 * (1 - alpha) + alpha) 24 | 25 | def get_lr(self): 26 | if self.last_epoch < self.warmup_iteration: 27 | alpha = self.last_epoch / self.warmup_iteration 28 | lrs = [min(self.warmup_lr(base_lr, alpha), self.poly_lr(base_lr, self.last_epoch)) for base_lr in 29 | self.base_lrs] 30 | else: 31 | lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] 32 | 33 | return lrs 34 | 35 | def clip_gradient(optimizer, grad_clip): 36 | for group in optimizer.param_groups: 37 | for param in group['params']: 38 | if param.grad is not None: 39 | param.grad.data.clamp_(-grad_clip, grad_clip) 40 | 41 | 42 | class AvgMeter(object): 43 | def __init__(self, num=40): 44 | self.num = num 45 | self.reset() 46 | 47 | def reset(self): 48 | self.val = 0 49 | self.avg = 0 50 | self.sum = 0 51 | self.count = 0 52 | self.losses = [] 53 | 54 | def update(self, val, n=1): 55 | self.val = val 56 | self.sum += val * n 57 | self.count += n 58 | self.avg = self.sum / self.count 59 | self.losses.append(val) 60 | 61 | def show(self): 62 | return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):])) 63 | 64 | 65 | def CalParams(model, input_tensor): 66 | flops, params = profile(model, inputs=(input_tensor,)) 67 | flops, params = clever_format([flops, params], "%.3f") 68 | print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params)) 69 | 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CGRNet: Contour-Guided Graph Reasoning Network for Ambiguous Biomedical Image Segmentation 2 | [![MIT License](https://img.shields.io/badge/license-MIT-green.svg)](https://opensource.org/licenses/MIT) [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](http://makeapullrequest.com) 3 | 4 | > **Authors:** 5 | > Kun Wang, 6 | > Xiaohong Zhang, 7 | > Yuting Lu, 8 | > Xiangbo Zhang, 9 | > Wei Zhang. 10 | 11 | 12 | 13 | ### 1.1. 🔥NEWS🔥 : 14 | - [2021/10/30]:fire: Release the inference code! 15 | - [2021/10/28] Create repository. 16 | 17 | 18 | ## Prerequisites 19 | - [Python 3.5](https://www.python.org/) 20 | - [Pytorch 1.1](http://pytorch.org/) 21 | - [OpenCV 4.0](https://opencv.org/) 22 | - [Numpy 1.15](https://numpy.org/) 23 | - [TensorboardX](https://github.com/lanpa/tensorboardX) 24 | 25 | ## Clone repository 26 | ```shell 27 | git clone https://github.com/DLWK/CGRNet.git 28 | cd CGRNet/ 29 | ``` 30 | ## Download dataset 31 | Download the datasets and unzip them into `data` folder 32 | - [COVID-19](https://medicalsegmentation.com/covid19/) 33 | - Download dataset from following [URL](https://drive.google.com/file/d/17Cs2JhKOKwt4usiAYJVJMnXfyZWySn3s/view?usp=sharing) 34 | - You can use our data/dataloader2.py to load the datasets. 35 | ## Training & Evaluation 36 | ```shell 37 | cd CGRNet/ 38 | python3 train.py 39 | ################ 40 | python3 test.py 41 | 42 | ``` 43 | 44 | ## Demo 45 | ```shell 46 | from CGRmodes.CGR import CGRNet 47 | if __name__ == '__main__': 48 | ras =CGRNet(n_channels=3, n_classes=1).cuda() 49 | input_tensor = torch.randn(4, 3, 352, 352).cuda() 50 | out,out1 = ras(input_tensor) 51 | print(out.shape) 52 | ``` 53 | ### 2.1 Overview framework 54 |

55 |
56 | 57 | 58 |

59 | 60 | ### 2.2 Visualization Results 61 |

62 |
63 | 64 | 65 |

66 | 67 |

68 |
69 | 70 | 71 |

72 | 73 | 74 | ## Citation 75 | - If you find this work is helpful, please cite our paper 76 | ``` 77 | @article{wang2022cgrnet, 78 | title={CGRNet: Contour-guided graph reasoning network for ambiguous biomedical image segmentation}, 79 | author={Wang, Kun and Zhang, Xiaohong and Lu, Yuting and Zhang, Xiangbo and Zhang, Wei}, 80 | journal={Biomedical Signal Processing and Control}, 81 | volume={75}, 82 | pages={103621}, 83 | year={2022}, 84 | publisher={Elsevier} 85 | ``` 86 | 87 | 88 | 89 | 90 | # Tips 91 | :fire:If you have any questions about our work, please do not hesitate to contact us by emails. 92 | **[⬆ back to top](#0-preface)** 93 | -------------------------------------------------------------------------------- /__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/transform.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/__pycache__/transform.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/__pycache__/dataloader2.cpython-36.pyc -------------------------------------------------------------------------------- /data/utils/comm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import pickle 4 | 5 | def get_world_size(): 6 | if not dist.is_available(): 7 | return 1 8 | if not dist.is_initialized(): 9 | return 1 10 | return dist.get_world_size() 11 | 12 | def get_rank(): 13 | if not dist.is_available(): 14 | return 0 15 | if not dist.is_initialized(): 16 | return 0 17 | return dist.get_rank() 18 | 19 | def synchronize(): 20 | """ 21 | Helper function to synchronize (barrier) among all processes when 22 | using distributed training 23 | """ 24 | if not dist.is_available(): 25 | return 26 | if not dist.is_initialized(): 27 | return 28 | world_size = dist.get_world_size() 29 | if world_size == 1: 30 | return 31 | dist.barrier() 32 | 33 | def all_gather(data): 34 | """ 35 | Run all_gather on arbitrary picklable data (not necessarily tensors) 36 | Args: 37 | data: any picklable object 38 | Returns: 39 | list[data]: list of data gathered from each rank 40 | """ 41 | world_size = get_world_size() 42 | if world_size == 1: 43 | return [data] 44 | 45 | # serialized to a Tensor 46 | buffer = pickle.dumps(data) 47 | storage = torch.ByteStorage.from_buffer(buffer) 48 | tensor = torch.ByteTensor(storage).to("cuda") 49 | 50 | # obtain Tensor size of each rank 51 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 52 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] 53 | dist.all_gather(size_list, local_size) 54 | size_list = [int(size.item()) for size in size_list] 55 | max_size = max(size_list) 56 | 57 | # receiving Tensor from all ranks 58 | # we pad the tensor because torch all_gather does not support 59 | # gathering tensors of different shapes 60 | tensor_list = [] 61 | for _ in size_list: 62 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 63 | if local_size != max_size: 64 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 65 | tensor = torch.cat((tensor, padding), dim=0) 66 | dist.all_gather(tensor_list, tensor) 67 | 68 | data_list = [] 69 | for size, tensor in zip(size_list, tensor_list): 70 | buffer = tensor.cpu().numpy().tobytes()[:size] 71 | data_list.append(pickle.loads(buffer)) 72 | 73 | return data_list -------------------------------------------------------------------------------- /data/utils/dice3D.py: -------------------------------------------------------------------------------- 1 | """ 2 | author: Clément Zotti (clement.zotti@usherbrooke.ca) 3 | date: April 2017 4 | 5 | DESCRIPTION : 6 | The script provide helpers functions to handle nifti image format: 7 | - load_nii() 8 | - save_nii() 9 | 10 | to generate metrics for two images: 11 | - metrics() 12 | 13 | And it is callable from the command line (see below). 14 | Each function provided in this script has comments to understand 15 | how they works. 16 | 17 | HOW-TO: 18 | 19 | This script was tested for python 3.4. 20 | 21 | First, you need to install the required packages with 22 | pip install -r requirements.txt 23 | 24 | After the installation, you have two ways of running this script: 25 | 1) python metrics.py ground_truth/patient001_ED.nii.gz prediction/patient001_ED.nii.gz 26 | 2) python metrics.py ground_truth/ prediction/ 27 | 28 | The first option will print in the console the dice and volume of each class for the given image. 29 | The second option wiil ouput a csv file where each images will have the dice and volume of each class. 30 | 31 | 32 | Link: http://acdc.creatis.insa-lyon.fr 33 | 34 | """ 35 | 36 | import os 37 | from glob import glob 38 | import time 39 | import re 40 | import argparse 41 | import nibabel as nib 42 | # import pandas as pd 43 | from medpy.metric.binary import hd, dc 44 | import numpy as np 45 | 46 | 47 | 48 | HEADER = ["Name", "Dice LV", "Volume LV", "Err LV(ml)", 49 | "Dice RV", "Volume RV", "Err RV(ml)", 50 | "Dice MYO", "Volume MYO", "Err MYO(ml)"] 51 | 52 | # 53 | # Utils functions used to sort strings into a natural order 54 | # 55 | def conv_int(i): 56 | return int(i) if i.isdigit() else i 57 | 58 | 59 | def natural_order(sord): 60 | """ 61 | Sort a (list,tuple) of strings into natural order. 62 | 63 | Ex: 64 | 65 | ['1','10','2'] -> ['1','2','10'] 66 | 67 | ['abc1def','ab10d','b2c','ab1d'] -> ['ab1d','ab10d', 'abc1def', 'b2c'] 68 | 69 | """ 70 | if isinstance(sord, tuple): 71 | sord = sord[0] 72 | return [conv_int(c) for c in re.split(r'(\d+)', sord)] 73 | 74 | 75 | # 76 | # Utils function to load and save nifti files with the nibabel package 77 | # 78 | def load_nii(img_path): 79 | """ 80 | Function to load a 'nii' or 'nii.gz' file, The function returns 81 | everyting needed to save another 'nii' or 'nii.gz' 82 | in the same dimensional space, i.e. the affine matrix and the header 83 | 84 | Parameters 85 | ---------- 86 | 87 | img_path: string 88 | String with the path of the 'nii' or 'nii.gz' image file name. 89 | 90 | Returns 91 | ------- 92 | Three element, the first is a numpy array of the image values, 93 | the second is the affine transformation of the image, and the 94 | last one is the header of the image. 95 | """ 96 | nimg = nib.load(img_path) 97 | return nimg.get_data(), nimg.affine, nimg.header 98 | 99 | 100 | def save_nii(img_path, data, affine, header): 101 | """ 102 | Function to save a 'nii' or 'nii.gz' file. 103 | 104 | Parameters 105 | ---------- 106 | 107 | img_path: string 108 | Path to save the image should be ending with '.nii' or '.nii.gz'. 109 | 110 | data: np.array 111 | Numpy array of the image data. 112 | 113 | affine: list of list or np.array 114 | The affine transformation to save with the image. 115 | 116 | header: nib.Nifti1Header 117 | The header that define everything about the data 118 | (pleasecheck nibabel documentation). 119 | """ 120 | nimg = nib.Nifti1Image(data, affine=affine, header=header) 121 | nimg.to_filename(img_path) 122 | 123 | 124 | # 125 | # Functions to process files, directories and metrics 126 | # 127 | def metrics(img_gt, img_pred, voxel_size): 128 | """ 129 | Function to compute the metrics between two segmentation maps given as input. 130 | 131 | Parameters 132 | ---------- 133 | img_gt: np.array 134 | Array of the ground truth segmentation map. 135 | 136 | img_pred: np.array 137 | Array of the predicted segmentation map. 138 | 139 | voxel_size: list, tuple or np.array 140 | The size of a voxel of the images used to compute the volumes. 141 | 142 | Return 143 | ------ 144 | A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml), 145 | Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)] 146 | """ 147 | 148 | if img_gt.ndim != img_pred.ndim: 149 | raise ValueError("The arrays 'img_gt' and 'img_pred' should have the " 150 | "same dimension, {} against {}".format(img_gt.ndim, 151 | img_pred.ndim)) 152 | 153 | res = [] 154 | # Loop on each classes of the input images 155 | for c in [3, 1, 2]: 156 | # Copy the gt image to not alterate the input 157 | gt_c_i = np.copy(img_gt) 158 | gt_c_i[gt_c_i != c] = 0 159 | 160 | # Copy the pred image to not alterate the input 161 | pred_c_i = np.copy(img_pred) 162 | pred_c_i[pred_c_i != c] = 0 163 | 164 | # Clip the value to compute the volumes 165 | gt_c_i = np.clip(gt_c_i, 0, 1) 166 | pred_c_i = np.clip(pred_c_i, 0, 1) 167 | 168 | # Compute the Dice 169 | dice = dc(gt_c_i, pred_c_i) 170 | 171 | # Compute volume 172 | volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000. 173 | volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000. 174 | 175 | # res += [dice, volpred, volpred-volgt] 176 | res += [dice] 177 | 178 | return res 179 | 180 | 181 | def compute_metrics_on_files(path_gt, path_pred): 182 | """ 183 | Function to give the metrics for two files 184 | 185 | Parameters 186 | ---------- 187 | 188 | path_gt: string 189 | Path of the ground truth image. 190 | 191 | path_pred: string 192 | Path of the predicted image. 193 | """ 194 | gt, _, header = load_nii(path_gt) 195 | pred, _, _ = load_nii(path_pred) 196 | zooms = header.get_zooms() 197 | 198 | name = os.path.basename(path_gt) 199 | name = name.split('.')[0] 200 | res = metrics(gt, pred, zooms) 201 | res = ["{:.3f}".format(r) for r in res] 202 | 203 | formatting = "{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}" 204 | print(formatting.format(*HEADER)) 205 | print(formatting.format(name, *res)) 206 | 207 | 208 | def compute_metrics_on_directories(dir_gt, dir_pred): 209 | """ 210 | Function to generate a csv file for each images of two directories. 211 | 212 | Parameters 213 | ---------- 214 | 215 | path_gt: string 216 | Directory of the ground truth segmentation maps. 217 | 218 | path_pred: string 219 | Directory of the predicted segmentation maps. 220 | """ 221 | lst_gt = sorted(glob(os.path.join(dir_gt, '*')), key=natural_order) 222 | lst_pred = sorted(glob(os.path.join(dir_pred, '*')), key=natural_order) 223 | 224 | res = [] 225 | for p_gt, p_pred in zip(lst_gt, lst_pred): 226 | if os.path.basename(p_gt) != os.path.basename(p_pred): 227 | raise ValueError("The two files don't have the same name" 228 | " {}, {}.".format(os.path.basename(p_gt), 229 | os.path.basename(p_pred))) 230 | 231 | gt, _, header = load_nii(p_gt) 232 | pred, _, _ = load_nii(p_pred) 233 | zooms = header.get_zooms() 234 | res.append(metrics(gt, pred, zooms)) 235 | 236 | lst_name_gt = [os.path.basename(gt).split(".")[0] for gt in lst_gt] 237 | res = [[n,] + r for r, n in zip(res, lst_name_gt)] 238 | df = pd.DataFrame(res, columns=HEADER) 239 | df.to_csv("results_{}.csv".format(time.strftime("%Y%m%d_%H%M%S")), index=False) 240 | 241 | def main(path_gt, path_pred): 242 | """ 243 | Main function to select which method to apply on the input parameters. 244 | """ 245 | if os.path.isfile(path_gt) and os.path.isfile(path_pred): 246 | compute_metrics_on_files(path_gt, path_pred) 247 | elif os.path.isdir(path_gt) and os.path.isdir(path_pred): 248 | compute_metrics_on_directories(path_gt, path_pred) 249 | else: 250 | raise ValueError( 251 | "The paths given needs to be two directories or two files.") 252 | 253 | 254 | if __name__ == "__main__": 255 | # parser = argparse.ArgumentParser( 256 | # description="Script to compute ACDC challenge metrics.") 257 | # parser.add_argument("GT_IMG", type=str, help="Ground Truth image") 258 | # parser.add_argument("PRED_IMG", type=str, help="Predicted image") 259 | # args = parser.parse_args() 260 | # main(args.GT_IMG, args.PRED_IMG) 261 | 262 | ############################################################## 263 | gt = np.random.randint(0, 4, size=(224, 224, 100)) 264 | print(np.unique(gt)) 265 | pred = np.array(gt) 266 | pred[pred==2] = 3 267 | result = metrics(gt, pred, voxel_size=(224, 224, 100)) 268 | print(result) -------------------------------------------------------------------------------- /data/utils/direct_field/__pycache__/df_cardia.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/utils/direct_field/__pycache__/df_cardia.cpython-36.pyc -------------------------------------------------------------------------------- /data/utils/direct_field/__pycache__/df_cardia.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/utils/direct_field/__pycache__/df_cardia.cpython-37.pyc -------------------------------------------------------------------------------- /data/utils/direct_field/__pycache__/utils_df.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/utils/direct_field/__pycache__/utils_df.cpython-36.pyc -------------------------------------------------------------------------------- /data/utils/direct_field/__pycache__/utils_df.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/utils/direct_field/__pycache__/utils_df.cpython-37.pyc -------------------------------------------------------------------------------- /data/utils/direct_field/df_cardia.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | import math 4 | import cv2 5 | from PIL import Image 6 | 7 | def direct_field(a, norm=True): 8 | """ a: np.ndarray, (h, w) 9 | """ 10 | if a.ndim == 3: 11 | a = np.squeeze(a) 12 | 13 | h, w = a.shape 14 | 15 | a_Image = Image.fromarray(a) 16 | a = a_Image.resize((w, h), Image.NEAREST) 17 | a = np.array(a) 18 | 19 | accumulation = np.zeros((2, h, w), dtype=np.float32) 20 | for i in np.unique(a)[1:]: 21 | # b, ind = ndimage.distance_transform_edt(a==i, return_indices=True) 22 | # c = np.indices((h, w)) 23 | # diff = c - ind 24 | # dr = np.sqrt(np.sum(diff ** 2, axis=0)) 25 | 26 | img = (a == i).astype(np.uint8) 27 | dst, labels = cv2.distanceTransformWithLabels(img, cv2.DIST_L2, cv2.DIST_MASK_PRECISE, labelType=cv2.DIST_LABEL_PIXEL) 28 | index = np.copy(labels) 29 | index[img > 0] = 0 30 | place = np.argwhere(index > 0) 31 | nearCord = place[labels-1,:] 32 | x = nearCord[:, :, 0] 33 | y = nearCord[:, :, 1] 34 | nearPixel = np.zeros((2, h, w)) 35 | nearPixel[0,:,:] = x 36 | nearPixel[1,:,:] = y 37 | grid = np.indices(img.shape) 38 | grid = grid.astype(float) 39 | diff = grid - nearPixel 40 | if norm: 41 | dr = np.sqrt(np.sum(diff**2, axis = 0)) 42 | else: 43 | dr = np.ones_like(img) 44 | 45 | # direction = np.zeros((2, h, w), dtype=np.float32) 46 | # direction[0, b>0] = np.divide(diff[0, b>0], dr[b>0]) 47 | # direction[1, b>0] = np.divide(diff[1, b>0], dr[b>0]) 48 | 49 | direction = np.zeros((2, h, w), dtype=np.float32) 50 | direction[0, img>0] = np.divide(diff[0, img>0], dr[img>0]) 51 | direction[1, img>0] = np.divide(diff[1, img>0], dr[img>0]) 52 | 53 | accumulation[:, img>0] = 0 54 | accumulation = accumulation + direction 55 | 56 | # mag, angle = cv2.cartToPolar(accumulation[0, ...], accumulation[1, ...]) 57 | # for l in np.unique(a)[1:]: 58 | # mag_i = mag[a==l].astype(float) 59 | # t = 1 / mag_i * mag_i.max() 60 | # mag[a==l] = t 61 | # x, y = cv2.polarToCart(mag, angle) 62 | # accumulation = np.stack([x, y], axis=0) 63 | 64 | return accumulation 65 | 66 | 67 | if __name__ == "__main__": 68 | import matplotlib.pyplot as plt 69 | # gt_p = "/home/ffbian/chencheng/XieheCardiac/npydata/dianfen/16100000/gts/16100000_CINE_segmented_SAX_b3.npy" 70 | # gt = np.load(gt_p)[..., 9] # uint8 71 | # print(gt.shape) 72 | 73 | # a_Image = Image.fromarray(gt) 74 | # a = a_Image.resize((224, 224), Image.NEAREST) 75 | # a = np.array(a) # uint8 76 | # print(a.shape, np.unique(a)) 77 | 78 | # # plt.imshow(a) 79 | # # plt.show() 80 | 81 | # ############################################################ 82 | # direction = direct_field(gt) 83 | 84 | # theta = np.arctan2(direction[1,...], direction[0,...]) 85 | # degree = theta * 180 / math.pi 86 | # degree = (degree + 180) / 360 87 | 88 | # plt.imshow(degree) 89 | # plt.show() 90 | 91 | ######################################################## 92 | import json, time, pdb, h5py 93 | data_list = "/home/ffbian/chencheng/XieheCardiac/2DUNet/UNet/libs/datasets/train_new.json" 94 | data_list = "/root/chengfeng/Cardiac/source_code/libs/datasets/jsonLists/acdcList/Dense_TestList.json" 95 | with open(data_list, 'r') as f: 96 | data_infos = json.load(f) 97 | 98 | mag_stat = [] 99 | st = time.time() 100 | for i, di in enumerate(data_infos): 101 | # img_p, times_idx = di 102 | # gt_p = img_p.replace("/imgs/", "/gts/") 103 | # gt = np.load(gt_p)[..., times_idx] 104 | 105 | img = h5py.File(di,'r')['image'] 106 | gt = h5py.File(di,'r')['label'] 107 | gt = np.array(gt).astype(np.float32) 108 | 109 | print(gt.shape) 110 | direction = direct_field(gt, False) 111 | # theta = np.arctan2(direction[1,...], direction[0,...]) 112 | mag, angle = cv2.cartToPolar(direction[0, ...], direction[1, ...]) 113 | # degree = theta * 180 / math.pi 114 | # degree = (degree + 180) / 360 115 | degree = angle / (2 * math.pi) * 255 116 | # degree = (theta - theta.min()) / (theta.max() - theta.min()) * 255 117 | # mag = np.sqrt(np.sum(direction ** 2, axis=0, keepdims=False)) 118 | 119 | 120 | # 归一化 121 | # for l in np.unique(gt)[1:]: 122 | # mag_i = mag[gt==l].astype(float) 123 | # # mag[gt==l] = 1. - mag[gt==l] / np.max(mag[gt==l]) 124 | # t = (mag_i - mag_i.min()) / (mag_i.max() - mag_i.min()) 125 | # mag[gt==l] = np.exp(-10*t) 126 | # print(mag_i.max(), mag_i.min()) 127 | 128 | # for l in np.unique(gt)[1:]: 129 | # mag_i = mag[gt==l].astype(float) 130 | # t = 1 / (mag_i) * mag_i.max() 131 | # # t = np.exp(-0.8*mag_i) * mag_i.max() 132 | # # t = 1 / np.sqrt(mag_i+1) * mag_i.max() 133 | # mag[gt==l] = t 134 | # # print(mag_i.max(), mag_i.min()) 135 | 136 | # mag[mag>0] = 2 * np.exp(-0.8*(mag[mag>0]-1)) 137 | # mag[mag>0] = 2 * np.exp(0.8*(mag[mag>0]-1)) 138 | 139 | 140 | mag_stat.append(mag.max()) 141 | # pdb.set_trace() 142 | 143 | # plt.imshow(degree) 144 | # plt.show() 145 | 146 | ###################### 147 | fig, axs = plt.subplots(1, 3) 148 | axs[0].imshow(degree) 149 | axs[1].imshow(gt) 150 | axs[2].imshow(mag) 151 | plt.show() 152 | 153 | ###################### 154 | if i % 100 == 0: 155 | print("\r\r{}/{} {:.4}s".format(i+1, len(data_infos), time.time()-st)) 156 | print() 157 | 158 | print("total time: ", time.time()-st) 159 | print("Average time: ", (time.time()-st) / len(data_infos)) 160 | # total time: 865.811030626297 161 | # Average time: 0.012969593759126428 162 | 163 | plt.hist(mag_stat) 164 | plt.show() 165 | print(mag_stat) -------------------------------------------------------------------------------- /data/utils/direct_field/utils_df.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.ndimage import distance_transform_edt as distance 4 | from utils.utils_loss import one_hot, simplex, class2one_hot 5 | 6 | def one_hot2dist(seg: np.ndarray) -> np.ndarray: 7 | assert one_hot(torch.Tensor(seg), axis=0) 8 | C: int = len(seg) 9 | 10 | res = np.zeros_like(seg) 11 | for c in range(C): 12 | posmask = seg[c].astype(np.bool) 13 | 14 | if posmask.any(): 15 | negmask = ~posmask 16 | res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask 17 | return res 18 | 19 | def class2dist(seg: np.ndarray, C=4) -> np.ndarray: 20 | """ res: (C, H, W) 21 | """ 22 | if seg.ndim == 2: 23 | seg_tensor = torch.Tensor(seg) 24 | elif seg.ndim == 3: 25 | seg_tensor = torch.Tensor(seg[0]) 26 | elif seg.ndim == 4: 27 | seg_tensor = torch.Tensor(seg[0, ..., 0]) 28 | 29 | seg_onehot = class2one_hot(seg_tensor, C).to(torch.float32) 30 | 31 | assert simplex(seg_onehot) 32 | res = one_hot2dist(seg_onehot[0].numpy()) 33 | return res 34 | 35 | -------------------------------------------------------------------------------- /data/utils/image_list.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from __future__ import division 3 | 4 | import torch 5 | import numpy as np 6 | 7 | class ImageList(object): 8 | """ 9 | Structure that holds a list of images (of possibly 10 | varying sizes) as a single tensor. 11 | This works by padding the images to the same size, 12 | and storing in a field the original sizes of each image 13 | """ 14 | 15 | def __init__(self, tensors, image_sizes): 16 | """ 17 | Arguments: 18 | tensors (tensor) 19 | image_sizes (list[tuple[int, int]]) 20 | """ 21 | self.tensors = tensors 22 | self.image_sizes = image_sizes 23 | 24 | def to(self, *args, **kwargs): 25 | cast_tensor = self.tensors.to(*args, **kwargs) 26 | return ImageList(cast_tensor, self.image_sizes) 27 | 28 | 29 | def to_image_list(tensors, size_divisible=0, return_size=False): 30 | """ 31 | tensors can be an ImageList, a torch.Tensor or 32 | an iterable of Tensors. It can't be a numpy array. 33 | When tensors is an iterable of Tensors, it pads 34 | the Tensors with zeros so that they have the same 35 | shape 36 | """ 37 | if isinstance(tensors, torch.Tensor) and size_divisible > 0: 38 | tensors = [tensors] 39 | 40 | if isinstance(tensors, ImageList): 41 | return tensors 42 | elif isinstance(tensors, torch.Tensor): 43 | # single tensor shape can be inferred 44 | if tensors.dim() == 3: 45 | tensors = tensors[None] 46 | assert tensors.dim() == 4 47 | image_sizes = [tensor.shape[-2:] for tensor in tensors] 48 | return ImageList(tensors, image_sizes) 49 | elif isinstance(tensors, (tuple, list, np.ndarray)): 50 | max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) 51 | 52 | # TODO Ideally, just remove this and let me model handle arbitrary 53 | # input sizs 54 | if size_divisible > 0: 55 | import math 56 | 57 | stride = size_divisible 58 | max_size = list(max_size) 59 | max_size[1] = int(math.ceil(max_size[1] / stride) * stride) 60 | max_size[2] = int(math.ceil(max_size[2] / stride) * stride) 61 | max_size = tuple(max_size) 62 | 63 | batch_shape = (len(tensors),) + max_size 64 | batched_imgs = tensors[0].new(*batch_shape).zero_() 65 | for img, pad_img in zip(tensors, batched_imgs): 66 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 67 | 68 | image_sizes = [im.shape[-2:] for im in tensors] 69 | 70 | # return ImageList(batched_imgs, image_sizes) 71 | if return_size: 72 | return batched_imgs, image_sizes 73 | else: 74 | return batched_imgs 75 | else: 76 | raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) 77 | 78 | -------------------------------------------------------------------------------- /data/utils/init_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def init_weights(net, init_type='normal', gain=0.02): 4 | def init_func(m): 5 | classname = m.__class__.__name__ 6 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 7 | if init_type == 'normal': 8 | nn.init.normal_(m.weight.data, 0.0, gain) 9 | elif init_type == 'xavier': 10 | nn.init.xavier_normal_(m.weight.data, gain=gain) 11 | elif init_type == 'kaiming': 12 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 13 | elif init_type == 'orthogonal': 14 | nn.init.orthogonal_(m.weight.data, gain=gain) 15 | else: 16 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 17 | if hasattr(m, 'bias') and m.bias is not None: 18 | nn.init.constant_(m.bias.data, 0.0) 19 | elif classname.find('BatchNorm2d') != -1: 20 | nn.init.normal_(m.weight.data, 1.0, gain) 21 | nn.init.constant_(m.bias.data, 0.0) 22 | 23 | print('initialize network with %s' % init_type) 24 | net.apply(init_func) -------------------------------------------------------------------------------- /data/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from hausdorff import hausdorff_distance 4 | from medpy.metric.binary import hd, dc 5 | 6 | def dice(pred, target): 7 | pred = pred.contiguous() 8 | target = target.contiguous() 9 | smooth = 0.00001 10 | 11 | # intersection = (pred * target).sum(dim=2).sum(dim=2) 12 | pred_flat = pred.view(1, -1) 13 | target_flat = target.view(1, -1) 14 | 15 | intersection = (pred_flat * target_flat).sum().item() 16 | 17 | # loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))) 18 | dice = (2 * intersection + smooth) / (pred_flat.sum().item() + target_flat.sum().item() + smooth) 19 | return dice 20 | 21 | def dice3D(img_gt, img_pred, voxel_size): 22 | """ 23 | Function to compute the metrics between two segmentation maps given as input. 24 | 25 | Parameters 26 | ---------- 27 | img_gt: np.array 28 | Array of the ground truth segmentation map. 29 | 30 | img_pred: np.array 31 | Array of the predicted segmentation map. 32 | 33 | voxel_size: list, tuple or np.array 34 | The size of a voxel of the images used to compute the volumes. 35 | 36 | Return 37 | ------ 38 | A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml), 39 | Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)] 40 | """ 41 | 42 | if img_gt.ndim != img_pred.ndim: 43 | raise ValueError("The arrays 'img_gt' and 'img_pred' should have the " 44 | "same dimension, {} against {}".format(img_gt.ndim, 45 | img_pred.ndim)) 46 | 47 | res = [] 48 | # Loop on each classes of the input images 49 | for c in [3, 1, 2]: 50 | # Copy the gt image to not alterate the input 51 | gt_c_i = np.copy(img_gt) 52 | gt_c_i[gt_c_i != c] = 0 53 | 54 | # Copy the pred image to not alterate the input 55 | pred_c_i = np.copy(img_pred) 56 | pred_c_i[pred_c_i != c] = 0 57 | 58 | # Clip the value to compute the volumes 59 | gt_c_i = np.clip(gt_c_i, 0, 1) 60 | pred_c_i = np.clip(pred_c_i, 0, 1) 61 | 62 | # Compute the Dice 63 | dice = dc(gt_c_i, pred_c_i) 64 | 65 | # Compute volume 66 | # volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000. 67 | # volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000. 68 | 69 | # res += [dice, volpred, volpred-volgt] 70 | res += [dice] 71 | 72 | return res 73 | 74 | def hd_3D(img_pred, img_gt, labels=[3, 1, 2]): 75 | res = [] 76 | for c in labels: 77 | gt_c_i = np.copy(img_gt) 78 | gt_c_i[gt_c_i != c] = 0 79 | 80 | pred_c_i = np.copy(img_pred) 81 | pred_c_i[pred_c_i != c] = 0 82 | 83 | gt_c_i = np.clip(gt_c_i, 0, 1) 84 | pred_c_i = np.clip(pred_c_i, 0, 1) 85 | 86 | if np.sum(pred_c_i) == 0 or np.sum(gt_c_i) == 0: 87 | hausdorff = 0 88 | else: 89 | hausdorff = hd(pred_c_i, gt_c_i) 90 | 91 | res += [hausdorff] 92 | 93 | return res 94 | 95 | def cal_hausdorff_distance(pred,target): 96 | 97 | pred = np.array(pred.contiguous()) 98 | target = np.array(target.contiguous()) 99 | result = hausdorff_distance(pred,target,distance="euclidean") 100 | 101 | return result 102 | 103 | def make_one_hot(input, num_classes): 104 | """Convert class index tensor to one hot encoding tensor. 105 | Args: 106 | input: A tensor of shape [N, 1, *] 107 | num_classes: An int of number of class 108 | Returns: 109 | A tensor of shape [N, num_classes, *] 110 | """ 111 | shape = np.array(input.shape) 112 | shape[1] = num_classes 113 | shape = tuple(shape) 114 | result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1) 115 | # result = result.scatter_(1, input.cpu(), 1) 116 | 117 | return result 118 | 119 | def match_pred_gt(pred, gt): 120 | """ pred: (1, C, H, W) 121 | gt: (1, C, H, W) 122 | """ 123 | gt_labels = torch.unique(gt, sorted=True)[1:] 124 | pred_labels = torch.unique(pred, sorted=True)[1:] 125 | 126 | if len(gt_labels) != 0 and len(pred_labels) != 0: 127 | dice_Matrix = torch.zeros((len(pred_labels), len(gt_labels))) 128 | for i, pl in enumerate(pred_labels): 129 | pred_i = torch.tensor(pred==pl, dtype=torch.float) 130 | for j, gl in enumerate(gt_labels): 131 | dice_Matrix[i, j] = dice(make_one_hot(pred_i, 2)[0], make_one_hot(gt==gl, 2)[0]) 132 | 133 | # max_axis0 = np.max(dice_Matrix, axis=0) 134 | max_arg0 = np.argmax(dice_Matrix, axis=0) 135 | else: 136 | return torch.zeros_like(pred) 137 | 138 | pred_match = torch.zeros_like(pred) 139 | for i, arg in enumerate(max_arg0): 140 | pred_match[pred==pred_labels[arg]] = i + 1 141 | return pred_match 142 | 143 | if __name__ == "__main__": 144 | npy_path = "/home/fcheng/Cardia/source_code/logs/logs_df_50000/eval_pp_test/200.npy" 145 | pred_df, gt_df = np.load(npy_p) 146 | -------------------------------------------------------------------------------- /data/utils/utils_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import List, Set, Iterable 4 | 5 | def uniq(a: Tensor) -> Set: 6 | return set(torch.unique(a.cpu()).numpy()) 7 | 8 | def sset(a: Tensor, sub: Iterable) -> bool: 9 | return uniq(a).issubset(sub) 10 | 11 | def simplex(t: Tensor, axis=1) -> bool: 12 | _sum = t.sum(axis).type(torch.float32) 13 | _ones = torch.ones_like(_sum, dtype=torch.float32) 14 | return torch.allclose(_sum, _ones) 15 | 16 | def one_hot(t: Tensor, axis=1) -> bool: 17 | return simplex(t, axis) and sset(t, [0, 1]) 18 | 19 | # switch between representations 20 | def probs2class(probs: Tensor) -> Tensor: 21 | b, _, w, h = probs.shape # type: Tuple[int, int, int, int] 22 | assert simplex(probs) 23 | 24 | res = probs.argmax(dim=1) 25 | assert res.shape == (b, w, h) 26 | 27 | return res 28 | 29 | def class2one_hot(seg: Tensor, C: int) -> Tensor: 30 | if len(seg.shape) == 2: # Only w, h, used by the dataloader 31 | seg = seg.unsqueeze(dim=0) 32 | assert sset(seg, list(range(C))) 33 | 34 | b, w, h = seg.shape # type: Tuple[int, int, int] 35 | 36 | res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32) 37 | assert res.shape == (b, C, w, h) 38 | assert one_hot(res) 39 | 40 | return res 41 | 42 | def probs2one_hot(probs: Tensor) -> Tensor: 43 | _, C, _, _ = probs.shape 44 | assert simplex(probs) 45 | 46 | res = class2one_hot(probs2class(probs), C) 47 | assert res.shape == probs.shape 48 | assert one_hot(res) 49 | 50 | return res -------------------------------------------------------------------------------- /data/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import colorsys 3 | import random 4 | import cv2 5 | 6 | def get_n_hls_colors(num): 7 | hls_colors = [] 8 | i = 0 9 | step = 360.0 / num 10 | while i < 360: 11 | h = i 12 | s = 90 + random.random() * 10 13 | l = 50 + random.random() * 10 14 | _hlsc = [h / 360.0, l / 100.0, s / 100.0] 15 | hls_colors.append(_hlsc) 16 | i += step 17 | 18 | return hls_colors 19 | 20 | def ncolors(num): 21 | rgb_colors = [] 22 | if num < 1: 23 | return np.array(rgb_colors) 24 | hls_colors = get_n_hls_colors(num) 25 | for hlsc in hls_colors: 26 | _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2]) 27 | rgb_colors.append([_r, _g, _b]) 28 | # r, g, b = [int(x * 255.0) for x in (_r, _g, _b)] 29 | # rgb_colors.append([r, g, b]) 30 | 31 | return np.array(rgb_colors) 32 | 33 | def random_colors(N, bright=True): 34 | """ 35 | Generate random colors. 36 | To get visually distinct colors, generate them in HSV space then 37 | convert to RGB. 38 | """ 39 | brightness = 1.0 if bright else 0.7 40 | hsv = [(i / N, 1, brightness) for i in range(N)] 41 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 42 | # random.shuffle(colors) 43 | return colors 44 | 45 | def mask2png(mask, file_name=None, suffix="png"): 46 | """ mask: (w, h) 47 | img_rgb: (w, h, rgb) 48 | """ 49 | nums = np.unique(mask)[1:] 50 | if len(nums) < 1: 51 | colors = np.array([[0,0,0]]) 52 | else: 53 | # colors = ncolors(len(nums)) 54 | colors = (np.array(random_colors(len(nums))) * 255).astype(int) 55 | colors = np.insert(colors, 0, [0,0,0], 0) 56 | 57 | # 保证mask中的值为1-N连续 58 | mask_ordered = np.zeros_like(mask) 59 | for cnt, l in enumerate(nums, 1): 60 | mask_ordered[mask==l] = cnt 61 | 62 | im_rgb = colors[mask_ordered.astype(int)] 63 | if file_name is not None: 64 | cv2.imwrite(file_name+"."+suffix, im_rgb[:, :, ::-1]) 65 | return im_rgb 66 | 67 | def apply_mask(image, mask, color, alpha=0.5, scale=1): 68 | """Apply the given mask to the image. 69 | """ 70 | for c in range(3): 71 | image[:, :, c] = np.where(mask == 1, 72 | image[:, :, c] * 73 | (1 - alpha) + alpha * color[c] * scale, 74 | image[:, :, c]) 75 | return image 76 | 77 | def img_mask_png(image, mask, file_name=None, alpha=0.5, suffix="png"): 78 | """ mask: (h, w) 79 | image: (h, w, rgb) 80 | """ 81 | nums = np.unique(mask)[1:] 82 | if len(nums) < 1: 83 | colors = np.array([[0,0,0]]) 84 | else: 85 | colors = ncolors(len(nums)) 86 | colors = np.insert(colors, 0, [0,0,0], 0) 87 | 88 | # 保证mask中的值为1-N连续 89 | mask_ordered = np.zeros_like(mask) 90 | for cnt, l in enumerate(nums, 1): 91 | mask_ordered[mask==l] = cnt 92 | 93 | # mask_rgb = colors[mask_ordered.astype(int)] 94 | mix_im = image.copy() 95 | for i in np.unique(mask_ordered)[1:]: 96 | mix_im = apply_mask(mix_im, mask_ordered==i, colors[int(i)], alpha=alpha, scale=255) 97 | 98 | if file_name is not None: 99 | cv2.imwrite(file_name+"."+suffix, mix_im[:, :, ::-1]) 100 | return mix_im 101 | 102 | def _find_contour(mask): 103 | # _, contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # 顶点 104 | _, contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 105 | 106 | cont = np.zeros_like(mask) 107 | for contour in contours: 108 | cont[contour[:,:,1], contour[:,:,0]] = 1 109 | return cont 110 | 111 | def masks_to_contours(masks): 112 | # 包含多个区域 113 | nums = np.unique(masks)[1:] 114 | cont_mask = np.zeros_like(masks) 115 | for i in nums: 116 | cont_mask += _find_contour(masks==i) 117 | return (cont_mask>0).astype(int) 118 | 119 | def batchToColorImg(batch, minv=None, maxv=None, scale=255.): 120 | """ batch: (N, H, W, C) 121 | """ 122 | if batch.ndim == 3: 123 | N, H, W = batch.shape 124 | elif batch.ndim == 4: 125 | N, H, W, _ = batch.shape 126 | colorImg = np.zeros(shape=(N, H, W, 3)) 127 | for i in range(N): 128 | if minv is None: 129 | a = (batch[i] - batch[i].min()) / (batch[i].max() - batch[i].min()) * 255 130 | else: 131 | a = (batch[i] - minv) / (maxv - minv) * scale 132 | a = cv2.applyColorMap(a.astype(np.uint8), cv2.COLORMAP_JET) 133 | colorImg[i, ...] = a[..., ::-1] / 255. 134 | return colorImg 135 | 136 | if __name__ == "__main__": 137 | a = np.zeros((100, 100)) 138 | a[0:5, 3:8] = 1 139 | a[75:85, 85:95] = 2 140 | 141 | # colors = ncolors(2)[::-1] 142 | colors = np.array(random_colors(2)) * 255 143 | colors = np.insert(colors, 0, [0, 0, 0], 0) 144 | 145 | b = colors[a.astype(int)].astype(np.uint8) 146 | import cv2, skimage 147 | # skimage.io.imsave("test_io.png", b) 148 | # # cv2.imwrite("test.jpg", b[:, :, ::-1]) 149 | # print() 150 | 151 | # mask2png(a, "test") 152 | 153 | ################################################ 154 | # img_mask_png(b, a, "test") 155 | 156 | ############################################# 157 | # cont_mask = find_contours(a) 158 | # print() 159 | # # skimage.io.imsave("test_cont.png", cont_mask) 160 | # b[cont_mask>0, :] = [255, 255, 255] 161 | # skimage.io.imsave("test_cont.png", b) 162 | 163 | gt0 = skimage.io.imread("gt0.png", as_gray=False) 164 | print() 165 | gt0[gt0==54] = 0 166 | # cont_mask = find_contours(gt0==237) # Array([ 18, 54, 73, 237], dtype=uint8) 167 | # cont_mask += find_contours(gt0==18) 168 | # cont_mask += find_contours(gt0==73) 169 | cont_mask = masks_to_contours(gt0) 170 | 171 | colors = np.array(random_colors(1)) * 255 172 | colors = np.insert(colors, 0, [0, 0, 0], 0) 173 | cont_mask = colors[cont_mask.astype(int)].astype(np.uint8) 174 | 175 | skimage.io.imsave("test_cont.png", cont_mask) 176 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # SR : Segmentation Result 4 | # GT : Ground Truth 5 | 6 | def get_accuracy(SR,GT,threshold=0.5): 7 | SR = SR > threshold 8 | GT = GT == torch.max(GT) 9 | corr = torch.sum(SR==GT) 10 | tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3) 11 | acc = float(corr)/float(tensor_size) 12 | 13 | return acc 14 | 15 | def get_sensitivity(SR,GT,threshold=0.5): 16 | # Sensitivity == Recall 17 | SR = SR > threshold 18 | GT = GT == torch.max(GT) 19 | 20 | # TP : True Positive 21 | # FN : False Negative 22 | TP = ((SR==1)+(GT==1))==2 23 | FN = ((SR==0)+(GT==1))==2 24 | 25 | SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6) 26 | 27 | return SE 28 | 29 | def get_specificity(SR,GT,threshold=0.5): 30 | SR = SR > threshold 31 | GT = GT == torch.max(GT) 32 | 33 | # TN : True Negative 34 | # FP : False Positive 35 | TN = ((SR==0)+(GT==0))==2 36 | FP = ((SR==1)+(GT==0))==2 37 | 38 | SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6) 39 | 40 | return SP 41 | 42 | def get_precision(SR,GT,threshold=0.5): 43 | SR = SR > threshold 44 | GT = GT == torch.max(GT) 45 | 46 | # TP : True Positive 47 | # FP : False Positive 48 | TP = ((SR==1)+(GT==1))==2 49 | FP = ((SR==1)+(GT==0))==2 50 | 51 | PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6) 52 | 53 | return PC 54 | 55 | def get_F1(SR,GT,threshold=0.5): 56 | # Sensitivity == Recall 57 | SE = get_sensitivity(SR,GT,threshold=threshold) 58 | PC = get_precision(SR,GT,threshold=threshold) 59 | 60 | F1 = 2*SE*PC/(SE+PC + 1e-6) 61 | 62 | return F1 63 | 64 | def get_JS(SR,GT,threshold=0.5): 65 | # JS : Jaccard similarity 66 | SR = SR > threshold 67 | GT = GT == torch.max(GT) 68 | 69 | Inter = torch.sum((SR+GT)==2) 70 | Union = torch.sum((SR+GT)>=1) 71 | 72 | JS = float(Inter)/(float(Union) + 1e-6) 73 | 74 | return JS 75 | 76 | def get_DC(SR,GT,threshold=0.5): 77 | 78 | # DC : Dice Coefficient 79 | SR = SR > threshold 80 | GT = GT == torch.max(GT) 81 | 82 | Inter = torch.sum((SR+GT)==2) 83 | DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6) 84 | 85 | return DC 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /fig/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/fig/1.png -------------------------------------------------------------------------------- /fig/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/fig/2.png -------------------------------------------------------------------------------- /fig/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/fig/3.png -------------------------------------------------------------------------------- /models/._models.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/._models.py -------------------------------------------------------------------------------- /models/GSConv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.nn.modules.conv import _ConvNd 10 | from torch.nn.modules.utils import _pair 11 | import numpy as np 12 | import math 13 | from . import norm as mynn 14 | from . import custom_functions as myF 15 | 16 | class GatedSpatialConv2d(_ConvNd): 17 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 18 | padding=0, dilation=1, groups=1, bias=False): 19 | """ 20 | 21 | :param in_channels: 22 | :param out_channels: 23 | :param kernel_size: 24 | :param stride: 25 | :param padding: 26 | :param dilation: 27 | :param groups: 28 | :param bias: 29 | """ 30 | 31 | kernel_size = _pair(kernel_size) 32 | stride = _pair(stride) 33 | padding = _pair(padding) 34 | dilation = _pair(dilation) 35 | super(GatedSpatialConv2d, self).__init__( 36 | in_channels, out_channels, kernel_size, stride, padding, dilation, 37 | False, _pair(0), groups, bias, 'zeros') 38 | self._gate_conv = nn.Sequential( 39 | mynn.Norm2d(in_channels+1), 40 | nn.Conv2d(in_channels+1, in_channels+1, 1), 41 | nn.ReLU(), 42 | nn.Conv2d(in_channels+1, 1, 1), 43 | mynn.Norm2d(1), 44 | nn.Sigmoid() 45 | ) 46 | 47 | def forward(self, input_features, gating_features): 48 | """ 49 | 50 | :param input_features: [NxCxHxW] featuers comming from the shape branch (canny branch). 51 | :param gating_features: [Nx1xHxW] features comming from the texture branch (resnet). Only one channel feature map. 52 | :return: 53 | """ 54 | alphas = self._gate_conv(torch.cat([input_features, gating_features], dim=1)) 55 | input_features = (input_features * (alphas + 1)) 56 | return F.conv2d(input_features, self.weight, self.bias, self.stride, 57 | self.padding, self.dilation, self.groups) 58 | 59 | def reset_parameters(self): 60 | nn.init.xavier_normal_(self.weight) 61 | if self.bias is not None: 62 | nn.init.zeros_(self.bias) 63 | 64 | 65 | class Conv2dPad(nn.Conv2d): 66 | def forward(self, input): 67 | return myF.conv2d_same(input,self.weight,self.groups) 68 | 69 | class HighFrequencyGatedSpatialConv2d(_ConvNd): 70 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 71 | padding=0, dilation=1, groups=1, bias=False): 72 | """ 73 | 74 | :param in_channels: 75 | :param out_channels: 76 | :param kernel_size: 77 | :param stride: 78 | :param padding: 79 | :param dilation: 80 | :param groups: 81 | :param bias: 82 | """ 83 | 84 | kernel_size = _pair(kernel_size) 85 | stride = _pair(stride) 86 | padding = _pair(padding) 87 | dilation = _pair(dilation) 88 | super(HighFrequencyGatedSpatialConv2d, self).__init__( 89 | in_channels, out_channels, kernel_size, stride, padding, dilation, 90 | False, _pair(0), groups, bias) 91 | 92 | self._gate_conv = nn.Sequential( 93 | mynn.Norm2d(in_channels+1), 94 | nn.Conv2d(in_channels+1, in_channels+1, 1), 95 | nn.ReLU(), 96 | nn.Conv2d(in_channels+1, 1, 1), 97 | mynn.Norm2d(1), 98 | nn.Sigmoid() 99 | ) 100 | 101 | kernel_size = 7 102 | sigma = 3 103 | 104 | x_cord = torch.arange(kernel_size).float() 105 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size).float() 106 | y_grid = x_grid.t().float() 107 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 108 | 109 | mean = (kernel_size - 1)/2. 110 | variance = sigma**2. 111 | gaussian_kernel = (1./(2.*math.pi*variance)) *\ 112 | torch.exp( 113 | -torch.sum((xy_grid - mean)**2., dim=-1) /\ 114 | (2*variance) 115 | ) 116 | 117 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 118 | 119 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 120 | gaussian_kernel = gaussian_kernel.repeat(in_channels, 1, 1, 1) 121 | 122 | self.gaussian_filter = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, padding=3, 123 | kernel_size=kernel_size, groups=in_channels, bias=False) 124 | 125 | self.gaussian_filter.weight.data = gaussian_kernel 126 | self.gaussian_filter.weight.requires_grad = False 127 | 128 | self.cw = nn.Conv2d(in_channels * 2, in_channels, 1) 129 | 130 | self.procdog = nn.Sequential( 131 | nn.Conv2d(in_channels, in_channels, 1), 132 | mynn.Norm2d(in_channels), 133 | nn.Sigmoid() 134 | ) 135 | 136 | def forward(self, input_features, gating_features): 137 | """ 138 | 139 | :param input_features: [NxCxHxW] featuers comming from the shape branch (canny branch). 140 | :param gating_features: [Nx1xHxW] features comming from the texture branch (resnet). Only one channel feature map. 141 | :return: 142 | """ 143 | n, c, h, w = input_features.size() 144 | smooth_features = self.gaussian_filter(input_features) 145 | dog_features = input_features - smooth_features 146 | dog_features = self.cw(torch.cat((dog_features, input_features), dim=1)) 147 | 148 | alphas = self._gate_conv(torch.cat([input_features, gating_features], dim=1)) 149 | 150 | dog_features = dog_features * (alphas + 1) 151 | 152 | return F.conv2d(dog_features, self.weight, self.bias, self.stride, 153 | self.padding, self.dilation, self.groups) 154 | 155 | def reset_parameters(self): 156 | nn.init.xavier_normal_(self.weight) 157 | if self.bias is not None: 158 | nn.init.zeros_(self.bias) 159 | 160 | def t(): 161 | import matplotlib.pyplot as plt 162 | 163 | canny_map_filters_in = 8 164 | canny_map = np.random.normal(size=(1, canny_map_filters_in, 10, 10)) # NxCxHxW 165 | resnet_map = np.random.normal(size=(1, 1, 10, 10)) # NxCxHxW 166 | plt.imshow(canny_map[0, 0]) 167 | plt.show() 168 | 169 | canny_map = torch.from_numpy(canny_map).float() 170 | resnet_map = torch.from_numpy(resnet_map).float() 171 | 172 | gconv = GatedSpatialConv2d(canny_map_filters_in, canny_map_filters_in, 173 | kernel_size=3, stride=1, padding=1) 174 | output_map = gconv(canny_map, resnet_map) 175 | print('done') 176 | 177 | 178 | if __name__ == "__main__": 179 | t() 180 | 181 | -------------------------------------------------------------------------------- /models/PraNet_Res2Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Res2Net_v1b import res2net50_v1b_26w_4s 5 | 6 | 7 | class BasicConv2d(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 9 | super(BasicConv2d, self).__init__() 10 | self.conv = nn.Conv2d(in_planes, out_planes, 11 | kernel_size=kernel_size, stride=stride, 12 | padding=padding, dilation=dilation, bias=False) 13 | self.bn = nn.BatchNorm2d(out_planes) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | x = self.bn(x) 19 | return x 20 | 21 | 22 | class RFB_modified(nn.Module): 23 | def __init__(self, in_channel, out_channel): 24 | super(RFB_modified, self).__init__() 25 | self.relu = nn.ReLU(True) 26 | self.branch0 = nn.Sequential( 27 | BasicConv2d(in_channel, out_channel, 1), 28 | ) 29 | self.branch1 = nn.Sequential( 30 | BasicConv2d(in_channel, out_channel, 1), 31 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)), 32 | BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)), 33 | BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3) 34 | ) 35 | self.branch2 = nn.Sequential( 36 | BasicConv2d(in_channel, out_channel, 1), 37 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)), 38 | BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)), 39 | BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5) 40 | ) 41 | self.branch3 = nn.Sequential( 42 | BasicConv2d(in_channel, out_channel, 1), 43 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)), 44 | BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)), 45 | BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7) 46 | ) 47 | self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1) 48 | self.conv_res = BasicConv2d(in_channel, out_channel, 1) 49 | 50 | def forward(self, x): 51 | x0 = self.branch0(x) 52 | x1 = self.branch1(x) 53 | x2 = self.branch2(x) 54 | x3 = self.branch3(x) 55 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1)) 56 | 57 | x = self.relu(x_cat + self.conv_res(x)) 58 | return x 59 | 60 | 61 | class aggregation(nn.Module): 62 | # dense aggregation, it can be replaced by other aggregation previous, such as DSS, amulet, and so on. 63 | # used after MSF 64 | def __init__(self, channel): 65 | super(aggregation, self).__init__() 66 | self.relu = nn.ReLU(True) 67 | 68 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 69 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) 70 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) 71 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) 72 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1) 73 | self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1) 74 | 75 | self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1) 76 | self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1) 77 | self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1) 78 | self.conv5 = nn.Conv2d(3*channel, 1, 1) 79 | 80 | def forward(self, x1, x2, x3): 81 | x1_1 = x1 82 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2 83 | x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \ 84 | * self.conv_upsample3(self.upsample(x2)) * x3 85 | 86 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1) 87 | x2_2 = self.conv_concat2(x2_2) 88 | 89 | x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1) 90 | x3_2 = self.conv_concat3(x3_2) 91 | 92 | x = self.conv4(x3_2) 93 | x = self.conv5(x) 94 | 95 | return x 96 | 97 | 98 | class PraNet(nn.Module): 99 | # res2net based encoder decoder 100 | def __init__(self, channel=32): 101 | super(PraNet, self).__init__() 102 | # ---- ResNet Backbone ---- 103 | self.resnet = res2net50_v1b_26w_4s(pretrained=False) 104 | self.resnet.conv1[0]=nn.Conv2d(1,32,kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 105 | # ---- Receptive Field Block like module ---- 106 | self.rfb2_1 = RFB_modified(512, channel) 107 | self.rfb3_1 = RFB_modified(1024, channel) 108 | self.rfb4_1 = RFB_modified(2048, channel) 109 | # ---- Partial Decoder ---- 110 | self.agg1 = aggregation(channel) 111 | # ---- reverse attention branch 4 ---- 112 | self.ra4_conv1 = BasicConv2d(2048, 256, kernel_size=1) 113 | self.ra4_conv2 = BasicConv2d(256, 256, kernel_size=5, padding=2) 114 | self.ra4_conv3 = BasicConv2d(256, 256, kernel_size=5, padding=2) 115 | self.ra4_conv4 = BasicConv2d(256, 256, kernel_size=5, padding=2) 116 | self.ra4_conv5 = BasicConv2d(256, 1, kernel_size=1) 117 | # ---- reverse attention branch 3 ---- 118 | self.ra3_conv1 = BasicConv2d(1024, 64, kernel_size=1) 119 | self.ra3_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1) 120 | self.ra3_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1) 121 | self.ra3_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1) 122 | # ---- reverse attention branch 2 ---- 123 | self.ra2_conv1 = BasicConv2d(512, 64, kernel_size=1) 124 | self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1) 125 | self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1) 126 | self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1) 127 | 128 | def forward(self, x): 129 | x = self.resnet.conv1(x) 130 | x = self.resnet.bn1(x) 131 | x = self.resnet.relu(x) 132 | x = self.resnet.maxpool(x) # bs, 64, 88, 88 133 | # ---- low-level features ---- 134 | x1 = self.resnet.layer1(x) # bs, 256, 88, 88 135 | x2 = self.resnet.layer2(x1) # bs, 512, 44, 44 136 | 137 | x3 = self.resnet.layer3(x2) # bs, 1024, 22, 22 138 | x4 = self.resnet.layer4(x3) # bs, 2048, 11, 11 139 | x2_rfb = self.rfb2_1(x2) # channel -> 32 140 | x3_rfb = self.rfb3_1(x3) # channel -> 32 141 | x4_rfb = self.rfb4_1(x4) # channel -> 32 142 | 143 | ra5_feat = self.agg1(x4_rfb, x3_rfb, x2_rfb) 144 | lateral_map_5 = F.interpolate(ra5_feat, scale_factor=8, mode='bilinear') # NOTES: Sup-1 (bs, 1, 44, 44) -> (bs, 1, 352, 352) 145 | 146 | # ---- reverse attention branch_4 ---- 147 | crop_4 = F.interpolate(ra5_feat, scale_factor=0.25, mode='bilinear') 148 | x = -1*(torch.sigmoid(crop_4)) + 1 149 | x = x.expand(-1, 2048, -1, -1).mul(x4) 150 | x = self.ra4_conv1(x) 151 | x = F.relu(self.ra4_conv2(x)) 152 | x = F.relu(self.ra4_conv3(x)) 153 | x = F.relu(self.ra4_conv4(x)) 154 | ra4_feat = self.ra4_conv5(x) 155 | x = ra4_feat + crop_4 156 | lateral_map_4 = F.interpolate(x, scale_factor=32, mode='bilinear') # NOTES: Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352) 157 | 158 | # ---- reverse attention branch_3 ---- 159 | crop_3 = F.interpolate(x, scale_factor=2, mode='bilinear') 160 | x = -1*(torch.sigmoid(crop_3)) + 1 161 | x = x.expand(-1, 1024, -1, -1).mul(x3) 162 | x = self.ra3_conv1(x) 163 | x = F.relu(self.ra3_conv2(x)) 164 | x = F.relu(self.ra3_conv3(x)) 165 | ra3_feat = self.ra3_conv4(x) 166 | x = ra3_feat + crop_3 167 | lateral_map_3 = F.interpolate(x, scale_factor=16, mode='bilinear') # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352) 168 | 169 | # ---- reverse attention branch_2 ---- 170 | crop_2 = F.interpolate(x, scale_factor=2, mode='bilinear') 171 | x = -1*(torch.sigmoid(crop_2)) + 1 172 | x = x.expand(-1, 512, -1, -1).mul(x2) 173 | x = self.ra2_conv1(x) 174 | x = F.relu(self.ra2_conv2(x)) 175 | x = F.relu(self.ra2_conv3(x)) 176 | ra2_feat = self.ra2_conv4(x) 177 | x = ra2_feat + crop_2 178 | lateral_map_2 = F.interpolate(x, scale_factor=8, mode='bilinear') # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352) 179 | 180 | return lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 181 | 182 | 183 | if __name__ == '__main__': 184 | ras = PraNet().cuda() 185 | input_tensor = torch.randn(1, 1, 96, 96).cuda() 186 | 187 | out = ras(input_tensor) 188 | print(out[0].shape) -------------------------------------------------------------------------------- /models/R2U_Net_model.py: -------------------------------------------------------------------------------- 1 | # """ Full assembly of the parts to form the complete network """ 2 | # """Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py""" 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import init 9 | 10 | def init_weights(net, init_type='normal', gain=0.02): 11 | def init_func(m): 12 | classname = m.__class__.__name__ 13 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 14 | if init_type == 'normal': 15 | init.normal_(m.weight.data, 0.0, gain) 16 | elif init_type == 'xavier': 17 | init.xavier_normal_(m.weight.data, gain=gain) 18 | elif init_type == 'kaiming': 19 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 20 | elif init_type == 'orthogonal': 21 | init.orthogonal_(m.weight.data, gain=gain) 22 | else: 23 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 24 | if hasattr(m, 'bias') and m.bias is not None: 25 | init.constant_(m.bias.data, 0.0) 26 | elif classname.find('BatchNorm2d') != -1: 27 | init.normal_(m.weight.data, 1.0, gain) 28 | init.constant_(m.bias.data, 0.0) 29 | 30 | print('initialize network with %s' % init_type) 31 | net.apply(init_func) 32 | # class UNet(nn.Module): 33 | # def __init__(self, n_channels, n_classes, bilinear=True): 34 | # super(UNet, self).__init__() 35 | # self.n_channels = n_channels 36 | # self.n_classes = n_classes 37 | # self.bilinear = bilinear 38 | 39 | # self.inc = DoubleConv(n_channels, 64) 40 | # self.down1 = Down(64, 128) 41 | # self.down2 = Down(128, 256) 42 | # self.down3 = Down(256, 512) 43 | # self.down4 = Down(512, 512) 44 | # self.up1 = Up(1024, 256, bilinear) 45 | # self.up2 = Up(512, 128, bilinear) 46 | # self.up3 = Up(256, 64, bilinear) 47 | # self.up4 = Up(128, 64, bilinear) 48 | # self.outc = OutConv(64, n_classes) 49 | 50 | # def forward(self, x): 51 | # x1 = self.inc(x) 52 | # x2 = self.down1(x1) 53 | # x3 = self.down2(x2) 54 | # x4 = self.down3(x3) 55 | # x5 = self.down4(x4) 56 | # x = self.up1(x5, x4) 57 | # x = self.up2(x, x3) 58 | # x = self.up3(x, x2) 59 | # x = self.up4(x, x1) 60 | # logits = self.outc(x) 61 | # return logits 62 | 63 | class up_conv(nn.Module): 64 | def __init__(self,ch_in,ch_out): 65 | super(up_conv,self).__init__() 66 | self.up = nn.Sequential( 67 | nn.Upsample(scale_factor=2), 68 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 69 | nn.BatchNorm2d(ch_out), 70 | nn.ReLU(inplace=True) 71 | ) 72 | 73 | def forward(self, x1, x2): 74 | x1 = self.up(x1) 75 | # input is CHW 76 | diffY = x2.size()[2] - x1.size()[2] 77 | diffX = x2.size()[3] - x1.size()[3] 78 | 79 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 80 | diffY // 2, diffY - diffY // 2]) 81 | x = torch.cat([x2, x1], dim = 1) 82 | return x 83 | 84 | class Recurrent_block(nn.Module): 85 | def __init__(self,ch_out,t=2): 86 | super(Recurrent_block,self).__init__() 87 | self.t = t 88 | self.ch_out = ch_out 89 | self.conv = nn.Sequential( 90 | nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 91 | nn.BatchNorm2d(ch_out), 92 | nn.ReLU(inplace=True) 93 | ) 94 | 95 | def forward(self,x): 96 | for i in range(self.t): 97 | 98 | if i==0: 99 | x1 = self.conv(x) 100 | 101 | x1 = self.conv(x+x1) 102 | return x1 103 | 104 | class RRCNN_block(nn.Module): 105 | def __init__(self,ch_in,ch_out,t=2): 106 | super(RRCNN_block,self).__init__() 107 | self.RCNN = nn.Sequential( 108 | Recurrent_block(ch_out,t=t), 109 | Recurrent_block(ch_out,t=t) 110 | ) 111 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0) 112 | 113 | def forward(self,x): 114 | x = self.Conv_1x1(x) 115 | x1 = self.RCNN(x) 116 | return x+x1 117 | class R2U_Net(nn.Module): 118 | def __init__(self,img_ch,output_ch,t=2): 119 | super(R2U_Net,self).__init__() 120 | 121 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 122 | self.Upsample = nn.Upsample(scale_factor=2) 123 | 124 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t) 125 | 126 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t) 127 | 128 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t) 129 | 130 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t) 131 | 132 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t) 133 | 134 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 135 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t) 136 | 137 | self.Up4 = up_conv(ch_in=512,ch_out=256) 138 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t) 139 | 140 | self.Up3 = up_conv(ch_in=256,ch_out=128) 141 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t) 142 | 143 | self.Up2 = up_conv(ch_in=128,ch_out=64) 144 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t) 145 | 146 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 147 | 148 | 149 | def forward(self,x): 150 | # encoding path 151 | x1 = self.RRCNN1(x) 152 | 153 | x2 = self.Maxpool(x1) 154 | x2 = self.RRCNN2(x2) 155 | 156 | x3 = self.Maxpool(x2) 157 | x3 = self.RRCNN3(x3) 158 | 159 | x4 = self.Maxpool(x3) 160 | x4 = self.RRCNN4(x4) 161 | 162 | x5 = self.Maxpool(x4) 163 | x5 = self.RRCNN5(x5) 164 | 165 | # decoding + concat path 166 | d5 = self.Up5(x5,x4) 167 | d5 = self.Up_RRCNN5(d5) 168 | d4 = self.Up4(d5,x3) 169 | 170 | d4 = self.Up_RRCNN4(d4) 171 | 172 | d3 = self.Up3(d4,x2) 173 | 174 | d3 = self.Up_RRCNN3(d3) 175 | 176 | d2 = self.Up2(d3,x1) 177 | 178 | d2 = self.Up_RRCNN2(d2) 179 | d1 = self.Conv_1x1(d2) 180 | 181 | return d1 182 | 183 | if __name__ == '__main__': 184 | net = R2U_Net(img_ch=1, output_ch=1) 185 | print(net) 186 | 187 | 188 | -------------------------------------------------------------------------------- /models/UNet_2Plus.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.layers import unetConv2, unetUp_origin 6 | from models.init_weights import init_weights 7 | import numpy as np 8 | from torchvision import models 9 | class UNet_2Plus(nn.Module): 10 | 11 | def __init__(self, in_channels=1, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True, is_ds=True): 12 | super(UNet_2Plus, self).__init__() 13 | self.is_deconv = is_deconv 14 | self.in_channels = in_channels 15 | self.is_batchnorm = is_batchnorm 16 | self.is_ds = is_ds 17 | self.feature_scale = feature_scale 18 | 19 | # filters = [32, 64, 128, 256, 512] 20 | filters = [64, 128, 256, 512, 1024] 21 | # filters = [int(x / self.feature_scale) for x in filters] 22 | 23 | # downsampling 24 | self.conv00 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 25 | self.maxpool0 = nn.MaxPool2d(kernel_size=2) 26 | self.conv10 = unetConv2(filters[0], filters[1], self.is_batchnorm) 27 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 28 | self.conv20 = unetConv2(filters[1], filters[2], self.is_batchnorm) 29 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 30 | self.conv30 = unetConv2(filters[2], filters[3], self.is_batchnorm) 31 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 32 | self.conv40 = unetConv2(filters[3], filters[4], self.is_batchnorm) 33 | 34 | 35 | # upsampling 36 | self.up_concat01 = unetUp_origin(filters[1], filters[0], self.is_deconv) 37 | self.up_concat11 = unetUp_origin(filters[2], filters[1], self.is_deconv) 38 | self.up_concat21 = unetUp_origin(filters[3], filters[2], self.is_deconv) 39 | self.up_concat31 = unetUp_origin(filters[4], filters[3], self.is_deconv) 40 | 41 | self.up_concat02 = unetUp_origin(filters[1], filters[0], self.is_deconv, 3) 42 | self.up_concat12 = unetUp_origin(filters[2], filters[1], self.is_deconv, 3) 43 | self.up_concat22 = unetUp_origin(filters[3], filters[2], self.is_deconv, 3) 44 | 45 | self.up_concat03 = unetUp_origin(filters[1], filters[0], self.is_deconv, 4) 46 | self.up_concat13 = unetUp_origin(filters[2], filters[1], self.is_deconv, 4) 47 | 48 | self.up_concat04 = unetUp_origin(filters[1], filters[0], self.is_deconv, 5) 49 | 50 | # final conv (without any concat) 51 | self.final_1 = nn.Conv2d(filters[0], n_classes, 1) 52 | self.final_2 = nn.Conv2d(filters[0], n_classes, 1) 53 | self.final_3 = nn.Conv2d(filters[0], n_classes, 1) 54 | self.final_4 = nn.Conv2d(filters[0], n_classes, 1) 55 | 56 | # initialise weights 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | init_weights(m, init_type='kaiming') 60 | elif isinstance(m, nn.BatchNorm2d): 61 | init_weights(m, init_type='kaiming') 62 | 63 | def forward(self, inputs): 64 | # column : 0 65 | X_00 = self.conv00(inputs) 66 | maxpool0 = self.maxpool0(X_00) 67 | X_10 = self.conv10(maxpool0) 68 | maxpool1 = self.maxpool1(X_10) 69 | X_20 = self.conv20(maxpool1) 70 | maxpool2 = self.maxpool2(X_20) 71 | X_30 = self.conv30(maxpool2) 72 | maxpool3 = self.maxpool3(X_30) 73 | X_40 = self.conv40(maxpool3) 74 | 75 | # column : 1 76 | X_01 = self.up_concat01(X_10, X_00) 77 | X_11 = self.up_concat11(X_20, X_10) 78 | X_21 = self.up_concat21(X_30, X_20) 79 | X_31 = self.up_concat31(X_40, X_30) 80 | # column : 2 81 | X_02 = self.up_concat02(X_11, X_00, X_01) 82 | X_12 = self.up_concat12(X_21, X_10, X_11) 83 | X_22 = self.up_concat22(X_31, X_20, X_21) 84 | # column : 3 85 | X_03 = self.up_concat03(X_12, X_00, X_01, X_02) 86 | X_13 = self.up_concat13(X_22, X_10, X_11, X_12) 87 | # column : 4 88 | X_04 = self.up_concat04(X_13, X_00, X_01, X_02, X_03) 89 | 90 | # final layer 91 | final_1 = self.final_1(X_01) 92 | final_2 = self.final_2(X_02) 93 | final_3 = self.final_3(X_03) 94 | final_4 = self.final_4(X_04) 95 | 96 | final = (final_1 + final_2 + final_3 + final_4) / 4 97 | 98 | if self.is_ds: 99 | return final 100 | else: 101 | return final_4 102 | 103 | # if self.is_ds: 104 | # return F.sigmoid(final) 105 | # else: 106 | # return F.sigmoid(final_4) 107 | 108 | # model = UNet_2Plus() 109 | # print('# generator parameters:', 1.0 * sum(param.numel() for param in model.parameters())/1000000) 110 | # params = list(model.named_parameters()) 111 | # for i in range(len(params)): 112 | # (name, param) = params[i] 113 | # print(name) 114 | # print(param.shape) 115 | -------------------------------------------------------------------------------- /models/__pycache__/AttU_Net_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/AttU_Net_model.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/BaseNet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/BaseNet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/BaseNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/BaseNet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/F3net.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/F3net.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/GSConv.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/GSConv.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/GSConv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/GSConv.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/InfNet_Res2Net.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/InfNet_Res2Net.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/InfNet_Res2Net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/InfNet_Res2Net.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/LDF.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/LDF.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/LDunet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/LDunet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/PraNet_Res2Net.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/PraNet_Res2Net.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/R2U_Net_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/R2U_Net_model.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/Res2Net_v1b.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/Res2Net_v1b.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/UNet_2Plus.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/UNet_2Plus.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/UNet_2Plus.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/UNet_2Plus.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/attention_blocks.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/attention_blocks.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/attention_blocks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/attention_blocks.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/cenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/cenet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/cenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/cenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/custom_functions.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/custom_functions.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/custom_functions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/custom_functions.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/deeplab_v3p.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/deeplab_v3p.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/denseunet_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/denseunet_model.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/fcn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/fcn.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/fcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/fcn.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/init_weights.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/init_weights.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/init_weights.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/init_weights.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/layers.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/layers.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/models.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/models.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/multi_scale.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/multi_scale.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/multi_scale.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/multi_scale.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/multi_scale_module.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/multi_scale_module.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/newnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/newnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/norm.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/norm.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/norm.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/resnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/unet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/unet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vggunet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/vggunet.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/wassp.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/wassp.cpython-35.pyc -------------------------------------------------------------------------------- /models/__pycache__/wassp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/wassp.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/wnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/wnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/adaptive_avgmax_pool.py: -------------------------------------------------------------------------------- 1 | """ PyTorch selectable adaptive pooling 2 | Adaptive pooling with the ability to select the type of pooling from: 3 | * 'avg' - Average pooling 4 | * 'max' - Max pooling 5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim 7 | Both a functional and a nn.Module version of the pooling is provided. 8 | Author: Ross Wightman (rwightman) 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def pooling_factor(pool_type='avg'): 16 | return 2 if pool_type == 'avgmaxc' else 1 17 | 18 | 19 | def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False): 20 | """Selectable global pooling function with dynamic input kernel size 21 | """ 22 | if pool_type == 'avgmaxc': 23 | x = torch.cat([ 24 | F.avg_pool2d( 25 | x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad), 26 | F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding) 27 | ], dim=1) 28 | elif pool_type == 'avgmax': 29 | x_avg = F.avg_pool2d( 30 | x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad) 31 | x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding) 32 | x = 0.5 * (x_avg + x_max) 33 | elif pool_type == 'max': 34 | x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding) 35 | else: 36 | if pool_type != 'avg': 37 | print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type) 38 | x = F.avg_pool2d( 39 | x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad) 40 | return x 41 | 42 | 43 | class AdaptiveAvgMaxPool2d(torch.nn.Module): 44 | """Selectable global pooling layer with dynamic input kernel size 45 | """ 46 | def __init__(self, output_size=1, pool_type='avg'): 47 | super(AdaptiveAvgMaxPool2d, self).__init__() 48 | self.output_size = output_size 49 | self.pool_type = pool_type 50 | if pool_type == 'avgmaxc' or pool_type == 'avgmax': 51 | self.pool = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size), nn.AdaptiveMaxPool2d(output_size)]) 52 | elif pool_type == 'max': 53 | self.pool = nn.AdaptiveMaxPool2d(output_size) 54 | else: 55 | if pool_type != 'avg': 56 | print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type) 57 | self.pool = nn.AdaptiveAvgPool2d(output_size) 58 | 59 | def forward(self, x): 60 | if self.pool_type == 'avgmaxc': 61 | x = torch.cat([p(x) for p in self.pool], dim=1) 62 | elif self.pool_type == 'avgmax': 63 | x = 0.5 * torch.sum(torch.stack([p(x) for p in self.pool]), 0).squeeze(dim=0) 64 | else: 65 | x = self.pool(x) 66 | return x 67 | 68 | def factor(self): 69 | return pooling_factor(self.pool_type) 70 | 71 | def __repr__(self): 72 | return self.__class__.__name__ + ' (' \ 73 | + 'output_size=' + str(self.output_size) \ 74 | + ', pool_type=' + self.pool_type + ')' 75 | -------------------------------------------------------------------------------- /models/backbone/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """ 7 | 3x3 convolution with padding 8 | """ 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | # ResNet50 with two branches 86 | def __init__(self): 87 | # self.inplanes = 128 88 | self.inplanes = 64 89 | super(ResNet, self).__init__() 90 | 91 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 92 | bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | self.layer1 = self._make_layer(Bottleneck, 64, 3) 97 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) 98 | self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2) 99 | self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2) 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | elif isinstance(m, nn.BatchNorm2d): 106 | m.weight.data.fill_(1) 107 | m.bias.data.zero_() 108 | 109 | def _make_layer(self, block, planes, blocks, stride=1): 110 | downsample = None 111 | if stride != 1 or self.inplanes != planes * block.expansion: 112 | downsample = nn.Sequential( 113 | nn.Conv2d(self.inplanes, planes * block.expansion, 114 | kernel_size=1, stride=stride, bias=False), 115 | nn.BatchNorm2d(planes * block.expansion), 116 | ) 117 | 118 | layers = [] 119 | 120 | layers.append(block(self.inplanes, planes, stride, downsample)) 121 | self.inplanes = planes * block.expansion 122 | for i in range(1, blocks): 123 | layers.append(block(self.inplanes, planes)) 124 | 125 | return nn.Sequential(*layers) 126 | 127 | def forward(self, x): 128 | x = self.conv1(x) 129 | x = self.bn1(x) 130 | x = self.relu(x) 131 | x = self.maxpool(x) 132 | 133 | x = self.layer1(x) 134 | x = self.layer2(x) 135 | x1 = self.layer3_1(x) 136 | x1 = self.layer4_1(x1) 137 | 138 | x2 = self.layer3_2(x) 139 | x2 = self.layer4_2(x2) 140 | 141 | return x1, x2 142 | -------------------------------------------------------------------------------- /models/backbone/VGGNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class B2_VGG(nn.Module): 6 | def __init__(self): 7 | super(B2_VGG, self).__init__() 8 | conv1 = nn.Sequential() 9 | conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1)) 10 | conv1.add_module('relu1_1', nn.ReLU(inplace=True)) 11 | conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1)) 12 | conv1.add_module('relu1_2', nn.ReLU(inplace=True)) 13 | 14 | self.conv1 = conv1 15 | conv2 = nn.Sequential() 16 | conv2.add_module('pool1', nn.AvgPool2d(2, stride=2)) 17 | conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1)) 18 | conv2.add_module('relu2_1', nn.ReLU()) 19 | conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1)) 20 | conv2.add_module('relu2_2', nn.ReLU()) 21 | self.conv2 = conv2 22 | 23 | conv3 = nn.Sequential() 24 | conv3.add_module('pool2', nn.AvgPool2d(2, stride=2)) 25 | conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1)) 26 | conv3.add_module('relu3_1', nn.ReLU()) 27 | conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1)) 28 | conv3.add_module('relu3_2', nn.ReLU()) 29 | conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1)) 30 | conv3.add_module('relu3_3', nn.ReLU()) 31 | self.conv3 = conv3 32 | 33 | conv4_1 = nn.Sequential() 34 | conv4_1.add_module('pool3_1', nn.AvgPool2d(2, stride=2)) 35 | conv4_1.add_module('conv4_1_1', nn.Conv2d(256, 512, 3, 1, 1)) 36 | conv4_1.add_module('relu4_1_1', nn.ReLU()) 37 | conv4_1.add_module('conv4_2_1', nn.Conv2d(512, 512, 3, 1, 1)) 38 | conv4_1.add_module('relu4_2_1', nn.ReLU()) 39 | conv4_1.add_module('conv4_3_1', nn.Conv2d(512, 512, 3, 1, 1)) 40 | conv4_1.add_module('relu4_3_1', nn.ReLU()) 41 | self.conv4_1 = conv4_1 42 | 43 | conv5_1 = nn.Sequential() 44 | conv5_1.add_module('pool4_1', nn.AvgPool2d(2, stride=2)) 45 | conv5_1.add_module('conv5_1_1', nn.Conv2d(512, 512, 3, 1, 1)) 46 | conv5_1.add_module('relu5_1_1', nn.ReLU()) 47 | conv5_1.add_module('conv5_2_1', nn.Conv2d(512, 512, 3, 1, 1)) 48 | conv5_1.add_module('relu5_2_1', nn.ReLU()) 49 | conv5_1.add_module('conv5_3_1', nn.Conv2d(512, 512, 3, 1, 1)) 50 | conv5_1.add_module('relu5_3_1', nn.ReLU()) 51 | self.conv5_1 = conv5_1 52 | 53 | conv4_2 = nn.Sequential() 54 | conv4_2.add_module('pool3_2', nn.AvgPool2d(2, stride=2)) 55 | conv4_2.add_module('conv4_1_2', nn.Conv2d(256, 512, 3, 1, 1)) 56 | conv4_2.add_module('relu4_1_2', nn.ReLU()) 57 | conv4_2.add_module('conv4_2_2', nn.Conv2d(512, 512, 3, 1, 1)) 58 | conv4_2.add_module('relu4_2_2', nn.ReLU()) 59 | conv4_2.add_module('conv4_3_2', nn.Conv2d(512, 512, 3, 1, 1)) 60 | conv4_2.add_module('relu4_3_2', nn.ReLU()) 61 | self.conv4_2 = conv4_2 62 | 63 | conv5_2 = nn.Sequential() 64 | conv5_2.add_module('pool4_2', nn.AvgPool2d(2, stride=2)) 65 | conv5_2.add_module('conv5_1_2', nn.Conv2d(512, 512, 3, 1, 1)) 66 | conv5_2.add_module('relu5_1_2', nn.ReLU()) 67 | conv5_2.add_module('conv5_2_2', nn.Conv2d(512, 512, 3, 1, 1)) 68 | conv5_2.add_module('relu5_2_2', nn.ReLU()) 69 | conv5_2.add_module('conv5_3_2', nn.Conv2d(512, 512, 3, 1, 1)) 70 | conv5_2.add_module('relu5_3_2', nn.ReLU()) 71 | self.conv5_2 = conv5_2 72 | 73 | pre_train = torch.load('./Snapshots/pre_trained/vgg16-397923af.pth') 74 | self._initialize_weights(pre_train) 75 | 76 | def forward(self, x): 77 | x = self.conv1(x) 78 | x = self.conv2(x) 79 | x = self.conv3(x) 80 | x1 = self.conv4_1(x) 81 | x1 = self.conv5_1(x1) 82 | x2 = self.conv4_2(x) 83 | x2 = self.conv5_2(x2) 84 | return x1, x2 85 | 86 | def _initialize_weights(self, pre_train): 87 | keys = list(pre_train.keys()) 88 | self.conv1.conv1_1.weight.data.copy_(pre_train[keys[0]]) 89 | self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]]) 90 | self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]]) 91 | self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]]) 92 | self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]]) 93 | self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]]) 94 | self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]]) 95 | self.conv4_1.conv4_1_1.weight.data.copy_(pre_train[keys[14]]) 96 | self.conv4_1.conv4_2_1.weight.data.copy_(pre_train[keys[16]]) 97 | self.conv4_1.conv4_3_1.weight.data.copy_(pre_train[keys[18]]) 98 | self.conv5_1.conv5_1_1.weight.data.copy_(pre_train[keys[20]]) 99 | self.conv5_1.conv5_2_1.weight.data.copy_(pre_train[keys[22]]) 100 | self.conv5_1.conv5_3_1.weight.data.copy_(pre_train[keys[24]]) 101 | self.conv4_2.conv4_1_2.weight.data.copy_(pre_train[keys[14]]) 102 | self.conv4_2.conv4_2_2.weight.data.copy_(pre_train[keys[16]]) 103 | self.conv4_2.conv4_3_2.weight.data.copy_(pre_train[keys[18]]) 104 | self.conv5_2.conv5_1_2.weight.data.copy_(pre_train[keys[20]]) 105 | self.conv5_2.conv5_2_2.weight.data.copy_(pre_train[keys[22]]) 106 | self.conv5_2.conv5_3_2.weight.data.copy_(pre_train[keys[24]]) 107 | 108 | self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]]) 109 | self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]]) 110 | self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]]) 111 | self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]]) 112 | self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]]) 113 | self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]]) 114 | self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]]) 115 | self.conv4_1.conv4_1_1.bias.data.copy_(pre_train[keys[15]]) 116 | self.conv4_1.conv4_2_1.bias.data.copy_(pre_train[keys[17]]) 117 | self.conv4_1.conv4_3_1.bias.data.copy_(pre_train[keys[19]]) 118 | self.conv5_1.conv5_1_1.bias.data.copy_(pre_train[keys[21]]) 119 | self.conv5_1.conv5_2_1.bias.data.copy_(pre_train[keys[23]]) 120 | self.conv5_1.conv5_3_1.bias.data.copy_(pre_train[keys[25]]) 121 | self.conv4_2.conv4_1_2.bias.data.copy_(pre_train[keys[15]]) 122 | self.conv4_2.conv4_2_2.bias.data.copy_(pre_train[keys[17]]) 123 | self.conv4_2.conv4_3_2.bias.data.copy_(pre_train[keys[19]]) 124 | self.conv5_2.conv5_1_2.bias.data.copy_(pre_train[keys[21]]) 125 | self.conv5_2.conv5_2_2.bias.data.copy_(pre_train[keys[23]]) 126 | self.conv5_2.conv5_3_2.bias.data.copy_(pre_train[keys[25]]) 127 | -------------------------------------------------------------------------------- /models/backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/backbone/__init__.py -------------------------------------------------------------------------------- /models/backbone/__pycache__/Res2Net.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/backbone/__pycache__/Res2Net.cpython-35.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/Res2Net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/backbone/__pycache__/Res2Net.cpython-36.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/backbone/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/custom_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torchvision.transforms.functional import pad 9 | import numpy as np 10 | 11 | 12 | def calc_pad_same(in_siz, out_siz, stride, ksize): 13 | """Calculate same padding width. 14 | Args: 15 | ksize: kernel size [I, J]. 16 | Returns: 17 | pad_: Actual padding width. 18 | """ 19 | return (out_siz - 1) * stride + ksize - in_siz 20 | 21 | 22 | def conv2d_same(input, kernel, groups,bias=None,stride=1,padding=0,dilation=1): 23 | n, c, h, w = input.shape 24 | kout, ki_c_g, kh, kw = kernel.shape 25 | pw = calc_pad_same(w, w, 1, kw) 26 | ph = calc_pad_same(h, h, 1, kh) 27 | pw_l = pw // 2 28 | pw_r = pw - pw_l 29 | ph_t = ph // 2 30 | ph_b = ph - ph_t 31 | 32 | input_ = F.pad(input, (pw_l, pw_r, ph_t, ph_b)) 33 | result = F.conv2d(input_, kernel, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 34 | assert result.shape == input.shape 35 | return result 36 | 37 | 38 | def gradient_central_diff(input, cuda): 39 | return input, input 40 | kernel = [[1, 0, -1]] 41 | kernel_t = 0.5 * torch.Tensor(kernel) * -1. # pytorch implements correlation instead of conv 42 | if type(cuda) is int: 43 | if cuda != -1: 44 | kernel_t = kernel_t.cuda(device=cuda) 45 | else: 46 | if cuda is True: 47 | kernel_t = kernel_t.cuda() 48 | n, c, h, w = input.shape 49 | 50 | x = conv2d_same(input, kernel_t.unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), c) 51 | y = conv2d_same(input, kernel_t.t().unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), c) 52 | return x, y 53 | 54 | 55 | def compute_single_sided_diferences(o_x, o_y, input): 56 | # n,c,h,w 57 | #input = input.clone() 58 | o_y[:, :, 0, :] = input[:, :, 1, :].clone() - input[:, :, 0, :].clone() 59 | o_x[:, :, :, 0] = input[:, :, :, 1].clone() - input[:, :, :, 0].clone() 60 | # -- 61 | o_y[:, :, -1, :] = input[:, :, -1, :].clone() - input[:, :, -2, :].clone() 62 | o_x[:, :, :, -1] = input[:, :, :, -1].clone() - input[:, :, :, -2].clone() 63 | return o_x, o_y 64 | 65 | 66 | def numerical_gradients_2d(input, cuda=False): 67 | """ 68 | numerical gradients implementation over batches using torch group conv operator. 69 | the single sided differences are re-computed later. 70 | it matches np.gradient(image) with the difference than here output=x,y for an image while there output=y,x 71 | :param input: N,C,H,W 72 | :param cuda: whether or not use cuda 73 | :return: X,Y 74 | """ 75 | n, c, h, w = input.shape 76 | assert h > 1 and w > 1 77 | x, y = gradient_central_diff(input, cuda) 78 | return x, y 79 | 80 | 81 | def convTri(input, r, cuda=False): 82 | """ 83 | Convolves an image by a 2D triangle filter (the 1D triangle filter f is 84 | [1:r r+1 r:-1:1]/(r+1)^2, the 2D version is simply conv2(f,f')) 85 | :param input: 86 | :param r: integer filter radius 87 | :param cuda: move the kernel to gpu 88 | :return: 89 | """ 90 | if (r <= 1): 91 | raise ValueError() 92 | n, c, h, w = input.shape 93 | return input 94 | f = list(range(1, r + 1)) + [r + 1] + list(reversed(range(1, r + 1))) 95 | kernel = torch.Tensor([f]) / (r + 1) ** 2 96 | if type(cuda) is int: 97 | if cuda != -1: 98 | kernel = kernel.cuda(device=cuda) 99 | else: 100 | if cuda is True: 101 | kernel = kernel.cuda() 102 | 103 | # padding w 104 | input_ = F.pad(input, (1, 1, 0, 0), mode='replicate') 105 | input_ = F.pad(input_, (r, r, 0, 0), mode='reflect') 106 | input_ = [input_[:, :, :, :r], input, input_[:, :, :, -r:]] 107 | input_ = torch.cat(input_, 3) 108 | t = input_ 109 | 110 | # padding h 111 | input_ = F.pad(input_, (0, 0, 1, 1), mode='replicate') 112 | input_ = F.pad(input_, (0, 0, r, r), mode='reflect') 113 | input_ = [input_[:, :, :r, :], t, input_[:, :, -r:, :]] 114 | input_ = torch.cat(input_, 2) 115 | 116 | output = F.conv2d(input_, 117 | kernel.unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), 118 | padding=0, groups=c) 119 | output = F.conv2d(output, 120 | kernel.t().unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), 121 | padding=0, groups=c) 122 | return output 123 | 124 | 125 | def compute_normal(E, cuda=False): 126 | if torch.sum(torch.isnan(E)) != 0: 127 | print('nans found here') 128 | import ipdb; 129 | ipdb.set_trace() 130 | E_ = convTri(E, 4, cuda) 131 | Ox, Oy = numerical_gradients_2d(E_, cuda) 132 | Oxx, _ = numerical_gradients_2d(Ox, cuda) 133 | Oxy, Oyy = numerical_gradients_2d(Oy, cuda) 134 | 135 | aa = Oyy * torch.sign(-(Oxy + 1e-5)) / (Oxx + 1e-5) 136 | t = torch.atan(aa) 137 | O = torch.remainder(t, np.pi) 138 | 139 | if torch.sum(torch.isnan(O)) != 0: 140 | print('nans found here') 141 | import ipdb; 142 | ipdb.set_trace() 143 | 144 | return O 145 | 146 | 147 | def compute_normal_2(E, cuda=False): 148 | if torch.sum(torch.isnan(E)) != 0: 149 | print('nans found here') 150 | import ipdb; 151 | ipdb.set_trace() 152 | E_ = convTri(E, 4, cuda) 153 | Ox, Oy = numerical_gradients_2d(E_, cuda) 154 | Oxx, _ = numerical_gradients_2d(Ox, cuda) 155 | Oxy, Oyy = numerical_gradients_2d(Oy, cuda) 156 | 157 | aa = Oyy * torch.sign(-(Oxy + 1e-5)) / (Oxx + 1e-5) 158 | t = torch.atan(aa) 159 | O = torch.remainder(t, np.pi) 160 | 161 | if torch.sum(torch.isnan(O)) != 0: 162 | print('nans found here') 163 | import ipdb; 164 | ipdb.set_trace() 165 | 166 | return O, (Oyy, Oxx) 167 | 168 | 169 | def compute_grad_mag(E, cuda=False): 170 | E_ = convTri(E, 4, cuda) 171 | Ox, Oy = numerical_gradients_2d(E_, cuda) 172 | mag = torch.sqrt(torch.mul(Ox,Ox) + torch.mul(Oy,Oy) + 1e-6) 173 | mag = mag / mag.max(); 174 | 175 | return mag 176 | -------------------------------------------------------------------------------- /models/deeplab/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/__pycache__/aspp.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/__pycache__/aspp.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/__pycache__/aspp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/__pycache__/aspp.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/__pycache__/decoder.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/__pycache__/decoder.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/__pycache__/decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/__pycache__/decoder.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 60 | self.bn1 = BatchNorm(256) 61 | self.relu = nn.ReLU() 62 | self.dropout = nn.Dropout(0.5) 63 | self._init_weight() 64 | 65 | def forward(self, x): 66 | x1 = self.aspp1(x) 67 | x2 = self.aspp2(x) 68 | x3 = self.aspp3(x) 69 | x4 = self.aspp4(x) 70 | x5 = self.global_avg_pool(x) 71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 73 | 74 | x = self.conv1(x) 75 | x = self.bn1(x) 76 | x = self.relu(x) 77 | 78 | return self.dropout(x) 79 | 80 | def _init_weight(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | torch.nn.init.kaiming_normal_(m.weight) 86 | elif isinstance(m, SynchronizedBatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | 94 | def build_aspp(backbone, output_stride, BatchNorm): 95 | return ASPP(backbone, output_stride, BatchNorm) -------------------------------------------------------------------------------- /models/deeplab/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from models.deeplab.backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /models/deeplab/backbone/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/backbone/__pycache__/drn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/drn.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/backbone/__pycache__/drn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/drn.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/backbone/__pycache__/mobilenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/mobilenet.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/backbone/__pycache__/resnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/resnet.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/backbone/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/backbone/__pycache__/xception.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/xception.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/backbone/__pycache__/xception.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/xception.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from models.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /models/deeplab/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from models.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = BatchNorm(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class ResNet(nn.Module): 46 | 47 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True): 48 | self.inplanes = 64 49 | super(ResNet, self).__init__() 50 | blocks = [1, 2, 4] 51 | if output_stride == 16: 52 | strides = [1, 2, 2, 1] 53 | dilations = [1, 1, 1, 2] 54 | elif output_stride == 8: 55 | strides = [1, 2, 1, 1] 56 | dilations = [1, 1, 2, 4] 57 | else: 58 | raise NotImplementedError 59 | 60 | # Modules 61 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = BatchNorm(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | 67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 70 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 71 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 72 | self._init_weight() 73 | 74 | if pretrained: 75 | self._load_pretrained_model() 76 | 77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 78 | downsample = None 79 | if stride != 1 or self.inplanes != planes * block.expansion: 80 | downsample = nn.Sequential( 81 | nn.Conv2d(self.inplanes, planes * block.expansion, 82 | kernel_size=1, stride=stride, bias=False), 83 | BatchNorm(planes * block.expansion), 84 | ) 85 | 86 | layers = [] 87 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 88 | self.inplanes = planes * block.expansion 89 | for i in range(1, blocks): 90 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 91 | 92 | return nn.Sequential(*layers) 93 | 94 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 95 | downsample = None 96 | if stride != 1 or self.inplanes != planes * block.expansion: 97 | downsample = nn.Sequential( 98 | nn.Conv2d(self.inplanes, planes * block.expansion, 99 | kernel_size=1, stride=stride, bias=False), 100 | BatchNorm(planes * block.expansion), 101 | ) 102 | 103 | layers = [] 104 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 105 | downsample=downsample, BatchNorm=BatchNorm)) 106 | self.inplanes = planes * block.expansion 107 | for i in range(1, len(blocks)): 108 | layers.append(block(self.inplanes, planes, stride=1, 109 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, input): 114 | x = self.conv1(input) 115 | x = self.bn1(x) 116 | x = self.relu(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | low_level_feat = x 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | return x, low_level_feat 125 | 126 | def _init_weight(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, SynchronizedBatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | 138 | def _load_pretrained_model(self): 139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 140 | model_dict = {} 141 | state_dict = self.state_dict() 142 | for k, v in pretrain_dict.items(): 143 | if k in state_dict: 144 | model_dict[k] = v 145 | state_dict.update(model_dict) 146 | self.load_state_dict(state_dict) 147 | 148 | def ResNet101(output_stride, BatchNorm, pretrained=True): 149 | """Constructs a ResNet-101 model. 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | """ 153 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=False) 154 | # model.features[0]=nn.Conv2d(1, 64, kernel_size=3, padding=1) 155 | return model 156 | 157 | if __name__ == "__main__": 158 | import torch 159 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 160 | input = torch.rand(1, 3, 512, 512) 161 | output, low_level_feat = model(input) 162 | print(output.size()) 163 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /models/deeplab/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 23 | BatchNorm(256), 24 | nn.ReLU(), 25 | nn.Dropout(0.5), 26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 27 | BatchNorm(256), 28 | nn.ReLU(), 29 | nn.Dropout(0.1), 30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 31 | self._init_weight() 32 | 33 | 34 | def forward(self, x, low_level_feat): 35 | low_level_feat = self.conv1(low_level_feat) 36 | low_level_feat = self.bn1(low_level_feat) 37 | low_level_feat = self.relu(low_level_feat) 38 | 39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv(x) 42 | 43 | return x 44 | 45 | def _init_weight(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | torch.nn.init.kaiming_normal_(m.weight) 49 | elif isinstance(m, SynchronizedBatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | 56 | def build_decoder(num_classes, backbone, BatchNorm): 57 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/__pycache__/comm.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/comm.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/__pycache__/replicate.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/replicate.cpython-35.pyc -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /models/deeplab/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /models/deeplab_v3p.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | sys.path.append('..') 6 | from models.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | from models.deeplab.aspp import build_aspp 8 | from models.deeplab.decoder import build_decoder 9 | from models.deeplab.backbone import build_backbone 10 | 11 | 12 | class DeepLab(nn.Module): 13 | def __init__(self, num_classes, backbone='resnet', output_stride=16, 14 | sync_bn=True, freeze_bn=False): 15 | super(DeepLab, self).__init__() 16 | if backbone == 'drn': 17 | output_stride = 8 18 | 19 | if sync_bn == True: 20 | BatchNorm = SynchronizedBatchNorm2d 21 | else: 22 | BatchNorm = nn.BatchNorm2d 23 | 24 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 25 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 26 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 27 | 28 | if freeze_bn: 29 | self.freeze_bn() 30 | 31 | def forward(self, input): 32 | x, low_level_feat = self.backbone(input) 33 | x = self.aspp(x) 34 | x = self.decoder(x, low_level_feat) 35 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 36 | 37 | return x 38 | 39 | def freeze_bn(self): 40 | for m in self.modules(): 41 | if isinstance(m, SynchronizedBatchNorm2d): 42 | m.eval() 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.eval() 45 | 46 | def get_1x_lr_params(self): 47 | modules = [self.backbone] 48 | for i in range(len(modules)): 49 | for m in modules[i].named_modules(): 50 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 51 | or isinstance(m[1], nn.BatchNorm2d): 52 | for p in m[1].parameters(): 53 | if p.requires_grad: 54 | yield p 55 | 56 | def get_10x_lr_params(self): 57 | modules = [self.aspp, self.decoder] 58 | for i in range(len(modules)): 59 | for m in modules[i].named_modules(): 60 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 61 | or isinstance(m[1], nn.BatchNorm2d): 62 | for p in m[1].parameters(): 63 | if p.requires_grad: 64 | yield p 65 | 66 | 67 | if __name__ == "__main__": 68 | import os 69 | os.environ['CUDA_VISIBLE_DEVICES'] = '5' 70 | 71 | # model = DeepLab(backbone='mobilenet', output_stride=16, num_classes=1).cuda() 72 | # model.eval() 73 | # input = torch.rand(1, 3, 512, 512).cuda() 74 | # output = model(input) 75 | # print(output.size()) 76 | 77 | 78 | -------------------------------------------------------------------------------- /models/denseunet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | from collections import OrderedDict 6 | 7 | 8 | class DenseUnet(nn.Module): 9 | def __init__(self, in_ch=3, num_classes=3, hybrid=False): 10 | super().__init__() 11 | self.hybrid = hybrid 12 | num_init_features = 96 13 | backbone = torchvision.models.densenet161(pretrained=False) 14 | self.first_convblock = nn.Sequential(OrderedDict([ 15 | ('conv0', nn.Conv2d(in_ch, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 16 | ('norm0', nn.BatchNorm2d(num_init_features)), 17 | ('relu0', nn.ReLU(inplace=True)), 18 | # ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 19 | ])) 20 | self.pool0 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 21 | 22 | self.denseblock1 = backbone.features.denseblock1 23 | self.transition1 = backbone.features.transition1 24 | self.denseblock2 = backbone.features.denseblock2 25 | self.transition2 = backbone.features.transition2 26 | self.denseblock3 = backbone.features.denseblock3 27 | self.transition3 = backbone.features.transition3 28 | self.denseblock4 = backbone.features.denseblock4 29 | self.bn5 = backbone.features.norm5 30 | self.relu = nn.ReLU(inplace=True) 31 | 32 | self.conv3 = nn.Conv2d(2112, 2208, kernel_size=1, stride=1) 33 | 34 | self.convblock43 = ConvBlock(2208, 768) 35 | self.convblock32 = ConvBlock(768, 384, kernel_size=3, stride=1, padding=1) 36 | self.convblock21 = ConvBlock(384, 96, kernel_size=3, stride=1, padding=1) 37 | self.convblock10 = ConvBlock(96, 96, kernel_size=3, stride=1, padding=1) 38 | self.convblock00 = ConvBlock(96, 64, kernel_size=3, stride=1, padding=1) 39 | self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1, stride=1) 40 | 41 | def forward(self, x): 42 | db0 = self.first_convblock(x) 43 | x = self.pool0(db0) 44 | db1 = self.denseblock1(x) 45 | x = self.transition1(db1) 46 | db2 = self.denseblock2(x) 47 | x = self.transition2(db2) 48 | db3 = self.denseblock3(x) 49 | x = self.transition3(db3) 50 | x = self.denseblock4(x) 51 | x = self.bn5(x) 52 | db4 = self.relu(x) 53 | 54 | up4 = torch.nn.functional.interpolate(db4, scale_factor=2, mode='bilinear', align_corners=True) 55 | db3 = self.conv3(db3) 56 | db43 = torch.add(up4, db3) 57 | db43 = self.convblock43(db43) 58 | 59 | up43 = torch.nn.functional.interpolate(db43, scale_factor=2, mode='bilinear', align_corners=True) 60 | db432 = torch.add(up43, db2) 61 | db432 = self.convblock32(db432) 62 | 63 | up432 = torch.nn.functional.interpolate(db432, scale_factor=2, mode='bilinear', align_corners=True) 64 | db4321 = torch.add(up432, db1) 65 | db4321 = self.convblock21(db4321) 66 | 67 | up4321 = torch.nn.functional.interpolate(db4321, scale_factor=2, mode='bilinear', align_corners=True) 68 | db43210 = torch.add(up4321, db0) 69 | db43210 = self.convblock10(db43210) 70 | 71 | up43210 = torch.nn.functional.interpolate(db43210, scale_factor=2, mode='bilinear', align_corners=True) 72 | db43210 = self.convblock00(up43210) 73 | 74 | out = self.final_conv(db43210) 75 | 76 | if self.hybrid: 77 | return db43210, out 78 | else: 79 | return out 80 | 81 | 82 | class ConvBlock(torch.nn.Module): 83 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): 84 | super().__init__() 85 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 86 | self.bn = nn.BatchNorm2d(out_channels) 87 | self.relu = nn.ReLU() 88 | 89 | def forward(self, input): 90 | x = self.conv1(input) 91 | x = self.bn(x) 92 | x = self.relu(x) 93 | return x 94 | 95 | 96 | 97 | class Dense_Block(nn.Module): 98 | ''' Build a dense_block where the output of each conv_block is fed to subsequent ones 99 | # Arguments 100 | x: input tensor 101 | stage: index for dense block 102 | nb_layers: the number of layers of conv_block to append to the model. 103 | nb_filter: number of filters 104 | growth_rate: growth rate 105 | dropout_rate: dropout rate 106 | weight_decay: weight decay factor 107 | grow_nb_filters: flag to decide to allow number of filters to grow 108 | ''' 109 | def __init__(self, x, stage, nb_layers, nb_filter, growth_rate, dropout_rate=None, weight_decay=1e-4, 110 | grow_nb_filters=True): 111 | super().__init__() 112 | torchvision.models.densenet161() 113 | 114 | def forward(self, x): 115 | pass 116 | 117 | 118 | if __name__ == '__main__': 119 | # model = torchvision.models.densenet161() 120 | # for name, param in model.named_parameters(): 121 | # print(name) 122 | 123 | import os 124 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 125 | # model_bb = torchvision.models.densenet161(pretrained=True).cuda() 126 | 127 | # model = DenseUnet().cuda() 128 | # data = torch.randn((2, 3, 224, 224)).cuda() 129 | # pred = model(data) 130 | # print(pred.shape) 131 | 132 | 133 | -------------------------------------------------------------------------------- /models/init_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | def weights_init_normal(m): 6 | classname = m.__class__.__name__ 7 | #print(classname) 8 | if classname.find('Conv') != -1: 9 | init.normal_(m.weight.data, 0.0, 0.02) 10 | elif classname.find('Linear') != -1: 11 | init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | init.normal_(m.weight.data, 1.0, 0.02) 14 | init.constant_(m.bias.data, 0.0) 15 | 16 | 17 | def weights_init_xavier(m): 18 | classname = m.__class__.__name__ 19 | #print(classname) 20 | if classname.find('Conv') != -1: 21 | init.xavier_normal_(m.weight.data, gain=1) 22 | elif classname.find('Linear') != -1: 23 | init.xavier_normal_(m.weight.data, gain=1) 24 | elif classname.find('BatchNorm') != -1: 25 | init.normal_(m.weight.data, 1.0, 0.02) 26 | init.constant_(m.bias.data, 0.0) 27 | 28 | 29 | def weights_init_kaiming(m): 30 | classname = m.__class__.__name__ 31 | #print(classname) 32 | if classname.find('Conv') != -1: 33 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 34 | elif classname.find('Linear') != -1: 35 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 36 | elif classname.find('BatchNorm') != -1: 37 | init.normal_(m.weight.data, 1.0, 0.02) 38 | init.constant_(m.bias.data, 0.0) 39 | 40 | 41 | def weights_init_orthogonal(m): 42 | classname = m.__class__.__name__ 43 | #print(classname) 44 | if classname.find('Conv') != -1: 45 | init.orthogonal_(m.weight.data, gain=1) 46 | elif classname.find('Linear') != -1: 47 | init.orthogonal_(m.weight.data, gain=1) 48 | elif classname.find('BatchNorm') != -1: 49 | init.normal_(m.weight.data, 1.0, 0.02) 50 | init.constant_(m.bias.data, 0.0) 51 | 52 | 53 | def init_weights(net, init_type='normal'): 54 | #print('initialization method [%s]' % init_type) 55 | if init_type == 'normal': 56 | net.apply(weights_init_normal) 57 | elif init_type == 'xavier': 58 | net.apply(weights_init_xavier) 59 | elif init_type == 'kaiming': 60 | net.apply(weights_init_kaiming) 61 | elif init_type == 'orthogonal': 62 | net.apply(weights_init_orthogonal) 63 | else: 64 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 65 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.init_weights import init_weights 5 | 6 | 7 | class unetConv2(nn.Module): 8 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 9 | super(unetConv2, self).__init__() 10 | self.n = n 11 | self.ks = ks 12 | self.stride = stride 13 | self.padding = padding 14 | s = stride 15 | p = padding 16 | if is_batchnorm: 17 | for i in range(1, n + 1): 18 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 19 | nn.BatchNorm2d(out_size), 20 | nn.ReLU(inplace=True), ) 21 | setattr(self, 'conv%d' % i, conv) 22 | in_size = out_size 23 | 24 | else: 25 | for i in range(1, n + 1): 26 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 27 | nn.ReLU(inplace=True), ) 28 | setattr(self, 'conv%d' % i, conv) 29 | in_size = out_size 30 | 31 | # initialise the blocks 32 | for m in self.children(): 33 | init_weights(m, init_type='kaiming') 34 | 35 | def forward(self, inputs): 36 | x = inputs 37 | for i in range(1, self.n + 1): 38 | conv = getattr(self, 'conv%d' % i) 39 | x = conv(x) 40 | 41 | return x 42 | 43 | class unetUp(nn.Module): 44 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 45 | super(unetUp, self).__init__() 46 | # self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False) 47 | self.conv = unetConv2(out_size*2, out_size, False) 48 | if is_deconv: 49 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) 50 | else: 51 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 52 | 53 | # initialise the blocks 54 | for m in self.children(): 55 | if m.__class__.__name__.find('unetConv2') != -1: continue 56 | init_weights(m, init_type='kaiming') 57 | 58 | def forward(self, inputs0, *input): 59 | # print(self.n_concat) 60 | # print(input) 61 | outputs0 = self.up(inputs0) 62 | for i in range(len(input)): 63 | outputs0 = torch.cat([outputs0, input[i]], 1) 64 | return self.conv(outputs0) 65 | 66 | class unetUp_origin(nn.Module): 67 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 68 | super(unetUp_origin, self).__init__() 69 | # self.conv = unetConv2(out_size*2, out_size, False) 70 | if is_deconv: 71 | self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False) 72 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) 73 | else: 74 | self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False) 75 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 76 | 77 | # initialise the blocks 78 | for m in self.children(): 79 | if m.__class__.__name__.find('unetConv2') != -1: continue 80 | init_weights(m, init_type='kaiming') 81 | 82 | def forward(self, inputs0, *input): 83 | # print(self.n_concat) 84 | # print(input) 85 | outputs0 = self.up(inputs0) 86 | for i in range(len(input)): 87 | outputs0 = torch.cat([outputs0, input[i]], 1) 88 | return self.conv(outputs0) 89 | -------------------------------------------------------------------------------- /models/mynn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from config import cfg 7 | import torch.nn as nn 8 | from math import sqrt 9 | import torch 10 | from torch.autograd.function import InplaceFunction 11 | from itertools import repeat 12 | from torch.nn.modules import Module 13 | from torch.utils.checkpoint import checkpoint 14 | 15 | 16 | def Norm2d(in_channels): 17 | """ 18 | Custom Norm Function to allow flexible switching 19 | """ 20 | layer = getattr(cfg.MODEL,'BNFUNC') 21 | normalizationLayer = layer(in_channels) 22 | return normalizationLayer 23 | 24 | 25 | def initialize_weights(*models): 26 | for model in models: 27 | for module in model.modules(): 28 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 29 | nn.init.kaiming_normal(module.weight) 30 | if module.bias is not None: 31 | module.bias.data.zero_() 32 | elif isinstance(module, nn.BatchNorm2d): 33 | module.weight.data.fill_(1) 34 | module.bias.data.zero_() 35 | -------------------------------------------------------------------------------- /models/new2net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from backbone.origin.from_origin import Backbone_ResNet50_in3, Backbone_VGG16_in3 5 | from module.BaseBlocks import BasicConv2d 6 | from module.MyModule import AIM, SIM 7 | from utils.tensor_ops import cus_sample, upsample_add 8 | 9 | 10 | class MINet_VGG16(nn.Module): 11 | def __init__(self): 12 | super(MINet_VGG16, self).__init__() 13 | self.upsample_add = upsample_add 14 | self.upsample = cus_sample 15 | 16 | ( 17 | self.encoder1, 18 | self.encoder2, 19 | self.encoder4, 20 | self.encoder8, 21 | self.encoder16, 22 | ) = Backbone_VGG16_in3() 23 | 24 | self.trans = AIM((64, 128, 256, 512, 512), (32, 64, 64, 64, 64)) 25 | 26 | self.sim16 = SIM(64, 32) 27 | self.sim8 = SIM(64, 32) 28 | self.sim4 = SIM(64, 32) 29 | self.sim2 = SIM(64, 32) 30 | self.sim1 = SIM(32, 16) 31 | 32 | self.upconv16 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1) 33 | self.upconv8 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1) 34 | self.upconv4 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1) 35 | self.upconv2 = BasicConv2d(64, 32, kernel_size=3, stride=1, padding=1) 36 | self.upconv1 = BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 37 | 38 | self.classifier = nn.Conv2d(32, 1, 1) 39 | 40 | def forward(self, in_data): 41 | in_data_1 = self.encoder1(in_data) 42 | in_data_2 = self.encoder2(in_data_1) 43 | in_data_4 = self.encoder4(in_data_2) 44 | in_data_8 = self.encoder8(in_data_4) 45 | in_data_16 = self.encoder16(in_data_8) 46 | 47 | in_data_1, in_data_2, in_data_4, in_data_8, in_data_16 = self.trans( 48 | in_data_1, in_data_2, in_data_4, in_data_8, in_data_16 49 | ) 50 | 51 | out_data_16 = self.upconv16(self.sim16(in_data_16)) # 1024 52 | 53 | out_data_8 = self.upsample_add(out_data_16, in_data_8) 54 | out_data_8 = self.upconv8(self.sim8(out_data_8)) # 512 55 | 56 | out_data_4 = self.upsample_add(out_data_8, in_data_4) 57 | out_data_4 = self.upconv4(self.sim4(out_data_4)) # 256 58 | 59 | out_data_2 = self.upsample_add(out_data_4, in_data_2) 60 | out_data_2 = self.upconv2(self.sim2(out_data_2)) # 64 61 | 62 | out_data_1 = self.upsample_add(out_data_2, in_data_1) 63 | out_data_1 = self.upconv1(self.sim1(out_data_1)) # 32 64 | 65 | out_data = self.classifier(out_data_1) 66 | 67 | return out_data 68 | 69 | 70 | class MINet_Res50(nn.Module): 71 | def __init__(self): 72 | super(MINet_Res50, self).__init__() 73 | self.div_2, self.div_4, self.div_8, self.div_16, self.div_32 = Backbone_ResNet50_in3() 74 | 75 | self.upsample_add = upsample_add 76 | self.upsample = cus_sample 77 | 78 | self.trans = AIM(iC_list=(64, 256, 512, 1024, 2048), oC_list=(64, 64, 64, 64, 64)) 79 | 80 | self.sim32 = SIM(64, 32) 81 | self.sim16 = SIM(64, 32) 82 | self.sim8 = SIM(64, 32) 83 | self.sim4 = SIM(64, 32) 84 | self.sim2 = SIM(64, 32) 85 | 86 | self.upconv32 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1) 87 | self.upconv16 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1) 88 | self.upconv8 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1) 89 | self.upconv4 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1) 90 | self.upconv2 = BasicConv2d(64, 32, kernel_size=3, stride=1, padding=1) 91 | self.upconv1 = BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 92 | 93 | self.classifier = nn.Conv2d(32, 1, 1) 94 | 95 | def forward(self, in_data): 96 | in_data_2 = self.div_2(in_data) 97 | in_data_4 = self.div_4(in_data_2) 98 | in_data_8 = self.div_8(in_data_4) 99 | in_data_16 = self.div_16(in_data_8) 100 | in_data_32 = self.div_32(in_data_16) 101 | 102 | in_data_2, in_data_4, in_data_8, in_data_16, in_data_32 = self.trans( 103 | in_data_2, in_data_4, in_data_8, in_data_16, in_data_32 104 | ) 105 | 106 | out_data_32 = self.upconv32(self.sim32(in_data_32)) # 1024 107 | 108 | out_data_16 = self.upsample_add(out_data_32, in_data_16) # 1024 109 | out_data_16 = self.upconv16(self.sim16(out_data_16)) 110 | 111 | out_data_8 = self.upsample_add(out_data_16, in_data_8) 112 | out_data_8 = self.upconv8(self.sim8(out_data_8)) # 512 113 | 114 | out_data_4 = self.upsample_add(out_data_8, in_data_4) 115 | out_data_4 = self.upconv4(self.sim4(out_data_4)) # 256 116 | 117 | out_data_2 = self.upsample_add(out_data_4, in_data_2) 118 | out_data_2 = self.upconv2(self.sim2(out_data_2)) # 64 119 | 120 | out_data_1 = self.upconv1(self.upsample(out_data_2, scale_factor=2)) # 32 121 | out_data = self.classifier(out_data_1) 122 | 123 | return out_data 124 | 125 | 126 | if __name__ == "__main__": 127 | in_data = torch.randn((1, 3, 320, 320)) 128 | net = MINet_VGG16() 129 | print(sum([x.nelement() for x in net.parameters()])) -------------------------------------------------------------------------------- /models/norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from config import cfg 7 | import torch.nn as nn 8 | from math import sqrt 9 | import torch 10 | from torch.autograd.function import InplaceFunction 11 | from itertools import repeat 12 | from torch.nn.modules import Module 13 | from torch.utils.checkpoint import checkpoint 14 | 15 | 16 | def Norm2d(in_channels): 17 | """ 18 | Custom Norm Function to allow flexible switching 19 | """ 20 | layer = getattr(cfg.MODEL,'BNFUNC') 21 | normalizationLayer = layer(in_channels) 22 | return normalizationLayer 23 | 24 | 25 | def initialize_weights(*models): 26 | for model in models: 27 | for module in model.modules(): 28 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 29 | nn.init.kaiming_normal(module.weight) 30 | if module.bias is not None: 31 | module.bias.data.zero_() 32 | elif isinstance(module, nn.BatchNorm2d): 33 | module.weight.data.fill_(1) 34 | module.bias.data.zero_() 35 | -------------------------------------------------------------------------------- /models/pretrain/SAMNet_with_ImageNet_pretrain.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/pretrain/SAMNet_with_ImageNet_pretrain.pth -------------------------------------------------------------------------------- /models/pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ModelBuilder, SegmentationModule, SAUNet, VGG19UNet, VGG19UNet_without_boudary, VGGUNet 2 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from lib.nn import SynchronizedBatchNorm2d 7 | 8 | try: 9 | from urllib import urlretrieve 10 | except ImportError: 11 | from urllib.request import urlretrieve 12 | 13 | 14 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 19 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 20 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | "3x3 convolution with padding" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = SynchronizedBatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = SynchronizedBatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = SynchronizedBatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = SynchronizedBatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = SynchronizedBatchNorm2d(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | self.inplanes = 128 105 | super(ResNet, self).__init__() 106 | self.conv1 = conv3x3(3, 64, stride=2) 107 | self.bn1 = SynchronizedBatchNorm2d(64) 108 | self.relu1 = nn.ReLU(inplace=True) 109 | self.conv2 = conv3x3(64, 64) 110 | self.bn2 = SynchronizedBatchNorm2d(64) 111 | self.relu2 = nn.ReLU(inplace=True) 112 | self.conv3 = conv3x3(64, 128) 113 | self.bn3 = SynchronizedBatchNorm2d(128) 114 | self.relu3 = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | 117 | self.layer1 = self._make_layer(block, 64, layers[0]) 118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 121 | self.avgpool = nn.AvgPool2d(7, stride=1) 122 | self.fc = nn.Linear(512 * block.expansion, num_classes) 123 | 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 127 | m.weight.data.normal_(0, math.sqrt(2. / n)) 128 | elif isinstance(m, SynchronizedBatchNorm2d): 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | 132 | def _make_layer(self, block, planes, blocks, stride=1): 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | nn.Conv2d(self.inplanes, planes * block.expansion, 137 | kernel_size=1, stride=stride, bias=False), 138 | SynchronizedBatchNorm2d(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, downsample)) 143 | self.inplanes = planes * block.expansion 144 | for i in range(1, blocks): 145 | layers.append(block(self.inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | x = self.relu1(self.bn1(self.conv1(x))) 151 | x = self.relu2(self.bn2(self.conv2(x))) 152 | x = self.relu3(self.bn3(self.conv3(x))) 153 | x = self.maxpool(x) 154 | 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | def resnet18(pretrained=False, **kwargs): 167 | """Constructs a ResNet-18 model. 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | """ 171 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 172 | if pretrained: 173 | model.load_state_dict(load_url(model_urls['resnet18'])) 174 | return model 175 | 176 | ''' 177 | def resnet34(pretrained=False, **kwargs): 178 | """Constructs a ResNet-34 model. 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | """ 182 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 183 | if pretrained: 184 | model.load_state_dict(load_url(model_urls['resnet34'])) 185 | return model 186 | ''' 187 | 188 | def resnet50(pretrained=False, **kwargs): 189 | """Constructs a ResNet-50 model. 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | """ 193 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 194 | if pretrained: 195 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 196 | return model 197 | 198 | 199 | def resnet101(pretrained=False, **kwargs): 200 | """Constructs a ResNet-101 model. 201 | Args: 202 | pretrained (bool): If True, returns a model pre-trained on ImageNet 203 | """ 204 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 205 | if pretrained: 206 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 207 | return model 208 | 209 | # def resnet152(pretrained=False, **kwargs): 210 | # """Constructs a ResNet-152 model. 211 | # 212 | # Args: 213 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | # """ 215 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 216 | # if pretrained: 217 | # model.load_state_dict(load_url(model_urls['resnet152'])) 218 | # return model 219 | 220 | def load_url(url, model_dir='./pretrained', map_location=None): 221 | if not os.path.exists(model_dir): 222 | os.makedirs(model_dir) 223 | filename = url.split('/')[-1] 224 | cached_file = os.path.join(model_dir, filename) 225 | if not os.path.exists(cached_file): 226 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 227 | urlretrieve(url, cached_file) 228 | return torch.load(cached_file, map_location=map_location) 229 | -------------------------------------------------------------------------------- /models/segnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | class SegNet(nn.Module): 7 | def __init__(self,input_nbr,label_nbr): 8 | super(SegNet, self).__init__() 9 | 10 | batchNorm_momentum = 0.1 11 | 12 | self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1) 13 | self.bn11 = nn.BatchNorm2d(64, momentum= batchNorm_momentum) 14 | self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 15 | self.bn12 = nn.BatchNorm2d(64, momentum= batchNorm_momentum) 16 | 17 | self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 18 | self.bn21 = nn.BatchNorm2d(128, momentum= batchNorm_momentum) 19 | self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 20 | self.bn22 = nn.BatchNorm2d(128, momentum= batchNorm_momentum) 21 | 22 | self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 23 | self.bn31 = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 24 | self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 25 | self.bn32 = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 26 | self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 27 | self.bn33 = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 28 | 29 | self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 30 | self.bn41 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 31 | self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 32 | self.bn42 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 33 | self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 34 | self.bn43 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 35 | 36 | self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 37 | self.bn51 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 38 | self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 39 | self.bn52 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 40 | self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 41 | self.bn53 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 42 | 43 | self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 44 | self.bn53d = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 45 | self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 46 | self.bn52d = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 47 | self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 48 | self.bn51d = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 49 | 50 | self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 51 | self.bn43d = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 52 | self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 53 | self.bn42d = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 54 | self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1) 55 | self.bn41d = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 56 | 57 | self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 58 | self.bn33d = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 59 | self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 60 | self.bn32d = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 61 | self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1) 62 | self.bn31d = nn.BatchNorm2d(128, momentum= batchNorm_momentum) 63 | 64 | self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1) 65 | self.bn22d = nn.BatchNorm2d(128, momentum= batchNorm_momentum) 66 | self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1) 67 | self.bn21d = nn.BatchNorm2d(64, momentum= batchNorm_momentum) 68 | 69 | self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1) 70 | self.bn12d = nn.BatchNorm2d(64, momentum= batchNorm_momentum) 71 | self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1) 72 | self.sigmoid = nn.Sigmoid() 73 | 74 | def forward(self, x): 75 | 76 | # Stage 1 77 | x11 = F.relu(self.bn11(self.conv11(x))) 78 | x12 = F.relu(self.bn12(self.conv12(x11))) 79 | x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=1,return_indices=True) 80 | 81 | # Stage 2 82 | x21 = F.relu(self.bn21(self.conv21(x1p))) 83 | x22 = F.relu(self.bn22(self.conv22(x21))) 84 | x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=1,return_indices=True) 85 | 86 | # Stage 3 87 | x31 = F.relu(self.bn31(self.conv31(x2p))) 88 | x32 = F.relu(self.bn32(self.conv32(x31))) 89 | x33 = F.relu(self.bn33(self.conv33(x32))) 90 | x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=1,return_indices=True) 91 | 92 | # Stage 4 93 | x41 = F.relu(self.bn41(self.conv41(x3p))) 94 | x42 = F.relu(self.bn42(self.conv42(x41))) 95 | x43 = F.relu(self.bn43(self.conv43(x42))) 96 | x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=1,return_indices=True) 97 | 98 | # Stage 5 99 | x51 = F.relu(self.bn51(self.conv51(x4p))) 100 | x52 = F.relu(self.bn52(self.conv52(x51))) 101 | x53 = F.relu(self.bn53(self.conv53(x52))) 102 | x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=1,return_indices=True) 103 | 104 | 105 | # Stage 5d 106 | x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=1) 107 | x53d = F.relu(self.bn53d(self.conv53d(x5d))) 108 | x52d = F.relu(self.bn52d(self.conv52d(x53d))) 109 | x51d = F.relu(self.bn51d(self.conv51d(x52d))) 110 | 111 | # Stage 4d 112 | x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=1) 113 | x43d = F.relu(self.bn43d(self.conv43d(x4d))) 114 | x42d = F.relu(self.bn42d(self.conv42d(x43d))) 115 | x41d = F.relu(self.bn41d(self.conv41d(x42d))) 116 | 117 | # Stage 3d 118 | x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=1) 119 | x33d = F.relu(self.bn33d(self.conv33d(x3d))) 120 | x32d = F.relu(self.bn32d(self.conv32d(x33d))) 121 | x31d = F.relu(self.bn31d(self.conv31d(x32d))) 122 | 123 | # Stage 2d 124 | x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=1) 125 | x22d = F.relu(self.bn22d(self.conv22d(x2d))) 126 | x21d = F.relu(self.bn21d(self.conv21d(x22d))) 127 | 128 | # Stage 1d 129 | x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=1) 130 | x12d = F.relu(self.bn12d(self.conv12d(x1d))) 131 | x11d = self.conv11d(x12d) 132 | # x11d = self.sigmoid(x11d) 133 | return x11d -------------------------------------------------------------------------------- /models/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import warnings 5 | warnings.filterwarnings(action='ignore') 6 | 7 | 8 | class double_conv(nn.Module): 9 | '''(conv => BN => ReLU) * 2''' 10 | 11 | def __init__(self, in_ch, out_ch): 12 | super(double_conv, self).__init__() 13 | self.conv = nn.Sequential( 14 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 15 | nn.BatchNorm2d(out_ch), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 18 | nn.BatchNorm2d(out_ch), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | return x 25 | 26 | 27 | class inconv(nn.Module): 28 | def __init__(self, in_ch, out_ch): 29 | super(inconv, self).__init__() 30 | self.conv = double_conv(in_ch, out_ch) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | return x 35 | 36 | 37 | class down(nn.Module): 38 | def __init__(self, in_ch, out_ch): 39 | super(down, self).__init__() 40 | self.mpconv = nn.Sequential( 41 | nn.MaxPool2d(2), 42 | double_conv(in_ch, out_ch) 43 | ) 44 | 45 | def forward(self, x): 46 | x = self.mpconv(x) 47 | return x 48 | 49 | 50 | class up(nn.Module): 51 | def __init__(self, in_ch, out_ch, bilinear=True): 52 | super(up, self).__init__() 53 | 54 | if bilinear: 55 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 56 | else: 57 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 58 | 59 | self.conv = double_conv(in_ch, out_ch) 60 | 61 | def forward(self,x1, x2): 62 | x1 = self.up(x1) 63 | 64 | diffY = x2.size()[2] - x1.size()[2] 65 | diffX = x2.size()[3] - x1.size()[3] 66 | 67 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 68 | diffY // 2, diffY - diffY // 2)) 69 | 70 | x = torch.cat([x2, x1], dim=1) 71 | x = self.conv(x) 72 | return x 73 | 74 | 75 | class outconv(nn.Module): 76 | def __init__(self, in_ch, out_ch): 77 | super(outconv, self).__init__() 78 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | return x 83 | 84 | 85 | class GEncoder(nn.Module): 86 | def __init__(self, n_channels, n_classes, deep_supervision = False): 87 | super(GEncoder, self).__init__() 88 | self.deep_supervision = deep_supervision 89 | self.inc = inconv(n_channels, 64) 90 | self.down1 = down(64, 128) 91 | self.down2 = down(128, 256) 92 | self.down3 = down(256, 512) 93 | self.down4 = down(512, 1024) 94 | 95 | 96 | def forward(self, x): 97 | x1 = self.inc(x) 98 | x2 = self.down1(x1) 99 | x3 = self.down2(x2) 100 | x4 = self.down3(x3) 101 | x5 = self.down4(x4) 102 | 103 | 104 | if self.deep_supervision: 105 | x11 = F.interpolate(self.dsoutc1(x11), x0.shape[2:], mode='bilinear') 106 | x22 = F.interpolate(self.dsoutc2(x22), x0.shape[2:], mode='bilinear') 107 | x33 = F.interpolate(self.dsoutc3(x33), x0.shape[2:], mode='bilinear') 108 | x44 = F.interpolate(self.dsoutc4(x44), x0.shape[2:], mode='bilinear') 109 | 110 | return x0, x11, x22, x33, x44 111 | else: 112 | return x5 113 | 114 | if __name__ == '__main__': 115 | ras =GEncoder(n_channels=1, n_classes=1).cuda() 116 | input_tensor = torch.randn(4, 1, 96, 96).cuda() 117 | out = ras(input_tensor) 118 | print(out[0].shape) 119 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import warnings 5 | warnings.filterwarnings(action='ignore') 6 | 7 | 8 | class double_conv(nn.Module): 9 | '''(conv => BN => ReLU) * 2''' 10 | 11 | def __init__(self, in_ch, out_ch): 12 | super(double_conv, self).__init__() 13 | self.conv = nn.Sequential( 14 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 15 | nn.BatchNorm2d(out_ch), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 18 | nn.BatchNorm2d(out_ch), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | return x 25 | 26 | 27 | class inconv(nn.Module): 28 | def __init__(self, in_ch, out_ch): 29 | super(inconv, self).__init__() 30 | self.conv = double_conv(in_ch, out_ch) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | return x 35 | 36 | 37 | class down(nn.Module): 38 | def __init__(self, in_ch, out_ch): 39 | super(down, self).__init__() 40 | self.mpconv = nn.Sequential( 41 | nn.MaxPool2d(2), 42 | double_conv(in_ch, out_ch) 43 | ) 44 | 45 | def forward(self, x): 46 | x = self.mpconv(x) 47 | return x 48 | 49 | 50 | class up(nn.Module): 51 | def __init__(self, in_ch, out_ch, bilinear=True): 52 | super(up, self).__init__() 53 | 54 | if bilinear: 55 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 56 | else: 57 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 58 | 59 | self.conv = double_conv(in_ch, out_ch) 60 | 61 | def forward(self,x1, x2): 62 | x1 = self.up(x1) 63 | 64 | diffY = x2.size()[2] - x1.size()[2] 65 | diffX = x2.size()[3] - x1.size()[3] 66 | 67 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 68 | diffY // 2, diffY - diffY // 2)) 69 | 70 | x = torch.cat([x2, x1], dim=1) 71 | x = self.conv(x) 72 | return x 73 | 74 | 75 | class outconv(nn.Module): 76 | def __init__(self, in_ch, out_ch): 77 | super(outconv, self).__init__() 78 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | return x 83 | 84 | 85 | class UNet(nn.Module): 86 | def __init__(self, n_channels, n_classes, deep_supervision = False): 87 | super(UNet, self).__init__() 88 | self.deep_supervision = deep_supervision 89 | self.inc = inconv(n_channels, 64) 90 | self.down1 = down(64, 128) 91 | self.down2 = down(128, 256) 92 | self.down3 = down(256, 512) 93 | self.down4 = down(512, 512) 94 | self.up1 = up(1024, 256) 95 | self.up2 = up(512, 128) 96 | self.up3 = up(256, 64) 97 | self.up4 = up(128, 64) 98 | self.outc = outconv(64, n_classes) 99 | 100 | self.dsoutc4 = outconv(256, n_classes) 101 | self.dsoutc3 = outconv(128, n_classes) 102 | self.dsoutc2 = outconv(64, n_classes) 103 | self.dsoutc1 = outconv(64, n_classes) 104 | 105 | def forward(self, x): 106 | x1 = self.inc(x) 107 | # print(x1.shape) 108 | x2 = self.down1(x1) 109 | # print(x2.shape) 110 | x3 = self.down2(x2) 111 | # print(x3.shape) 112 | x4 = self.down3(x3) 113 | # print(x4.shape) 114 | x5 = self.down4(x4) 115 | # print(x5.shape) 116 | x44 = self.up1(x5, x4) 117 | x33 = self.up2(x44, x3) 118 | x22 = self.up3(x33, x2) 119 | x11 = self.up4(x22, x1) 120 | x0 = self.outc(x11) 121 | if self.deep_supervision: 122 | x11 = F.interpolate(self.dsoutc1(x11), x0.shape[2:], mode='bilinear') 123 | x22 = F.interpolate(self.dsoutc2(x22), x0.shape[2:], mode='bilinear') 124 | x33 = F.interpolate(self.dsoutc3(x33), x0.shape[2:], mode='bilinear') 125 | x44 = F.interpolate(self.dsoutc4(x44), x0.shape[2:], mode='bilinear') 126 | 127 | return x0, x11, x22, x33, x44 128 | else: 129 | return x0 130 | 131 | 132 | 133 | if __name__ == '__main__': 134 | ras = UNet(n_channels=1, n_classes=1).cuda() 135 | input_tensor = torch.randn(4, 1, 96, 96).cuda() 136 | out = ras(input_tensor) 137 | print(out[0].shape) 138 | 139 | 140 | -------------------------------------------------------------------------------- /models/wassp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | import torch.nn as nn 4 | 5 | # from resnet import resnet34 6 | # import resnet 7 | from torch.nn import functional as F 8 | class ConvBnRelu(nn.Module): 9 | def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1, 10 | groups=1, has_bn=True, norm_layer=nn.BatchNorm2d, 11 | has_relu=True, inplace=True, has_bias=False): 12 | super(ConvBnRelu, self).__init__() 13 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize, 14 | stride=stride, padding=pad, 15 | dilation=dilation, groups=groups, bias=has_bias) 16 | self.has_bn = has_bn 17 | if self.has_bn: 18 | self.bn = nn.BatchNorm2d(out_planes) 19 | self.has_relu = has_relu 20 | if self.has_relu: 21 | self.relu = nn.ReLU(inplace=inplace) 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | if self.has_bn: 26 | x = self.bn(x) 27 | if self.has_relu: 28 | x = self.relu(x) 29 | 30 | return x 31 | 32 | 33 | 34 | class SAPP(nn.Module): 35 | def __init__(self, in_channels): 36 | super(SAPP, self).__init__() 37 | self.conv3x3=nn.Conv2d(in_channels=in_channels, out_channels=in_channels,dilation=1,kernel_size=3, padding=1) 38 | 39 | self.bn=nn.ModuleList([nn.BatchNorm2d(in_channels),nn.BatchNorm2d(in_channels),nn.BatchNorm2d(in_channels)]) 40 | self.conv1x1=nn.ModuleList([nn.Conv2d(in_channels=2*in_channels, out_channels=in_channels,dilation=1,kernel_size=1, padding=0), 41 | nn.Conv2d(in_channels=2*in_channels, out_channels=in_channels,dilation=1,kernel_size=1, padding=0)]) 42 | self.conv3x3_1=nn.ModuleList([nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2,dilation=1,kernel_size=3, padding=1), 43 | nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2,dilation=1,kernel_size=3, padding=1)]) 44 | self.conv3x3_2=nn.ModuleList([nn.Conv2d(in_channels=in_channels//2, out_channels=2,dilation=1,kernel_size=3, padding=1), 45 | nn.Conv2d(in_channels=in_channels//2, out_channels=2,dilation=1,kernel_size=3, padding=1)]) 46 | self.conv_last=ConvBnRelu(in_planes=in_channels,out_planes=in_channels,ksize=1,stride=1,pad=0,dilation=1) 47 | self.norm = nn.Sigmoid() 48 | self.conv1= nn.Conv2d(in_channels*2, 1, kernel_size=1, padding=0) 49 | self.dconv1=nn.Conv2d(in_channels*2, in_channels, kernel_size=1, padding=0) 50 | self.gamma = nn.Parameter(torch.zeros(1)) 51 | 52 | self.relu=nn.ReLU(inplace=True) 53 | 54 | def forward(self, x): 55 | 56 | x_size= x.size() 57 | 58 | branches_1=self.conv3x3(x) 59 | branches_1=self.bn[0](branches_1) 60 | 61 | branches_2=F.conv2d(x,self.conv3x3.weight,padding=2,dilation=2)#share weight 62 | branches_2=self.bn[1](branches_2) 63 | 64 | branches_3=F.conv2d(x,self.conv3x3.weight,padding=4,dilation=4)#share weight 65 | branches_3=self.bn[2](branches_3) 66 | 67 | feat=torch.cat([branches_1,branches_2],dim=1) 68 | 69 | feat_g =feat 70 | # print(feat_g.shape) 71 | feat_g1 = self.relu(self.conv1(feat_g)) 72 | feat_g1 = self.norm(feat_g1) 73 | 74 | out1 = feat_g * feat_g1 75 | out1 = self.dconv1(out1) 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | # feat=feat_cat.detach() 84 | feat=self.relu(self.conv1x1[0](feat)) 85 | feat=self.relu(self.conv3x3_1[0](feat)) 86 | att=self.conv3x3_2[0](feat) 87 | att = F.softmax(att, dim=1) 88 | 89 | att_1=att[:,0,:,:].unsqueeze(1) 90 | att_2=att[:,1,:,:].unsqueeze(1) 91 | 92 | fusion_1_2=att_1*branches_1+att_2*branches_2 +out1 93 | 94 | 95 | 96 | feat1=torch.cat([fusion_1_2,branches_3],dim=1) 97 | 98 | feat_g =feat1 99 | feat_g1 = self.relu(self.conv1(feat_g)) 100 | feat_g1 = self.norm(feat_g1) 101 | out2 = feat_g * feat_g1 102 | out2 = self.dconv1(out2) 103 | 104 | 105 | # feat=feat_cat.detach() 106 | feat1=self.relu(self.conv1x1[0](feat1)) 107 | feat1=self.relu(self.conv3x3_1[0](feat1)) 108 | att1=self.conv3x3_2[0](feat1) 109 | att1 = F.softmax(att1, dim=1) 110 | 111 | 112 | att_1_2=att1[:,0,:,:].unsqueeze(1) 113 | 114 | att_3=att1[:,1,:,:].unsqueeze(1) 115 | 116 | 117 | ax=self.relu(self.gamma*(att_1_2*fusion_1_2+att_3*branches_3 +out2)+(1-self.gamma)*x) 118 | ax=self.conv_last(ax) 119 | 120 | return ax 121 | 122 | 123 | 124 | # if __name__=='__main__': 125 | # x=torch.randn(1,512,6,6) 126 | 127 | # net = SAPblock(512) 128 | # ax =net(x) 129 | 130 | # print(ax.shape) 131 | 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /models/wnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision import models 5 | from models.multi_scale_module import FoldConv_aspp 6 | 7 | 8 | class GateNet(nn.Module): 9 | def __init__(self): 10 | super(GateNet, self).__init__() 11 | ################################vgg16####################################### 12 | feats = list(models.vgg16_bn(pretrained=True).features.children()) 13 | feats[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 14 | self.conv1 = nn.Sequential(*feats[:6]) 15 | # print(self.conv1) 16 | self.conv2 = nn.Sequential(*feats[6:13]) 17 | self.conv3 = nn.Sequential(*feats[13:23]) 18 | self.conv4 = nn.Sequential(*feats[23:33]) 19 | self.conv5 = nn.Sequential(*feats[33:43]) 20 | ################################Gate####################################### 21 | self.attention_feature5 = nn.Sequential(nn.Conv2d(64+32, 2, kernel_size=3, padding=1)) 22 | self.attention_feature4 = nn.Sequential(nn.Conv2d(128+64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 23 | nn.Conv2d(64, 2, kernel_size=3, padding=1)) 24 | self.attention_feature3 = nn.Sequential(nn.Conv2d(256+128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 25 | nn.Conv2d(128, 2, kernel_size=3, padding=1)) 26 | self.attention_feature2 = nn.Sequential(nn.Conv2d(512+256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU(), 27 | nn.Conv2d(256, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 28 | nn.Conv2d(64, 2, kernel_size=3, padding=1)) 29 | self.attention_feature1 = nn.Sequential(nn.Conv2d(512+512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU(), 30 | nn.Conv2d(512, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 31 | nn.Conv2d(128, 2, kernel_size=3, padding=1)) 32 | ###############################Transition Layer######################################## 33 | self.dem1 = nn.Sequential(FoldConv_aspp(in_channel=512, 34 | out_channel=512, 35 | out_size=384 // 16, 36 | kernel_size=3, 37 | stride=1, 38 | padding=2, 39 | dilation=2, 40 | win_size=2, 41 | win_padding=0, 42 | 43 | ), nn.BatchNorm2d(512), nn.PReLU()) 44 | self.dem2 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 45 | self.dem3 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 46 | self.dem4 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 47 | self.dem5 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU()) 48 | ################################FPN branch####################################### 49 | self.output1 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 50 | self.output2 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 51 | self.output3 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 52 | self.output4 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.BatchNorm2d(32), nn.PReLU()) 53 | self.output5 = nn.Sequential(nn.Conv2d(32, 1, kernel_size=3, padding=1)) 54 | ################################Parallel branch####################################### 55 | self.dem1_1 = nn.Sequential(nn.Conv2d(512, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU()) 56 | self.dem2_1 = nn.Sequential(nn.Conv2d(256, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU()) 57 | self.dem3_1 = nn.Sequential(nn.Conv2d(128, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU()) 58 | self.dem4_1 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU()) 59 | self.dem5_1 = nn.Sequential(nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU()) 60 | self.out_res = nn.Sequential(nn.Conv2d(32+32+32+32+32+1, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 61 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 62 | nn.Conv2d(64, 1, kernel_size=3, padding=1)) 63 | ####################################################################### 64 | 65 | 66 | for m in self.modules(): 67 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout): 68 | m.inplace = True 69 | 70 | 71 | def forward(self, x): 72 | input = x 73 | B,_,_,_ = input.size() 74 | ################################Encoder block####################################### 75 | E1 = self.conv1(x) 76 | E2 = self.conv2(E1) 77 | E3 = self.conv3(E2) 78 | E4 = self.conv4(E3) 79 | E5 = self.conv5(E4) 80 | ################################Transition Layer####################################### 81 | T5 = self.dem1(E5) 82 | T4 = self.dem2(E4) 83 | T3 = self.dem3(E3) 84 | T2 = self.dem4(E2) 85 | T1 = self.dem5(E1) 86 | ################################Gated FPN####################################### 87 | G5 = self.attention_feature1(torch.cat((E5,T5),1)) 88 | G5 = F.adaptive_avg_pool2d(F.sigmoid(G5),1) 89 | D5 = self.output1(G5[:, 0,:,:].unsqueeze(1).repeat(1,512,1,1)*T5) 90 | 91 | G4 = self.attention_feature2(torch.cat((E4,F.upsample(D5, size=E4.size()[2:], mode='bilinear')),1)) 92 | G4 = F.adaptive_avg_pool2d(F.sigmoid(G4),1) 93 | D4 = self.output2(F.upsample(D5, size=E4.size()[2:], mode='bilinear')+G4[:, 0,:,:].unsqueeze(1).repeat(1,256,1,1)*T4) 94 | 95 | G3 = self.attention_feature3(torch.cat((E3,F.upsample(D4, size=E3.size()[2:], mode='bilinear')),1)) 96 | G3 = F.adaptive_avg_pool2d(F.sigmoid(G3),1) 97 | D3 = self.output3(F.upsample(D4, size=E3.size()[2:], mode='bilinear')+G3[:, 0,:,:].unsqueeze(1).repeat(1,128,1,1)*T3) 98 | 99 | G2 = self.attention_feature4(torch.cat((E2,F.upsample(D3, size=E2.size()[2:], mode='bilinear')),1)) 100 | G2 = F.adaptive_avg_pool2d(F.sigmoid(G2),1) 101 | D2 = self.output4(F.upsample(D3, size=E2.size()[2:], mode='bilinear')+G2[:, 0,:,:].unsqueeze(1).repeat(1,64,1,1)*T2) 102 | 103 | G1 = self.attention_feature5(torch.cat((E1,F.upsample(D2, size=E1.size()[2:], mode='bilinear')),1)) 104 | G1 = F.adaptive_avg_pool2d(F.sigmoid(G1),1) 105 | D1 = self.output5(F.upsample(D2, size=E1.size()[2:], mode='bilinear')+G1[:, 0,:,:].unsqueeze(1).repeat(1,32,1,1)*T1) 106 | 107 | 108 | ################################Gated Parallel&Dual branch residual fuse####################################### 109 | R5 = self.dem1_1(T5) 110 | R4 = self.dem2_1(T4) 111 | R3 = self.dem3_1(T3) 112 | R2 = self.dem4_1(T2) 113 | R1 = self.dem5_1(T1) 114 | output_res = self.out_res(torch.cat((D1,F.upsample(G5[:, 1,:,:].unsqueeze(1).repeat(1,32,1,1)*R5,size=E1.size()[2:], mode='bilinear'),F.upsample(G4[:, 1,:,:].unsqueeze(1).repeat(1,32,1,1)*R4,size=E1.size()[2:], mode='bilinear'),F.upsample(G3[:, 1,:,:].unsqueeze(1).repeat(1,32,1,1)*R3,size=E1.size()[2:], mode='bilinear'),F.upsample(G2[:, 1,:,:].unsqueeze(1).repeat(1,32,1,1)*R2,size=E1.size()[2:], mode='bilinear'),F.upsample(G1[:, 1,:,:].unsqueeze(1).repeat(1,32,1,1)*R1,size=E1.size()[2:], mode='bilinear')),1)) 115 | output_res = F.upsample(output_res,size=input.size()[2:], mode='bilinear') 116 | output_fpn = F.upsample(D1, size=input.size()[2:], mode='bilinear') 117 | pre_sal = output_fpn+output_res 118 | # print(pre_sal.shape) 119 | ####################################################################### 120 | # if self.training: 121 | # return output_fpn, pre_sal 122 | return output_fpn, pre_sal 123 | 124 | if __name__ == "__main__": 125 | model = GateNet() 126 | input = torch.autograd.Variable(torch.randn(4, 3, 384, 384)) 127 | output = model(input) 128 | print(output.shape) 129 | -------------------------------------------------------------------------------- /trian_CGR_XS.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | # from losses import * 3 | from data.dataloader import XSDataset, XSDatatest 4 | import torch.nn as nn 5 | import torch 6 | 7 | from models.unet import UNet 8 | 9 | from CGRmodes.CGR import CGRNet 10 | # from models.vggunet import VGGUNet 11 | # from utils.metric import * 12 | from torchvision.transforms import transforms 13 | from evaluation import * 14 | # from utils.metric import * 15 | import torch.nn.functional as F 16 | # from models.newnet import FastSal 17 | import tqdm 18 | 19 | def iou_loss(pred, mask): 20 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask) 21 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 22 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) 23 | 24 | pred = torch.sigmoid(pred) 25 | inter = ((pred*mask)*weit).sum(dim=(2,3)) 26 | union = ((pred+mask)*weit).sum(dim=(2,3)) 27 | wiou = 1-(inter+1)/(union-inter+1) 28 | loss_total= (wbce+wiou).mean()/wiou.size(0) 29 | return loss_total 30 | 31 | 32 | class FocalLoss(nn.Module): 33 | def __init__(self, alpha=0.3, gamma=2, logits=True, reduce=True): 34 | super(FocalLoss, self).__init__() 35 | self.alpha = alpha 36 | self.gamma = gamma 37 | self.logits = logits 38 | self.reduce = reduce 39 | 40 | def forward(self, inputs, targets): 41 | if self.logits: 42 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) 43 | else: 44 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False) 45 | pt = torch.exp(-BCE_loss) 46 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss 47 | 48 | if self.reduce: 49 | return torch.mean(F_loss) 50 | else: 51 | return F_loss 52 | 53 | 54 | 55 | def test(testLoader,fold, net, device): 56 | net.to(device) 57 | sig = torch.nn.Sigmoid() 58 | print('start test!') 59 | net.eval() 60 | with torch.no_grad(): 61 | # when in test stage, no grad 62 | acc = 0. # Accuracy 63 | SE = 0. # Sensitivity (Recall) 64 | SP = 0. # Specificity 65 | PC = 0. # Precision 66 | F1 = 0. # F1 Score 67 | JS = 0. # Jaccard Similarity 68 | DC = 0. # Dice Coefficient 69 | count = 0 70 | for image, label, path in tqdm.tqdm(testLoader): 71 | image = image.to(device=device, dtype=torch.float32) 72 | label = label.to(device=device, dtype=torch.float32) 73 | # pred,p1,p2,p3,p4,e= net(image) 74 | e1,p2= net(image) 75 | pred = sig(p2) 76 | # print(pred.shape) 77 | acc += get_accuracy(pred,label) 78 | SE += get_sensitivity(pred,label) 79 | SP += get_specificity(pred,label) 80 | PC += get_precision(pred,label) 81 | F1 += get_F1(pred,label) 82 | JS += get_JS(pred,label) 83 | DC += get_DC(pred,label) 84 | count+=1 85 | acc = acc/count 86 | SE = SE/count 87 | SP = SP/count 88 | PC = PC/count 89 | F1 = F1/count 90 | JS = JS/count 91 | DC = DC/count 92 | score = JS + DC 93 | print("\tacc: {:.4f}\tSE: {:.4f}\tSP: {:.4f}\tPC: {:.4f}\tF1: {:.4f} \tJS: {:.4f}".format(acc, SE, SP, PC, F1, JS)) 94 | return acc, SE, SP, PC, F1, JS, DC, score 95 | 96 | 97 | 98 | 99 | def train_net(net, device, data_path,test_data_path, fold, epochs=40, batch_size=4, lr=0.00001): 100 | # 加载训练集 101 | isbi_dataset = XSDataset(data_path) 102 | train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset, 103 | batch_size=batch_size, 104 | shuffle=True, 105 | drop_last=True) 106 | 107 | test_dataset = XSDatatest(test_data_path) 108 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 109 | batch_size=1, 110 | shuffle=False) 111 | # 定义RMSprop算法11 112 | optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) 113 | # optimizer = optim.Adam(net.parameters(), lr=lr) 114 | 115 | # 定义Loss算法 116 | 117 | # criterion2 = nn.BCEWithLogitsLoss() 118 | # criterion3 = nn.BCELoss() 119 | criterion2 = FocalLoss() 120 | 121 | # best_loss统计,初始化为正无穷 122 | best_loss = float('inf') 123 | result = 0 124 | # 训练epochs次 125 | # f = open('./segmentation2/UNet.csv', 'w') 126 | # f.write('epoch,loss'+'\n') 127 | for epoch in range(epochs): 128 | # 训练模式 129 | net.train() 130 | # 按照batch_size开始训练 131 | for image, label, edge in train_loader: 132 | 133 | optimizer.zero_grad() 134 | # 将数据拷贝到device中 135 | image = image.to(device=device, dtype=torch.float32) 136 | label = label.to(device=device, dtype=torch.float32) 137 | edge = edge.to(device=device, dtype=torch.float32) 138 | # 使用网络参数,输出预测结果 139 | # pred, p1,p2,p3,p4,e= net(image) 140 | # # 计算loss 141 | e1,p2= net(image) 142 | loss = iou_loss(p2, label)+ criterion2(e1, edge) 143 | 144 | print('Train Epoch:{}'.format(epoch)) 145 | print('Loss/train', loss.item()) 146 | # print('Loss/edge', loss1.item()) 147 | 148 | # 保存loss值最小的网络参数 149 | if loss < best_loss: 150 | best_loss = loss 151 | # torch.save(net.state_dict(), './LUNG/fff'+str(fold)+'.pth') 152 | # 更新参数 153 | loss.backward() 154 | optimizer.step() 155 | # f.write(str(epoch)+","+str(best_loss.item())+"\n") 156 | if epoch>0: 157 | acc, SE, SP, PC, F1, JS, DC, score=test(test_loader,fold, net, device) 158 | if result < score: 159 | result = score 160 | # best_epoch = epoch 161 | torch.save(net.state_dict(), '/home/wangkun/BPGL/result/EPGNet/XS/EPNet_best_'+str(fold)+'.pth') 162 | with open("/home/wangkun/BPGL/result/EPGNet/XS/EPNet_"+str(fold)+".csv", "a") as w: 163 | w.write("epoch="+str(epoch)+",acc="+str(acc)+", SE="+str(SE)+",SP="+str(SP)+",PC="+str(PC)+",F1="+str(F1)+",JS="+str(JS)+",DC="+str(DC)+",Score="+str(score)+"\n") 164 | 165 | 166 | if __name__ == "__main__": 167 | import os 168 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 169 | def weights_init(m): 170 | classname = m.__class__.__name__ 171 | # print(classname) 172 | if classname.find('Conv') != -1: 173 | torch.nn.init.xavier_uniform_(m.weight.data) 174 | if m.bias is not None: 175 | torch.nn.init.constant_(m.bias.data, 0.0) 176 | seed=1234 177 | torch.manual_seed(seed) 178 | torch.cuda.manual_seed_all(seed) 179 | 180 | # 选择设备,有cuda用cuda,没有就用cpu 181 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 182 | # 加载网络,图片单通道1,分类为1。 183 | # net = UNet(n_channels=1, n_classes=1) 184 | # net = MyMGNet(n_channels=1, n_classes=1) 185 | net = CGRNet(n_channels=3, n_classes=1) 186 | # net = VGG19UNet_without_boudary(n_channels=1, n_classes=1) 187 | # net = R2U_Net(img_ch=1, output_ch=1) 188 | # net = CE_Net(num_classes=1, num_channels=1) 189 | # net = VGG19UNet(n_channels=1, n_classes=1) 190 | # net = AttU_Net(img_ch=1, output_ch=1) 191 | # net =get_fcn8s(n_class=1) 192 | # net = UNet_2Plus(in_channels=1, n_classes=1) 193 | # net = DenseUnet(in_ch=1, num_classes=1) 194 | # net = CPFNet() 195 | net.to(device=device) 196 | # 指定训练集地址,开始训练 197 | fold=1 198 | # data_path = "/home/wangkun/shape-attentive-unet/data/train_96" 199 | # data_path = "/home/wangkun/shape-attentive-unet/data/ISIC/train" 200 | # test_data_path = "/home/wangkun/shape-attentive-unet/data/ISIC/test" 201 | # data_path = "/home/wangkun/shape-attentive-unet/data/LUNG/train" 202 | # test_data_path = "/home/wangkun/shape-attentive-unet/data/LUNG/test" 203 | data_path = "/home/wangkun/data/dataset/XS/Train/" 204 | test_data_path = "/home/wangkun/data/dataset/XS/val/" 205 | 206 | 207 | 208 | 209 | train_net(net, device, data_path,test_data_path, fold) 210 | 211 | 212 | 213 | 214 | 215 | #by kun wang --------------------------------------------------------------------------------