├── CGRmodes
├── CGR.py
├── __pycache__
│ ├── CGR.cpython-36.pyc
│ ├── EPG.cpython-35.pyc
│ ├── EPG.cpython-36.pyc
│ ├── EPG2.cpython-36.pyc
│ ├── EPGbaseline.cpython-35.pyc
│ ├── layers.cpython-36.pyc
│ └── utils.cpython-36.pyc
├── layers.py
└── utils.py
├── README.md
├── __pycache__
├── evaluation.cpython-36.pyc
└── transform.cpython-36.pyc
├── data
├── __pycache__
│ ├── dataloader.cpython-36.pyc
│ └── dataloader2.cpython-36.pyc
├── dataloader.py
├── dataloader2.py
└── utils
│ ├── comm.py
│ ├── dice3D.py
│ ├── direct_field
│ ├── __pycache__
│ │ ├── df_cardia.cpython-36.pyc
│ │ ├── df_cardia.cpython-37.pyc
│ │ ├── utils_df.cpython-36.pyc
│ │ └── utils_df.cpython-37.pyc
│ ├── df_cardia.py
│ └── utils_df.py
│ ├── image_list.py
│ ├── init_net.py
│ ├── metrics.py
│ ├── utils_loss.py
│ └── vis_utils.py
├── evaluation.py
├── fig
├── 1.png
├── 2.png
└── 3.png
├── models
├── ._models.py
├── AttU_Net_model.py
├── BaseNet.py
├── F3net.py
├── GSConv.py
├── InfNet_Res2Net.py
├── LDF.py
├── LDunet.py
├── PraNet_Res2Net.py
├── PraNet_ResNet.py
├── R2U_Net_model.py
├── Res2Net_v1b.py
├── UNet_2Plus.py
├── __pycache__
│ ├── AttU_Net_model.cpython-35.pyc
│ ├── BaseNet.cpython-35.pyc
│ ├── BaseNet.cpython-36.pyc
│ ├── F3net.cpython-35.pyc
│ ├── GSConv.cpython-35.pyc
│ ├── GSConv.cpython-36.pyc
│ ├── InfNet_Res2Net.cpython-35.pyc
│ ├── InfNet_Res2Net.cpython-36.pyc
│ ├── LDF.cpython-35.pyc
│ ├── LDunet.cpython-35.pyc
│ ├── PraNet_Res2Net.cpython-35.pyc
│ ├── R2U_Net_model.cpython-35.pyc
│ ├── Res2Net_v1b.cpython-35.pyc
│ ├── UNet_2Plus.cpython-35.pyc
│ ├── UNet_2Plus.cpython-36.pyc
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-36.pyc
│ ├── attention_blocks.cpython-35.pyc
│ ├── attention_blocks.cpython-36.pyc
│ ├── cenet.cpython-35.pyc
│ ├── cenet.cpython-36.pyc
│ ├── custom_functions.cpython-35.pyc
│ ├── custom_functions.cpython-36.pyc
│ ├── deeplab_v3p.cpython-35.pyc
│ ├── denseunet_model.cpython-35.pyc
│ ├── fcn.cpython-35.pyc
│ ├── fcn.cpython-36.pyc
│ ├── init_weights.cpython-35.pyc
│ ├── init_weights.cpython-36.pyc
│ ├── layers.cpython-35.pyc
│ ├── layers.cpython-36.pyc
│ ├── models.cpython-35.pyc
│ ├── models.cpython-36.pyc
│ ├── multi_scale.cpython-35.pyc
│ ├── multi_scale.cpython-36.pyc
│ ├── multi_scale_module.cpython-35.pyc
│ ├── newnet.cpython-35.pyc
│ ├── norm.cpython-35.pyc
│ ├── norm.cpython-36.pyc
│ ├── resnet.cpython-35.pyc
│ ├── resnet.cpython-36.pyc
│ ├── unet.cpython-35.pyc
│ ├── unet.cpython-36.pyc
│ ├── vggunet.cpython-35.pyc
│ ├── wassp.cpython-35.pyc
│ ├── wassp.cpython-36.pyc
│ └── wnet.cpython-35.pyc
├── adaptive_avgmax_pool.py
├── attention_blocks.py
├── backbone
│ ├── DenseNet.py
│ ├── Res2Net.py
│ ├── ResNet.py
│ ├── VGGNet.py
│ ├── __init__.py
│ └── __pycache__
│ │ ├── Res2Net.cpython-35.pyc
│ │ ├── Res2Net.cpython-36.pyc
│ │ ├── __init__.cpython-35.pyc
│ │ └── __init__.cpython-36.pyc
├── cenet.py
├── custom_functions.py
├── deeplab
│ ├── LDF.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── aspp.cpython-35.pyc
│ │ ├── aspp.cpython-36.pyc
│ │ ├── decoder.cpython-35.pyc
│ │ └── decoder.cpython-36.pyc
│ ├── aspp.py
│ ├── backbone
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-35.pyc
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── drn.cpython-35.pyc
│ │ │ ├── drn.cpython-36.pyc
│ │ │ ├── mobilenet.cpython-35.pyc
│ │ │ ├── mobilenet.cpython-36.pyc
│ │ │ ├── resnet.cpython-35.pyc
│ │ │ ├── resnet.cpython-36.pyc
│ │ │ ├── xception.cpython-35.pyc
│ │ │ └── xception.cpython-36.pyc
│ │ ├── drn.py
│ │ ├── mobilenet.py
│ │ ├── resnet.py
│ │ └── xception.py
│ ├── decoder.py
│ └── sync_batchnorm
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-35.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── batchnorm.cpython-35.pyc
│ │ ├── batchnorm.cpython-36.pyc
│ │ ├── comm.cpython-35.pyc
│ │ ├── comm.cpython-36.pyc
│ │ ├── replicate.cpython-35.pyc
│ │ └── replicate.cpython-36.pyc
│ │ ├── batchnorm.py
│ │ ├── comm.py
│ │ ├── replicate.py
│ │ └── unittest.py
├── deeplab_v3p.py
├── denseunet_model.py
├── fcn.py
├── init_weights.py
├── layers.py
├── models.py
├── multi_scale.py
├── multi_scale_module.py
├── mynn.py
├── net.py
├── new2net.py
├── newnet.py
├── norm.py
├── pretrain
│ ├── SAMNet_with_ImageNet_pretrain.pth
│ └── __init__.py
├── resnet.py
├── segnet.py
├── test.py
├── unet.py
├── vggunet.py
├── wassp.py
└── wnet.py
├── test_XS2021.py
├── transform.py
└── trian_CGR_XS.py
/CGRmodes/__pycache__/CGR.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/CGR.cpython-36.pyc
--------------------------------------------------------------------------------
/CGRmodes/__pycache__/EPG.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/EPG.cpython-35.pyc
--------------------------------------------------------------------------------
/CGRmodes/__pycache__/EPG.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/EPG.cpython-36.pyc
--------------------------------------------------------------------------------
/CGRmodes/__pycache__/EPG2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/EPG2.cpython-36.pyc
--------------------------------------------------------------------------------
/CGRmodes/__pycache__/EPGbaseline.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/EPGbaseline.cpython-35.pyc
--------------------------------------------------------------------------------
/CGRmodes/__pycache__/layers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/layers.cpython-36.pyc
--------------------------------------------------------------------------------
/CGRmodes/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/CGRmodes/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/CGRmodes/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class conv(nn.Module):
5 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, padding='same', bias=False, bn=True, relu=False):
6 | super(conv, self).__init__()
7 | if '__iter__' not in dir(kernel_size):
8 | kernel_size = (kernel_size, kernel_size)
9 | if '__iter__' not in dir(stride):
10 | stride = (stride, stride)
11 | if '__iter__' not in dir(dilation):
12 | dilation = (dilation, dilation)
13 |
14 | if padding == 'same':
15 | width_pad_size = kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1)
16 | height_pad_size = kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1)
17 | elif padding == 'valid':
18 | width_pad_size = 0
19 | height_pad_size = 0
20 | else:
21 | if '__iter__' in dir(padding):
22 | width_pad_size = padding[0] * 2
23 | height_pad_size = padding[1] * 2
24 | else:
25 | width_pad_size = padding * 2
26 | height_pad_size = padding * 2
27 |
28 | width_pad_size = width_pad_size // 2 + (width_pad_size % 2 - 1)
29 | height_pad_size = height_pad_size // 2 + (height_pad_size % 2 - 1)
30 | pad_size = (width_pad_size, height_pad_size)
31 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_size, dilation, groups, bias=bias)
32 | self.reset_parameters()
33 |
34 | if bn is True:
35 | self.bn = nn.BatchNorm2d(out_channels)
36 | else:
37 | self.bn = None
38 |
39 | if relu is True:
40 | self.relu = nn.ReLU(inplace=True)
41 | else:
42 | self.relu = None
43 |
44 | def forward(self, x):
45 | x = self.conv(x)
46 | if self.bn is not None:
47 | x = self.bn(x)
48 | if self.relu is not None:
49 | x = self.relu(x)
50 | return x
51 |
52 | def reset_parameters(self):
53 | nn.init.kaiming_normal_(self.conv.weight)
54 |
55 |
56 | class self_attn(nn.Module):
57 | def __init__(self, in_channels, mode='hw'):
58 | super(self_attn, self).__init__()
59 |
60 | self.mode = mode
61 |
62 | self.query_conv = conv(in_channels, in_channels // 8, kernel_size=(1, 1))
63 | self.key_conv = conv(in_channels, in_channels // 8, kernel_size=(1, 1))
64 | self.value_conv = conv(in_channels, in_channels, kernel_size=(1, 1))
65 |
66 | self.gamma = nn.Parameter(torch.zeros(1))
67 | self.softmax = nn.Softmax(dim=-1)
68 |
69 | def forward(self, x):
70 | batch_size, channel, height, width = x.size()
71 |
72 | axis = 1
73 | if 'h' in self.mode:
74 | axis *= height
75 | if 'w' in self.mode:
76 | axis *= width
77 |
78 | view = (batch_size, -1, axis)
79 |
80 | projected_query = self.query_conv(x).view(*view).permute(0, 2, 1)
81 | projected_key = self.key_conv(x).view(*view)
82 |
83 | attention_map = torch.bmm(projected_query, projected_key)
84 | attention = self.softmax(attention_map)
85 | projected_value = self.value_conv(x).view(*view)
86 |
87 | out = torch.bmm(projected_value, attention.permute(0, 2, 1))
88 | out = out.view(batch_size, channel, height, width)
89 |
90 | out = self.gamma * out + x
91 | return out
--------------------------------------------------------------------------------
/CGRmodes/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from thop import profile
4 | from thop import clever_format
5 | from scipy.ndimage import map_coordinates
6 |
7 | from torch.optim.lr_scheduler import _LRScheduler
8 |
9 |
10 | class PolyLr(_LRScheduler):
11 | def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1):
12 | self.gamma = gamma
13 | self.max_iteration = max_iteration
14 | self.minimum_lr = minimum_lr
15 | self.warmup_iteration = warmup_iteration
16 |
17 | super(PolyLr, self).__init__(optimizer, last_epoch)
18 |
19 | def poly_lr(self, base_lr, step):
20 | return (base_lr - self.minimum_lr) * ((1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr
21 |
22 | def warmup_lr(self, base_lr, alpha):
23 | return base_lr * (1 / 10.0 * (1 - alpha) + alpha)
24 |
25 | def get_lr(self):
26 | if self.last_epoch < self.warmup_iteration:
27 | alpha = self.last_epoch / self.warmup_iteration
28 | lrs = [min(self.warmup_lr(base_lr, alpha), self.poly_lr(base_lr, self.last_epoch)) for base_lr in
29 | self.base_lrs]
30 | else:
31 | lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
32 |
33 | return lrs
34 |
35 | def clip_gradient(optimizer, grad_clip):
36 | for group in optimizer.param_groups:
37 | for param in group['params']:
38 | if param.grad is not None:
39 | param.grad.data.clamp_(-grad_clip, grad_clip)
40 |
41 |
42 | class AvgMeter(object):
43 | def __init__(self, num=40):
44 | self.num = num
45 | self.reset()
46 |
47 | def reset(self):
48 | self.val = 0
49 | self.avg = 0
50 | self.sum = 0
51 | self.count = 0
52 | self.losses = []
53 |
54 | def update(self, val, n=1):
55 | self.val = val
56 | self.sum += val * n
57 | self.count += n
58 | self.avg = self.sum / self.count
59 | self.losses.append(val)
60 |
61 | def show(self):
62 | return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):]))
63 |
64 |
65 | def CalParams(model, input_tensor):
66 | flops, params = profile(model, inputs=(input_tensor,))
67 | flops, params = clever_format([flops, params], "%.3f")
68 | print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params))
69 |
70 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CGRNet: Contour-Guided Graph Reasoning Network for Ambiguous Biomedical Image Segmentation
2 | [](https://opensource.org/licenses/MIT) [](http://makeapullrequest.com)
3 |
4 | > **Authors:**
5 | > Kun Wang,
6 | > Xiaohong Zhang,
7 | > Yuting Lu,
8 | > Xiangbo Zhang,
9 | > Wei Zhang.
10 |
11 |
12 |
13 | ### 1.1. 🔥NEWS🔥 :
14 | - [2021/10/30]:fire: Release the inference code!
15 | - [2021/10/28] Create repository.
16 |
17 |
18 | ## Prerequisites
19 | - [Python 3.5](https://www.python.org/)
20 | - [Pytorch 1.1](http://pytorch.org/)
21 | - [OpenCV 4.0](https://opencv.org/)
22 | - [Numpy 1.15](https://numpy.org/)
23 | - [TensorboardX](https://github.com/lanpa/tensorboardX)
24 |
25 | ## Clone repository
26 | ```shell
27 | git clone https://github.com/DLWK/CGRNet.git
28 | cd CGRNet/
29 | ```
30 | ## Download dataset
31 | Download the datasets and unzip them into `data` folder
32 | - [COVID-19](https://medicalsegmentation.com/covid19/)
33 | - Download dataset from following [URL](https://drive.google.com/file/d/17Cs2JhKOKwt4usiAYJVJMnXfyZWySn3s/view?usp=sharing)
34 | - You can use our data/dataloader2.py to load the datasets.
35 | ## Training & Evaluation
36 | ```shell
37 | cd CGRNet/
38 | python3 train.py
39 | ################
40 | python3 test.py
41 |
42 | ```
43 |
44 | ## Demo
45 | ```shell
46 | from CGRmodes.CGR import CGRNet
47 | if __name__ == '__main__':
48 | ras =CGRNet(n_channels=3, n_classes=1).cuda()
49 | input_tensor = torch.randn(4, 3, 352, 352).cuda()
50 | out,out1 = ras(input_tensor)
51 | print(out.shape)
52 | ```
53 | ### 2.1 Overview framework
54 |
55 |
56 |
57 |
58 |
59 |
60 | ### 2.2 Visualization Results
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | ## Citation
75 | - If you find this work is helpful, please cite our paper
76 | ```
77 | @article{wang2022cgrnet,
78 | title={CGRNet: Contour-guided graph reasoning network for ambiguous biomedical image segmentation},
79 | author={Wang, Kun and Zhang, Xiaohong and Lu, Yuting and Zhang, Xiangbo and Zhang, Wei},
80 | journal={Biomedical Signal Processing and Control},
81 | volume={75},
82 | pages={103621},
83 | year={2022},
84 | publisher={Elsevier}
85 | ```
86 |
87 |
88 |
89 |
90 | # Tips
91 | :fire:If you have any questions about our work, please do not hesitate to contact us by emails.
92 | **[⬆ back to top](#0-preface)**
93 |
--------------------------------------------------------------------------------
/__pycache__/evaluation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/__pycache__/evaluation.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/transform.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/__pycache__/transform.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataloader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/__pycache__/dataloader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataloader2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/__pycache__/dataloader2.cpython-36.pyc
--------------------------------------------------------------------------------
/data/utils/comm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import pickle
4 |
5 | def get_world_size():
6 | if not dist.is_available():
7 | return 1
8 | if not dist.is_initialized():
9 | return 1
10 | return dist.get_world_size()
11 |
12 | def get_rank():
13 | if not dist.is_available():
14 | return 0
15 | if not dist.is_initialized():
16 | return 0
17 | return dist.get_rank()
18 |
19 | def synchronize():
20 | """
21 | Helper function to synchronize (barrier) among all processes when
22 | using distributed training
23 | """
24 | if not dist.is_available():
25 | return
26 | if not dist.is_initialized():
27 | return
28 | world_size = dist.get_world_size()
29 | if world_size == 1:
30 | return
31 | dist.barrier()
32 |
33 | def all_gather(data):
34 | """
35 | Run all_gather on arbitrary picklable data (not necessarily tensors)
36 | Args:
37 | data: any picklable object
38 | Returns:
39 | list[data]: list of data gathered from each rank
40 | """
41 | world_size = get_world_size()
42 | if world_size == 1:
43 | return [data]
44 |
45 | # serialized to a Tensor
46 | buffer = pickle.dumps(data)
47 | storage = torch.ByteStorage.from_buffer(buffer)
48 | tensor = torch.ByteTensor(storage).to("cuda")
49 |
50 | # obtain Tensor size of each rank
51 | local_size = torch.IntTensor([tensor.numel()]).to("cuda")
52 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
53 | dist.all_gather(size_list, local_size)
54 | size_list = [int(size.item()) for size in size_list]
55 | max_size = max(size_list)
56 |
57 | # receiving Tensor from all ranks
58 | # we pad the tensor because torch all_gather does not support
59 | # gathering tensors of different shapes
60 | tensor_list = []
61 | for _ in size_list:
62 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
63 | if local_size != max_size:
64 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
65 | tensor = torch.cat((tensor, padding), dim=0)
66 | dist.all_gather(tensor_list, tensor)
67 |
68 | data_list = []
69 | for size, tensor in zip(size_list, tensor_list):
70 | buffer = tensor.cpu().numpy().tobytes()[:size]
71 | data_list.append(pickle.loads(buffer))
72 |
73 | return data_list
--------------------------------------------------------------------------------
/data/utils/dice3D.py:
--------------------------------------------------------------------------------
1 | """
2 | author: Clément Zotti (clement.zotti@usherbrooke.ca)
3 | date: April 2017
4 |
5 | DESCRIPTION :
6 | The script provide helpers functions to handle nifti image format:
7 | - load_nii()
8 | - save_nii()
9 |
10 | to generate metrics for two images:
11 | - metrics()
12 |
13 | And it is callable from the command line (see below).
14 | Each function provided in this script has comments to understand
15 | how they works.
16 |
17 | HOW-TO:
18 |
19 | This script was tested for python 3.4.
20 |
21 | First, you need to install the required packages with
22 | pip install -r requirements.txt
23 |
24 | After the installation, you have two ways of running this script:
25 | 1) python metrics.py ground_truth/patient001_ED.nii.gz prediction/patient001_ED.nii.gz
26 | 2) python metrics.py ground_truth/ prediction/
27 |
28 | The first option will print in the console the dice and volume of each class for the given image.
29 | The second option wiil ouput a csv file where each images will have the dice and volume of each class.
30 |
31 |
32 | Link: http://acdc.creatis.insa-lyon.fr
33 |
34 | """
35 |
36 | import os
37 | from glob import glob
38 | import time
39 | import re
40 | import argparse
41 | import nibabel as nib
42 | # import pandas as pd
43 | from medpy.metric.binary import hd, dc
44 | import numpy as np
45 |
46 |
47 |
48 | HEADER = ["Name", "Dice LV", "Volume LV", "Err LV(ml)",
49 | "Dice RV", "Volume RV", "Err RV(ml)",
50 | "Dice MYO", "Volume MYO", "Err MYO(ml)"]
51 |
52 | #
53 | # Utils functions used to sort strings into a natural order
54 | #
55 | def conv_int(i):
56 | return int(i) if i.isdigit() else i
57 |
58 |
59 | def natural_order(sord):
60 | """
61 | Sort a (list,tuple) of strings into natural order.
62 |
63 | Ex:
64 |
65 | ['1','10','2'] -> ['1','2','10']
66 |
67 | ['abc1def','ab10d','b2c','ab1d'] -> ['ab1d','ab10d', 'abc1def', 'b2c']
68 |
69 | """
70 | if isinstance(sord, tuple):
71 | sord = sord[0]
72 | return [conv_int(c) for c in re.split(r'(\d+)', sord)]
73 |
74 |
75 | #
76 | # Utils function to load and save nifti files with the nibabel package
77 | #
78 | def load_nii(img_path):
79 | """
80 | Function to load a 'nii' or 'nii.gz' file, The function returns
81 | everyting needed to save another 'nii' or 'nii.gz'
82 | in the same dimensional space, i.e. the affine matrix and the header
83 |
84 | Parameters
85 | ----------
86 |
87 | img_path: string
88 | String with the path of the 'nii' or 'nii.gz' image file name.
89 |
90 | Returns
91 | -------
92 | Three element, the first is a numpy array of the image values,
93 | the second is the affine transformation of the image, and the
94 | last one is the header of the image.
95 | """
96 | nimg = nib.load(img_path)
97 | return nimg.get_data(), nimg.affine, nimg.header
98 |
99 |
100 | def save_nii(img_path, data, affine, header):
101 | """
102 | Function to save a 'nii' or 'nii.gz' file.
103 |
104 | Parameters
105 | ----------
106 |
107 | img_path: string
108 | Path to save the image should be ending with '.nii' or '.nii.gz'.
109 |
110 | data: np.array
111 | Numpy array of the image data.
112 |
113 | affine: list of list or np.array
114 | The affine transformation to save with the image.
115 |
116 | header: nib.Nifti1Header
117 | The header that define everything about the data
118 | (pleasecheck nibabel documentation).
119 | """
120 | nimg = nib.Nifti1Image(data, affine=affine, header=header)
121 | nimg.to_filename(img_path)
122 |
123 |
124 | #
125 | # Functions to process files, directories and metrics
126 | #
127 | def metrics(img_gt, img_pred, voxel_size):
128 | """
129 | Function to compute the metrics between two segmentation maps given as input.
130 |
131 | Parameters
132 | ----------
133 | img_gt: np.array
134 | Array of the ground truth segmentation map.
135 |
136 | img_pred: np.array
137 | Array of the predicted segmentation map.
138 |
139 | voxel_size: list, tuple or np.array
140 | The size of a voxel of the images used to compute the volumes.
141 |
142 | Return
143 | ------
144 | A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml),
145 | Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)]
146 | """
147 |
148 | if img_gt.ndim != img_pred.ndim:
149 | raise ValueError("The arrays 'img_gt' and 'img_pred' should have the "
150 | "same dimension, {} against {}".format(img_gt.ndim,
151 | img_pred.ndim))
152 |
153 | res = []
154 | # Loop on each classes of the input images
155 | for c in [3, 1, 2]:
156 | # Copy the gt image to not alterate the input
157 | gt_c_i = np.copy(img_gt)
158 | gt_c_i[gt_c_i != c] = 0
159 |
160 | # Copy the pred image to not alterate the input
161 | pred_c_i = np.copy(img_pred)
162 | pred_c_i[pred_c_i != c] = 0
163 |
164 | # Clip the value to compute the volumes
165 | gt_c_i = np.clip(gt_c_i, 0, 1)
166 | pred_c_i = np.clip(pred_c_i, 0, 1)
167 |
168 | # Compute the Dice
169 | dice = dc(gt_c_i, pred_c_i)
170 |
171 | # Compute volume
172 | volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.
173 | volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.
174 |
175 | # res += [dice, volpred, volpred-volgt]
176 | res += [dice]
177 |
178 | return res
179 |
180 |
181 | def compute_metrics_on_files(path_gt, path_pred):
182 | """
183 | Function to give the metrics for two files
184 |
185 | Parameters
186 | ----------
187 |
188 | path_gt: string
189 | Path of the ground truth image.
190 |
191 | path_pred: string
192 | Path of the predicted image.
193 | """
194 | gt, _, header = load_nii(path_gt)
195 | pred, _, _ = load_nii(path_pred)
196 | zooms = header.get_zooms()
197 |
198 | name = os.path.basename(path_gt)
199 | name = name.split('.')[0]
200 | res = metrics(gt, pred, zooms)
201 | res = ["{:.3f}".format(r) for r in res]
202 |
203 | formatting = "{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}"
204 | print(formatting.format(*HEADER))
205 | print(formatting.format(name, *res))
206 |
207 |
208 | def compute_metrics_on_directories(dir_gt, dir_pred):
209 | """
210 | Function to generate a csv file for each images of two directories.
211 |
212 | Parameters
213 | ----------
214 |
215 | path_gt: string
216 | Directory of the ground truth segmentation maps.
217 |
218 | path_pred: string
219 | Directory of the predicted segmentation maps.
220 | """
221 | lst_gt = sorted(glob(os.path.join(dir_gt, '*')), key=natural_order)
222 | lst_pred = sorted(glob(os.path.join(dir_pred, '*')), key=natural_order)
223 |
224 | res = []
225 | for p_gt, p_pred in zip(lst_gt, lst_pred):
226 | if os.path.basename(p_gt) != os.path.basename(p_pred):
227 | raise ValueError("The two files don't have the same name"
228 | " {}, {}.".format(os.path.basename(p_gt),
229 | os.path.basename(p_pred)))
230 |
231 | gt, _, header = load_nii(p_gt)
232 | pred, _, _ = load_nii(p_pred)
233 | zooms = header.get_zooms()
234 | res.append(metrics(gt, pred, zooms))
235 |
236 | lst_name_gt = [os.path.basename(gt).split(".")[0] for gt in lst_gt]
237 | res = [[n,] + r for r, n in zip(res, lst_name_gt)]
238 | df = pd.DataFrame(res, columns=HEADER)
239 | df.to_csv("results_{}.csv".format(time.strftime("%Y%m%d_%H%M%S")), index=False)
240 |
241 | def main(path_gt, path_pred):
242 | """
243 | Main function to select which method to apply on the input parameters.
244 | """
245 | if os.path.isfile(path_gt) and os.path.isfile(path_pred):
246 | compute_metrics_on_files(path_gt, path_pred)
247 | elif os.path.isdir(path_gt) and os.path.isdir(path_pred):
248 | compute_metrics_on_directories(path_gt, path_pred)
249 | else:
250 | raise ValueError(
251 | "The paths given needs to be two directories or two files.")
252 |
253 |
254 | if __name__ == "__main__":
255 | # parser = argparse.ArgumentParser(
256 | # description="Script to compute ACDC challenge metrics.")
257 | # parser.add_argument("GT_IMG", type=str, help="Ground Truth image")
258 | # parser.add_argument("PRED_IMG", type=str, help="Predicted image")
259 | # args = parser.parse_args()
260 | # main(args.GT_IMG, args.PRED_IMG)
261 |
262 | ##############################################################
263 | gt = np.random.randint(0, 4, size=(224, 224, 100))
264 | print(np.unique(gt))
265 | pred = np.array(gt)
266 | pred[pred==2] = 3
267 | result = metrics(gt, pred, voxel_size=(224, 224, 100))
268 | print(result)
--------------------------------------------------------------------------------
/data/utils/direct_field/__pycache__/df_cardia.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/utils/direct_field/__pycache__/df_cardia.cpython-36.pyc
--------------------------------------------------------------------------------
/data/utils/direct_field/__pycache__/df_cardia.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/utils/direct_field/__pycache__/df_cardia.cpython-37.pyc
--------------------------------------------------------------------------------
/data/utils/direct_field/__pycache__/utils_df.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/utils/direct_field/__pycache__/utils_df.cpython-36.pyc
--------------------------------------------------------------------------------
/data/utils/direct_field/__pycache__/utils_df.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/data/utils/direct_field/__pycache__/utils_df.cpython-37.pyc
--------------------------------------------------------------------------------
/data/utils/direct_field/df_cardia.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import ndimage
3 | import math
4 | import cv2
5 | from PIL import Image
6 |
7 | def direct_field(a, norm=True):
8 | """ a: np.ndarray, (h, w)
9 | """
10 | if a.ndim == 3:
11 | a = np.squeeze(a)
12 |
13 | h, w = a.shape
14 |
15 | a_Image = Image.fromarray(a)
16 | a = a_Image.resize((w, h), Image.NEAREST)
17 | a = np.array(a)
18 |
19 | accumulation = np.zeros((2, h, w), dtype=np.float32)
20 | for i in np.unique(a)[1:]:
21 | # b, ind = ndimage.distance_transform_edt(a==i, return_indices=True)
22 | # c = np.indices((h, w))
23 | # diff = c - ind
24 | # dr = np.sqrt(np.sum(diff ** 2, axis=0))
25 |
26 | img = (a == i).astype(np.uint8)
27 | dst, labels = cv2.distanceTransformWithLabels(img, cv2.DIST_L2, cv2.DIST_MASK_PRECISE, labelType=cv2.DIST_LABEL_PIXEL)
28 | index = np.copy(labels)
29 | index[img > 0] = 0
30 | place = np.argwhere(index > 0)
31 | nearCord = place[labels-1,:]
32 | x = nearCord[:, :, 0]
33 | y = nearCord[:, :, 1]
34 | nearPixel = np.zeros((2, h, w))
35 | nearPixel[0,:,:] = x
36 | nearPixel[1,:,:] = y
37 | grid = np.indices(img.shape)
38 | grid = grid.astype(float)
39 | diff = grid - nearPixel
40 | if norm:
41 | dr = np.sqrt(np.sum(diff**2, axis = 0))
42 | else:
43 | dr = np.ones_like(img)
44 |
45 | # direction = np.zeros((2, h, w), dtype=np.float32)
46 | # direction[0, b>0] = np.divide(diff[0, b>0], dr[b>0])
47 | # direction[1, b>0] = np.divide(diff[1, b>0], dr[b>0])
48 |
49 | direction = np.zeros((2, h, w), dtype=np.float32)
50 | direction[0, img>0] = np.divide(diff[0, img>0], dr[img>0])
51 | direction[1, img>0] = np.divide(diff[1, img>0], dr[img>0])
52 |
53 | accumulation[:, img>0] = 0
54 | accumulation = accumulation + direction
55 |
56 | # mag, angle = cv2.cartToPolar(accumulation[0, ...], accumulation[1, ...])
57 | # for l in np.unique(a)[1:]:
58 | # mag_i = mag[a==l].astype(float)
59 | # t = 1 / mag_i * mag_i.max()
60 | # mag[a==l] = t
61 | # x, y = cv2.polarToCart(mag, angle)
62 | # accumulation = np.stack([x, y], axis=0)
63 |
64 | return accumulation
65 |
66 |
67 | if __name__ == "__main__":
68 | import matplotlib.pyplot as plt
69 | # gt_p = "/home/ffbian/chencheng/XieheCardiac/npydata/dianfen/16100000/gts/16100000_CINE_segmented_SAX_b3.npy"
70 | # gt = np.load(gt_p)[..., 9] # uint8
71 | # print(gt.shape)
72 |
73 | # a_Image = Image.fromarray(gt)
74 | # a = a_Image.resize((224, 224), Image.NEAREST)
75 | # a = np.array(a) # uint8
76 | # print(a.shape, np.unique(a))
77 |
78 | # # plt.imshow(a)
79 | # # plt.show()
80 |
81 | # ############################################################
82 | # direction = direct_field(gt)
83 |
84 | # theta = np.arctan2(direction[1,...], direction[0,...])
85 | # degree = theta * 180 / math.pi
86 | # degree = (degree + 180) / 360
87 |
88 | # plt.imshow(degree)
89 | # plt.show()
90 |
91 | ########################################################
92 | import json, time, pdb, h5py
93 | data_list = "/home/ffbian/chencheng/XieheCardiac/2DUNet/UNet/libs/datasets/train_new.json"
94 | data_list = "/root/chengfeng/Cardiac/source_code/libs/datasets/jsonLists/acdcList/Dense_TestList.json"
95 | with open(data_list, 'r') as f:
96 | data_infos = json.load(f)
97 |
98 | mag_stat = []
99 | st = time.time()
100 | for i, di in enumerate(data_infos):
101 | # img_p, times_idx = di
102 | # gt_p = img_p.replace("/imgs/", "/gts/")
103 | # gt = np.load(gt_p)[..., times_idx]
104 |
105 | img = h5py.File(di,'r')['image']
106 | gt = h5py.File(di,'r')['label']
107 | gt = np.array(gt).astype(np.float32)
108 |
109 | print(gt.shape)
110 | direction = direct_field(gt, False)
111 | # theta = np.arctan2(direction[1,...], direction[0,...])
112 | mag, angle = cv2.cartToPolar(direction[0, ...], direction[1, ...])
113 | # degree = theta * 180 / math.pi
114 | # degree = (degree + 180) / 360
115 | degree = angle / (2 * math.pi) * 255
116 | # degree = (theta - theta.min()) / (theta.max() - theta.min()) * 255
117 | # mag = np.sqrt(np.sum(direction ** 2, axis=0, keepdims=False))
118 |
119 |
120 | # 归一化
121 | # for l in np.unique(gt)[1:]:
122 | # mag_i = mag[gt==l].astype(float)
123 | # # mag[gt==l] = 1. - mag[gt==l] / np.max(mag[gt==l])
124 | # t = (mag_i - mag_i.min()) / (mag_i.max() - mag_i.min())
125 | # mag[gt==l] = np.exp(-10*t)
126 | # print(mag_i.max(), mag_i.min())
127 |
128 | # for l in np.unique(gt)[1:]:
129 | # mag_i = mag[gt==l].astype(float)
130 | # t = 1 / (mag_i) * mag_i.max()
131 | # # t = np.exp(-0.8*mag_i) * mag_i.max()
132 | # # t = 1 / np.sqrt(mag_i+1) * mag_i.max()
133 | # mag[gt==l] = t
134 | # # print(mag_i.max(), mag_i.min())
135 |
136 | # mag[mag>0] = 2 * np.exp(-0.8*(mag[mag>0]-1))
137 | # mag[mag>0] = 2 * np.exp(0.8*(mag[mag>0]-1))
138 |
139 |
140 | mag_stat.append(mag.max())
141 | # pdb.set_trace()
142 |
143 | # plt.imshow(degree)
144 | # plt.show()
145 |
146 | ######################
147 | fig, axs = plt.subplots(1, 3)
148 | axs[0].imshow(degree)
149 | axs[1].imshow(gt)
150 | axs[2].imshow(mag)
151 | plt.show()
152 |
153 | ######################
154 | if i % 100 == 0:
155 | print("\r\r{}/{} {:.4}s".format(i+1, len(data_infos), time.time()-st))
156 | print()
157 |
158 | print("total time: ", time.time()-st)
159 | print("Average time: ", (time.time()-st) / len(data_infos))
160 | # total time: 865.811030626297
161 | # Average time: 0.012969593759126428
162 |
163 | plt.hist(mag_stat)
164 | plt.show()
165 | print(mag_stat)
--------------------------------------------------------------------------------
/data/utils/direct_field/utils_df.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.ndimage import distance_transform_edt as distance
4 | from utils.utils_loss import one_hot, simplex, class2one_hot
5 |
6 | def one_hot2dist(seg: np.ndarray) -> np.ndarray:
7 | assert one_hot(torch.Tensor(seg), axis=0)
8 | C: int = len(seg)
9 |
10 | res = np.zeros_like(seg)
11 | for c in range(C):
12 | posmask = seg[c].astype(np.bool)
13 |
14 | if posmask.any():
15 | negmask = ~posmask
16 | res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
17 | return res
18 |
19 | def class2dist(seg: np.ndarray, C=4) -> np.ndarray:
20 | """ res: (C, H, W)
21 | """
22 | if seg.ndim == 2:
23 | seg_tensor = torch.Tensor(seg)
24 | elif seg.ndim == 3:
25 | seg_tensor = torch.Tensor(seg[0])
26 | elif seg.ndim == 4:
27 | seg_tensor = torch.Tensor(seg[0, ..., 0])
28 |
29 | seg_onehot = class2one_hot(seg_tensor, C).to(torch.float32)
30 |
31 | assert simplex(seg_onehot)
32 | res = one_hot2dist(seg_onehot[0].numpy())
33 | return res
34 |
35 |
--------------------------------------------------------------------------------
/data/utils/image_list.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | from __future__ import division
3 |
4 | import torch
5 | import numpy as np
6 |
7 | class ImageList(object):
8 | """
9 | Structure that holds a list of images (of possibly
10 | varying sizes) as a single tensor.
11 | This works by padding the images to the same size,
12 | and storing in a field the original sizes of each image
13 | """
14 |
15 | def __init__(self, tensors, image_sizes):
16 | """
17 | Arguments:
18 | tensors (tensor)
19 | image_sizes (list[tuple[int, int]])
20 | """
21 | self.tensors = tensors
22 | self.image_sizes = image_sizes
23 |
24 | def to(self, *args, **kwargs):
25 | cast_tensor = self.tensors.to(*args, **kwargs)
26 | return ImageList(cast_tensor, self.image_sizes)
27 |
28 |
29 | def to_image_list(tensors, size_divisible=0, return_size=False):
30 | """
31 | tensors can be an ImageList, a torch.Tensor or
32 | an iterable of Tensors. It can't be a numpy array.
33 | When tensors is an iterable of Tensors, it pads
34 | the Tensors with zeros so that they have the same
35 | shape
36 | """
37 | if isinstance(tensors, torch.Tensor) and size_divisible > 0:
38 | tensors = [tensors]
39 |
40 | if isinstance(tensors, ImageList):
41 | return tensors
42 | elif isinstance(tensors, torch.Tensor):
43 | # single tensor shape can be inferred
44 | if tensors.dim() == 3:
45 | tensors = tensors[None]
46 | assert tensors.dim() == 4
47 | image_sizes = [tensor.shape[-2:] for tensor in tensors]
48 | return ImageList(tensors, image_sizes)
49 | elif isinstance(tensors, (tuple, list, np.ndarray)):
50 | max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
51 |
52 | # TODO Ideally, just remove this and let me model handle arbitrary
53 | # input sizs
54 | if size_divisible > 0:
55 | import math
56 |
57 | stride = size_divisible
58 | max_size = list(max_size)
59 | max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
60 | max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
61 | max_size = tuple(max_size)
62 |
63 | batch_shape = (len(tensors),) + max_size
64 | batched_imgs = tensors[0].new(*batch_shape).zero_()
65 | for img, pad_img in zip(tensors, batched_imgs):
66 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
67 |
68 | image_sizes = [im.shape[-2:] for im in tensors]
69 |
70 | # return ImageList(batched_imgs, image_sizes)
71 | if return_size:
72 | return batched_imgs, image_sizes
73 | else:
74 | return batched_imgs
75 | else:
76 | raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))
77 |
78 |
--------------------------------------------------------------------------------
/data/utils/init_net.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | def init_weights(net, init_type='normal', gain=0.02):
4 | def init_func(m):
5 | classname = m.__class__.__name__
6 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
7 | if init_type == 'normal':
8 | nn.init.normal_(m.weight.data, 0.0, gain)
9 | elif init_type == 'xavier':
10 | nn.init.xavier_normal_(m.weight.data, gain=gain)
11 | elif init_type == 'kaiming':
12 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
13 | elif init_type == 'orthogonal':
14 | nn.init.orthogonal_(m.weight.data, gain=gain)
15 | else:
16 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
17 | if hasattr(m, 'bias') and m.bias is not None:
18 | nn.init.constant_(m.bias.data, 0.0)
19 | elif classname.find('BatchNorm2d') != -1:
20 | nn.init.normal_(m.weight.data, 1.0, gain)
21 | nn.init.constant_(m.bias.data, 0.0)
22 |
23 | print('initialize network with %s' % init_type)
24 | net.apply(init_func)
--------------------------------------------------------------------------------
/data/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from hausdorff import hausdorff_distance
4 | from medpy.metric.binary import hd, dc
5 |
6 | def dice(pred, target):
7 | pred = pred.contiguous()
8 | target = target.contiguous()
9 | smooth = 0.00001
10 |
11 | # intersection = (pred * target).sum(dim=2).sum(dim=2)
12 | pred_flat = pred.view(1, -1)
13 | target_flat = target.view(1, -1)
14 |
15 | intersection = (pred_flat * target_flat).sum().item()
16 |
17 | # loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
18 | dice = (2 * intersection + smooth) / (pred_flat.sum().item() + target_flat.sum().item() + smooth)
19 | return dice
20 |
21 | def dice3D(img_gt, img_pred, voxel_size):
22 | """
23 | Function to compute the metrics between two segmentation maps given as input.
24 |
25 | Parameters
26 | ----------
27 | img_gt: np.array
28 | Array of the ground truth segmentation map.
29 |
30 | img_pred: np.array
31 | Array of the predicted segmentation map.
32 |
33 | voxel_size: list, tuple or np.array
34 | The size of a voxel of the images used to compute the volumes.
35 |
36 | Return
37 | ------
38 | A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml),
39 | Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)]
40 | """
41 |
42 | if img_gt.ndim != img_pred.ndim:
43 | raise ValueError("The arrays 'img_gt' and 'img_pred' should have the "
44 | "same dimension, {} against {}".format(img_gt.ndim,
45 | img_pred.ndim))
46 |
47 | res = []
48 | # Loop on each classes of the input images
49 | for c in [3, 1, 2]:
50 | # Copy the gt image to not alterate the input
51 | gt_c_i = np.copy(img_gt)
52 | gt_c_i[gt_c_i != c] = 0
53 |
54 | # Copy the pred image to not alterate the input
55 | pred_c_i = np.copy(img_pred)
56 | pred_c_i[pred_c_i != c] = 0
57 |
58 | # Clip the value to compute the volumes
59 | gt_c_i = np.clip(gt_c_i, 0, 1)
60 | pred_c_i = np.clip(pred_c_i, 0, 1)
61 |
62 | # Compute the Dice
63 | dice = dc(gt_c_i, pred_c_i)
64 |
65 | # Compute volume
66 | # volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.
67 | # volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.
68 |
69 | # res += [dice, volpred, volpred-volgt]
70 | res += [dice]
71 |
72 | return res
73 |
74 | def hd_3D(img_pred, img_gt, labels=[3, 1, 2]):
75 | res = []
76 | for c in labels:
77 | gt_c_i = np.copy(img_gt)
78 | gt_c_i[gt_c_i != c] = 0
79 |
80 | pred_c_i = np.copy(img_pred)
81 | pred_c_i[pred_c_i != c] = 0
82 |
83 | gt_c_i = np.clip(gt_c_i, 0, 1)
84 | pred_c_i = np.clip(pred_c_i, 0, 1)
85 |
86 | if np.sum(pred_c_i) == 0 or np.sum(gt_c_i) == 0:
87 | hausdorff = 0
88 | else:
89 | hausdorff = hd(pred_c_i, gt_c_i)
90 |
91 | res += [hausdorff]
92 |
93 | return res
94 |
95 | def cal_hausdorff_distance(pred,target):
96 |
97 | pred = np.array(pred.contiguous())
98 | target = np.array(target.contiguous())
99 | result = hausdorff_distance(pred,target,distance="euclidean")
100 |
101 | return result
102 |
103 | def make_one_hot(input, num_classes):
104 | """Convert class index tensor to one hot encoding tensor.
105 | Args:
106 | input: A tensor of shape [N, 1, *]
107 | num_classes: An int of number of class
108 | Returns:
109 | A tensor of shape [N, num_classes, *]
110 | """
111 | shape = np.array(input.shape)
112 | shape[1] = num_classes
113 | shape = tuple(shape)
114 | result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1)
115 | # result = result.scatter_(1, input.cpu(), 1)
116 |
117 | return result
118 |
119 | def match_pred_gt(pred, gt):
120 | """ pred: (1, C, H, W)
121 | gt: (1, C, H, W)
122 | """
123 | gt_labels = torch.unique(gt, sorted=True)[1:]
124 | pred_labels = torch.unique(pred, sorted=True)[1:]
125 |
126 | if len(gt_labels) != 0 and len(pred_labels) != 0:
127 | dice_Matrix = torch.zeros((len(pred_labels), len(gt_labels)))
128 | for i, pl in enumerate(pred_labels):
129 | pred_i = torch.tensor(pred==pl, dtype=torch.float)
130 | for j, gl in enumerate(gt_labels):
131 | dice_Matrix[i, j] = dice(make_one_hot(pred_i, 2)[0], make_one_hot(gt==gl, 2)[0])
132 |
133 | # max_axis0 = np.max(dice_Matrix, axis=0)
134 | max_arg0 = np.argmax(dice_Matrix, axis=0)
135 | else:
136 | return torch.zeros_like(pred)
137 |
138 | pred_match = torch.zeros_like(pred)
139 | for i, arg in enumerate(max_arg0):
140 | pred_match[pred==pred_labels[arg]] = i + 1
141 | return pred_match
142 |
143 | if __name__ == "__main__":
144 | npy_path = "/home/fcheng/Cardia/source_code/logs/logs_df_50000/eval_pp_test/200.npy"
145 | pred_df, gt_df = np.load(npy_p)
146 |
--------------------------------------------------------------------------------
/data/utils/utils_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from typing import List, Set, Iterable
4 |
5 | def uniq(a: Tensor) -> Set:
6 | return set(torch.unique(a.cpu()).numpy())
7 |
8 | def sset(a: Tensor, sub: Iterable) -> bool:
9 | return uniq(a).issubset(sub)
10 |
11 | def simplex(t: Tensor, axis=1) -> bool:
12 | _sum = t.sum(axis).type(torch.float32)
13 | _ones = torch.ones_like(_sum, dtype=torch.float32)
14 | return torch.allclose(_sum, _ones)
15 |
16 | def one_hot(t: Tensor, axis=1) -> bool:
17 | return simplex(t, axis) and sset(t, [0, 1])
18 |
19 | # switch between representations
20 | def probs2class(probs: Tensor) -> Tensor:
21 | b, _, w, h = probs.shape # type: Tuple[int, int, int, int]
22 | assert simplex(probs)
23 |
24 | res = probs.argmax(dim=1)
25 | assert res.shape == (b, w, h)
26 |
27 | return res
28 |
29 | def class2one_hot(seg: Tensor, C: int) -> Tensor:
30 | if len(seg.shape) == 2: # Only w, h, used by the dataloader
31 | seg = seg.unsqueeze(dim=0)
32 | assert sset(seg, list(range(C)))
33 |
34 | b, w, h = seg.shape # type: Tuple[int, int, int]
35 |
36 | res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32)
37 | assert res.shape == (b, C, w, h)
38 | assert one_hot(res)
39 |
40 | return res
41 |
42 | def probs2one_hot(probs: Tensor) -> Tensor:
43 | _, C, _, _ = probs.shape
44 | assert simplex(probs)
45 |
46 | res = class2one_hot(probs2class(probs), C)
47 | assert res.shape == probs.shape
48 | assert one_hot(res)
49 |
50 | return res
--------------------------------------------------------------------------------
/data/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import colorsys
3 | import random
4 | import cv2
5 |
6 | def get_n_hls_colors(num):
7 | hls_colors = []
8 | i = 0
9 | step = 360.0 / num
10 | while i < 360:
11 | h = i
12 | s = 90 + random.random() * 10
13 | l = 50 + random.random() * 10
14 | _hlsc = [h / 360.0, l / 100.0, s / 100.0]
15 | hls_colors.append(_hlsc)
16 | i += step
17 |
18 | return hls_colors
19 |
20 | def ncolors(num):
21 | rgb_colors = []
22 | if num < 1:
23 | return np.array(rgb_colors)
24 | hls_colors = get_n_hls_colors(num)
25 | for hlsc in hls_colors:
26 | _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2])
27 | rgb_colors.append([_r, _g, _b])
28 | # r, g, b = [int(x * 255.0) for x in (_r, _g, _b)]
29 | # rgb_colors.append([r, g, b])
30 |
31 | return np.array(rgb_colors)
32 |
33 | def random_colors(N, bright=True):
34 | """
35 | Generate random colors.
36 | To get visually distinct colors, generate them in HSV space then
37 | convert to RGB.
38 | """
39 | brightness = 1.0 if bright else 0.7
40 | hsv = [(i / N, 1, brightness) for i in range(N)]
41 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
42 | # random.shuffle(colors)
43 | return colors
44 |
45 | def mask2png(mask, file_name=None, suffix="png"):
46 | """ mask: (w, h)
47 | img_rgb: (w, h, rgb)
48 | """
49 | nums = np.unique(mask)[1:]
50 | if len(nums) < 1:
51 | colors = np.array([[0,0,0]])
52 | else:
53 | # colors = ncolors(len(nums))
54 | colors = (np.array(random_colors(len(nums))) * 255).astype(int)
55 | colors = np.insert(colors, 0, [0,0,0], 0)
56 |
57 | # 保证mask中的值为1-N连续
58 | mask_ordered = np.zeros_like(mask)
59 | for cnt, l in enumerate(nums, 1):
60 | mask_ordered[mask==l] = cnt
61 |
62 | im_rgb = colors[mask_ordered.astype(int)]
63 | if file_name is not None:
64 | cv2.imwrite(file_name+"."+suffix, im_rgb[:, :, ::-1])
65 | return im_rgb
66 |
67 | def apply_mask(image, mask, color, alpha=0.5, scale=1):
68 | """Apply the given mask to the image.
69 | """
70 | for c in range(3):
71 | image[:, :, c] = np.where(mask == 1,
72 | image[:, :, c] *
73 | (1 - alpha) + alpha * color[c] * scale,
74 | image[:, :, c])
75 | return image
76 |
77 | def img_mask_png(image, mask, file_name=None, alpha=0.5, suffix="png"):
78 | """ mask: (h, w)
79 | image: (h, w, rgb)
80 | """
81 | nums = np.unique(mask)[1:]
82 | if len(nums) < 1:
83 | colors = np.array([[0,0,0]])
84 | else:
85 | colors = ncolors(len(nums))
86 | colors = np.insert(colors, 0, [0,0,0], 0)
87 |
88 | # 保证mask中的值为1-N连续
89 | mask_ordered = np.zeros_like(mask)
90 | for cnt, l in enumerate(nums, 1):
91 | mask_ordered[mask==l] = cnt
92 |
93 | # mask_rgb = colors[mask_ordered.astype(int)]
94 | mix_im = image.copy()
95 | for i in np.unique(mask_ordered)[1:]:
96 | mix_im = apply_mask(mix_im, mask_ordered==i, colors[int(i)], alpha=alpha, scale=255)
97 |
98 | if file_name is not None:
99 | cv2.imwrite(file_name+"."+suffix, mix_im[:, :, ::-1])
100 | return mix_im
101 |
102 | def _find_contour(mask):
103 | # _, contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # 顶点
104 | _, contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
105 |
106 | cont = np.zeros_like(mask)
107 | for contour in contours:
108 | cont[contour[:,:,1], contour[:,:,0]] = 1
109 | return cont
110 |
111 | def masks_to_contours(masks):
112 | # 包含多个区域
113 | nums = np.unique(masks)[1:]
114 | cont_mask = np.zeros_like(masks)
115 | for i in nums:
116 | cont_mask += _find_contour(masks==i)
117 | return (cont_mask>0).astype(int)
118 |
119 | def batchToColorImg(batch, minv=None, maxv=None, scale=255.):
120 | """ batch: (N, H, W, C)
121 | """
122 | if batch.ndim == 3:
123 | N, H, W = batch.shape
124 | elif batch.ndim == 4:
125 | N, H, W, _ = batch.shape
126 | colorImg = np.zeros(shape=(N, H, W, 3))
127 | for i in range(N):
128 | if minv is None:
129 | a = (batch[i] - batch[i].min()) / (batch[i].max() - batch[i].min()) * 255
130 | else:
131 | a = (batch[i] - minv) / (maxv - minv) * scale
132 | a = cv2.applyColorMap(a.astype(np.uint8), cv2.COLORMAP_JET)
133 | colorImg[i, ...] = a[..., ::-1] / 255.
134 | return colorImg
135 |
136 | if __name__ == "__main__":
137 | a = np.zeros((100, 100))
138 | a[0:5, 3:8] = 1
139 | a[75:85, 85:95] = 2
140 |
141 | # colors = ncolors(2)[::-1]
142 | colors = np.array(random_colors(2)) * 255
143 | colors = np.insert(colors, 0, [0, 0, 0], 0)
144 |
145 | b = colors[a.astype(int)].astype(np.uint8)
146 | import cv2, skimage
147 | # skimage.io.imsave("test_io.png", b)
148 | # # cv2.imwrite("test.jpg", b[:, :, ::-1])
149 | # print()
150 |
151 | # mask2png(a, "test")
152 |
153 | ################################################
154 | # img_mask_png(b, a, "test")
155 |
156 | #############################################
157 | # cont_mask = find_contours(a)
158 | # print()
159 | # # skimage.io.imsave("test_cont.png", cont_mask)
160 | # b[cont_mask>0, :] = [255, 255, 255]
161 | # skimage.io.imsave("test_cont.png", b)
162 |
163 | gt0 = skimage.io.imread("gt0.png", as_gray=False)
164 | print()
165 | gt0[gt0==54] = 0
166 | # cont_mask = find_contours(gt0==237) # Array([ 18, 54, 73, 237], dtype=uint8)
167 | # cont_mask += find_contours(gt0==18)
168 | # cont_mask += find_contours(gt0==73)
169 | cont_mask = masks_to_contours(gt0)
170 |
171 | colors = np.array(random_colors(1)) * 255
172 | colors = np.insert(colors, 0, [0, 0, 0], 0)
173 | cont_mask = colors[cont_mask.astype(int)].astype(np.uint8)
174 |
175 | skimage.io.imsave("test_cont.png", cont_mask)
176 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # SR : Segmentation Result
4 | # GT : Ground Truth
5 |
6 | def get_accuracy(SR,GT,threshold=0.5):
7 | SR = SR > threshold
8 | GT = GT == torch.max(GT)
9 | corr = torch.sum(SR==GT)
10 | tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3)
11 | acc = float(corr)/float(tensor_size)
12 |
13 | return acc
14 |
15 | def get_sensitivity(SR,GT,threshold=0.5):
16 | # Sensitivity == Recall
17 | SR = SR > threshold
18 | GT = GT == torch.max(GT)
19 |
20 | # TP : True Positive
21 | # FN : False Negative
22 | TP = ((SR==1)+(GT==1))==2
23 | FN = ((SR==0)+(GT==1))==2
24 |
25 | SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6)
26 |
27 | return SE
28 |
29 | def get_specificity(SR,GT,threshold=0.5):
30 | SR = SR > threshold
31 | GT = GT == torch.max(GT)
32 |
33 | # TN : True Negative
34 | # FP : False Positive
35 | TN = ((SR==0)+(GT==0))==2
36 | FP = ((SR==1)+(GT==0))==2
37 |
38 | SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6)
39 |
40 | return SP
41 |
42 | def get_precision(SR,GT,threshold=0.5):
43 | SR = SR > threshold
44 | GT = GT == torch.max(GT)
45 |
46 | # TP : True Positive
47 | # FP : False Positive
48 | TP = ((SR==1)+(GT==1))==2
49 | FP = ((SR==1)+(GT==0))==2
50 |
51 | PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6)
52 |
53 | return PC
54 |
55 | def get_F1(SR,GT,threshold=0.5):
56 | # Sensitivity == Recall
57 | SE = get_sensitivity(SR,GT,threshold=threshold)
58 | PC = get_precision(SR,GT,threshold=threshold)
59 |
60 | F1 = 2*SE*PC/(SE+PC + 1e-6)
61 |
62 | return F1
63 |
64 | def get_JS(SR,GT,threshold=0.5):
65 | # JS : Jaccard similarity
66 | SR = SR > threshold
67 | GT = GT == torch.max(GT)
68 |
69 | Inter = torch.sum((SR+GT)==2)
70 | Union = torch.sum((SR+GT)>=1)
71 |
72 | JS = float(Inter)/(float(Union) + 1e-6)
73 |
74 | return JS
75 |
76 | def get_DC(SR,GT,threshold=0.5):
77 |
78 | # DC : Dice Coefficient
79 | SR = SR > threshold
80 | GT = GT == torch.max(GT)
81 |
82 | Inter = torch.sum((SR+GT)==2)
83 | DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6)
84 |
85 | return DC
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/fig/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/fig/1.png
--------------------------------------------------------------------------------
/fig/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/fig/2.png
--------------------------------------------------------------------------------
/fig/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/fig/3.png
--------------------------------------------------------------------------------
/models/._models.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/._models.py
--------------------------------------------------------------------------------
/models/GSConv.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch.nn as nn
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.nn.modules.conv import _ConvNd
10 | from torch.nn.modules.utils import _pair
11 | import numpy as np
12 | import math
13 | from . import norm as mynn
14 | from . import custom_functions as myF
15 |
16 | class GatedSpatialConv2d(_ConvNd):
17 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
18 | padding=0, dilation=1, groups=1, bias=False):
19 | """
20 |
21 | :param in_channels:
22 | :param out_channels:
23 | :param kernel_size:
24 | :param stride:
25 | :param padding:
26 | :param dilation:
27 | :param groups:
28 | :param bias:
29 | """
30 |
31 | kernel_size = _pair(kernel_size)
32 | stride = _pair(stride)
33 | padding = _pair(padding)
34 | dilation = _pair(dilation)
35 | super(GatedSpatialConv2d, self).__init__(
36 | in_channels, out_channels, kernel_size, stride, padding, dilation,
37 | False, _pair(0), groups, bias, 'zeros')
38 | self._gate_conv = nn.Sequential(
39 | mynn.Norm2d(in_channels+1),
40 | nn.Conv2d(in_channels+1, in_channels+1, 1),
41 | nn.ReLU(),
42 | nn.Conv2d(in_channels+1, 1, 1),
43 | mynn.Norm2d(1),
44 | nn.Sigmoid()
45 | )
46 |
47 | def forward(self, input_features, gating_features):
48 | """
49 |
50 | :param input_features: [NxCxHxW] featuers comming from the shape branch (canny branch).
51 | :param gating_features: [Nx1xHxW] features comming from the texture branch (resnet). Only one channel feature map.
52 | :return:
53 | """
54 | alphas = self._gate_conv(torch.cat([input_features, gating_features], dim=1))
55 | input_features = (input_features * (alphas + 1))
56 | return F.conv2d(input_features, self.weight, self.bias, self.stride,
57 | self.padding, self.dilation, self.groups)
58 |
59 | def reset_parameters(self):
60 | nn.init.xavier_normal_(self.weight)
61 | if self.bias is not None:
62 | nn.init.zeros_(self.bias)
63 |
64 |
65 | class Conv2dPad(nn.Conv2d):
66 | def forward(self, input):
67 | return myF.conv2d_same(input,self.weight,self.groups)
68 |
69 | class HighFrequencyGatedSpatialConv2d(_ConvNd):
70 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
71 | padding=0, dilation=1, groups=1, bias=False):
72 | """
73 |
74 | :param in_channels:
75 | :param out_channels:
76 | :param kernel_size:
77 | :param stride:
78 | :param padding:
79 | :param dilation:
80 | :param groups:
81 | :param bias:
82 | """
83 |
84 | kernel_size = _pair(kernel_size)
85 | stride = _pair(stride)
86 | padding = _pair(padding)
87 | dilation = _pair(dilation)
88 | super(HighFrequencyGatedSpatialConv2d, self).__init__(
89 | in_channels, out_channels, kernel_size, stride, padding, dilation,
90 | False, _pair(0), groups, bias)
91 |
92 | self._gate_conv = nn.Sequential(
93 | mynn.Norm2d(in_channels+1),
94 | nn.Conv2d(in_channels+1, in_channels+1, 1),
95 | nn.ReLU(),
96 | nn.Conv2d(in_channels+1, 1, 1),
97 | mynn.Norm2d(1),
98 | nn.Sigmoid()
99 | )
100 |
101 | kernel_size = 7
102 | sigma = 3
103 |
104 | x_cord = torch.arange(kernel_size).float()
105 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size).float()
106 | y_grid = x_grid.t().float()
107 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
108 |
109 | mean = (kernel_size - 1)/2.
110 | variance = sigma**2.
111 | gaussian_kernel = (1./(2.*math.pi*variance)) *\
112 | torch.exp(
113 | -torch.sum((xy_grid - mean)**2., dim=-1) /\
114 | (2*variance)
115 | )
116 |
117 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
118 |
119 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
120 | gaussian_kernel = gaussian_kernel.repeat(in_channels, 1, 1, 1)
121 |
122 | self.gaussian_filter = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, padding=3,
123 | kernel_size=kernel_size, groups=in_channels, bias=False)
124 |
125 | self.gaussian_filter.weight.data = gaussian_kernel
126 | self.gaussian_filter.weight.requires_grad = False
127 |
128 | self.cw = nn.Conv2d(in_channels * 2, in_channels, 1)
129 |
130 | self.procdog = nn.Sequential(
131 | nn.Conv2d(in_channels, in_channels, 1),
132 | mynn.Norm2d(in_channels),
133 | nn.Sigmoid()
134 | )
135 |
136 | def forward(self, input_features, gating_features):
137 | """
138 |
139 | :param input_features: [NxCxHxW] featuers comming from the shape branch (canny branch).
140 | :param gating_features: [Nx1xHxW] features comming from the texture branch (resnet). Only one channel feature map.
141 | :return:
142 | """
143 | n, c, h, w = input_features.size()
144 | smooth_features = self.gaussian_filter(input_features)
145 | dog_features = input_features - smooth_features
146 | dog_features = self.cw(torch.cat((dog_features, input_features), dim=1))
147 |
148 | alphas = self._gate_conv(torch.cat([input_features, gating_features], dim=1))
149 |
150 | dog_features = dog_features * (alphas + 1)
151 |
152 | return F.conv2d(dog_features, self.weight, self.bias, self.stride,
153 | self.padding, self.dilation, self.groups)
154 |
155 | def reset_parameters(self):
156 | nn.init.xavier_normal_(self.weight)
157 | if self.bias is not None:
158 | nn.init.zeros_(self.bias)
159 |
160 | def t():
161 | import matplotlib.pyplot as plt
162 |
163 | canny_map_filters_in = 8
164 | canny_map = np.random.normal(size=(1, canny_map_filters_in, 10, 10)) # NxCxHxW
165 | resnet_map = np.random.normal(size=(1, 1, 10, 10)) # NxCxHxW
166 | plt.imshow(canny_map[0, 0])
167 | plt.show()
168 |
169 | canny_map = torch.from_numpy(canny_map).float()
170 | resnet_map = torch.from_numpy(resnet_map).float()
171 |
172 | gconv = GatedSpatialConv2d(canny_map_filters_in, canny_map_filters_in,
173 | kernel_size=3, stride=1, padding=1)
174 | output_map = gconv(canny_map, resnet_map)
175 | print('done')
176 |
177 |
178 | if __name__ == "__main__":
179 | t()
180 |
181 |
--------------------------------------------------------------------------------
/models/PraNet_Res2Net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .Res2Net_v1b import res2net50_v1b_26w_4s
5 |
6 |
7 | class BasicConv2d(nn.Module):
8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
9 | super(BasicConv2d, self).__init__()
10 | self.conv = nn.Conv2d(in_planes, out_planes,
11 | kernel_size=kernel_size, stride=stride,
12 | padding=padding, dilation=dilation, bias=False)
13 | self.bn = nn.BatchNorm2d(out_planes)
14 | self.relu = nn.ReLU(inplace=True)
15 |
16 | def forward(self, x):
17 | x = self.conv(x)
18 | x = self.bn(x)
19 | return x
20 |
21 |
22 | class RFB_modified(nn.Module):
23 | def __init__(self, in_channel, out_channel):
24 | super(RFB_modified, self).__init__()
25 | self.relu = nn.ReLU(True)
26 | self.branch0 = nn.Sequential(
27 | BasicConv2d(in_channel, out_channel, 1),
28 | )
29 | self.branch1 = nn.Sequential(
30 | BasicConv2d(in_channel, out_channel, 1),
31 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
32 | BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
33 | BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
34 | )
35 | self.branch2 = nn.Sequential(
36 | BasicConv2d(in_channel, out_channel, 1),
37 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
38 | BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
39 | BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
40 | )
41 | self.branch3 = nn.Sequential(
42 | BasicConv2d(in_channel, out_channel, 1),
43 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
44 | BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
45 | BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
46 | )
47 | self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
48 | self.conv_res = BasicConv2d(in_channel, out_channel, 1)
49 |
50 | def forward(self, x):
51 | x0 = self.branch0(x)
52 | x1 = self.branch1(x)
53 | x2 = self.branch2(x)
54 | x3 = self.branch3(x)
55 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
56 |
57 | x = self.relu(x_cat + self.conv_res(x))
58 | return x
59 |
60 |
61 | class aggregation(nn.Module):
62 | # dense aggregation, it can be replaced by other aggregation previous, such as DSS, amulet, and so on.
63 | # used after MSF
64 | def __init__(self, channel):
65 | super(aggregation, self).__init__()
66 | self.relu = nn.ReLU(True)
67 |
68 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
69 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
70 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
71 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
72 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
73 | self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
74 |
75 | self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
76 | self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
77 | self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
78 | self.conv5 = nn.Conv2d(3*channel, 1, 1)
79 |
80 | def forward(self, x1, x2, x3):
81 | x1_1 = x1
82 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
83 | x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \
84 | * self.conv_upsample3(self.upsample(x2)) * x3
85 |
86 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
87 | x2_2 = self.conv_concat2(x2_2)
88 |
89 | x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
90 | x3_2 = self.conv_concat3(x3_2)
91 |
92 | x = self.conv4(x3_2)
93 | x = self.conv5(x)
94 |
95 | return x
96 |
97 |
98 | class PraNet(nn.Module):
99 | # res2net based encoder decoder
100 | def __init__(self, channel=32):
101 | super(PraNet, self).__init__()
102 | # ---- ResNet Backbone ----
103 | self.resnet = res2net50_v1b_26w_4s(pretrained=False)
104 | self.resnet.conv1[0]=nn.Conv2d(1,32,kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
105 | # ---- Receptive Field Block like module ----
106 | self.rfb2_1 = RFB_modified(512, channel)
107 | self.rfb3_1 = RFB_modified(1024, channel)
108 | self.rfb4_1 = RFB_modified(2048, channel)
109 | # ---- Partial Decoder ----
110 | self.agg1 = aggregation(channel)
111 | # ---- reverse attention branch 4 ----
112 | self.ra4_conv1 = BasicConv2d(2048, 256, kernel_size=1)
113 | self.ra4_conv2 = BasicConv2d(256, 256, kernel_size=5, padding=2)
114 | self.ra4_conv3 = BasicConv2d(256, 256, kernel_size=5, padding=2)
115 | self.ra4_conv4 = BasicConv2d(256, 256, kernel_size=5, padding=2)
116 | self.ra4_conv5 = BasicConv2d(256, 1, kernel_size=1)
117 | # ---- reverse attention branch 3 ----
118 | self.ra3_conv1 = BasicConv2d(1024, 64, kernel_size=1)
119 | self.ra3_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
120 | self.ra3_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
121 | self.ra3_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)
122 | # ---- reverse attention branch 2 ----
123 | self.ra2_conv1 = BasicConv2d(512, 64, kernel_size=1)
124 | self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
125 | self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
126 | self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)
127 |
128 | def forward(self, x):
129 | x = self.resnet.conv1(x)
130 | x = self.resnet.bn1(x)
131 | x = self.resnet.relu(x)
132 | x = self.resnet.maxpool(x) # bs, 64, 88, 88
133 | # ---- low-level features ----
134 | x1 = self.resnet.layer1(x) # bs, 256, 88, 88
135 | x2 = self.resnet.layer2(x1) # bs, 512, 44, 44
136 |
137 | x3 = self.resnet.layer3(x2) # bs, 1024, 22, 22
138 | x4 = self.resnet.layer4(x3) # bs, 2048, 11, 11
139 | x2_rfb = self.rfb2_1(x2) # channel -> 32
140 | x3_rfb = self.rfb3_1(x3) # channel -> 32
141 | x4_rfb = self.rfb4_1(x4) # channel -> 32
142 |
143 | ra5_feat = self.agg1(x4_rfb, x3_rfb, x2_rfb)
144 | lateral_map_5 = F.interpolate(ra5_feat, scale_factor=8, mode='bilinear') # NOTES: Sup-1 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
145 |
146 | # ---- reverse attention branch_4 ----
147 | crop_4 = F.interpolate(ra5_feat, scale_factor=0.25, mode='bilinear')
148 | x = -1*(torch.sigmoid(crop_4)) + 1
149 | x = x.expand(-1, 2048, -1, -1).mul(x4)
150 | x = self.ra4_conv1(x)
151 | x = F.relu(self.ra4_conv2(x))
152 | x = F.relu(self.ra4_conv3(x))
153 | x = F.relu(self.ra4_conv4(x))
154 | ra4_feat = self.ra4_conv5(x)
155 | x = ra4_feat + crop_4
156 | lateral_map_4 = F.interpolate(x, scale_factor=32, mode='bilinear') # NOTES: Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352)
157 |
158 | # ---- reverse attention branch_3 ----
159 | crop_3 = F.interpolate(x, scale_factor=2, mode='bilinear')
160 | x = -1*(torch.sigmoid(crop_3)) + 1
161 | x = x.expand(-1, 1024, -1, -1).mul(x3)
162 | x = self.ra3_conv1(x)
163 | x = F.relu(self.ra3_conv2(x))
164 | x = F.relu(self.ra3_conv3(x))
165 | ra3_feat = self.ra3_conv4(x)
166 | x = ra3_feat + crop_3
167 | lateral_map_3 = F.interpolate(x, scale_factor=16, mode='bilinear') # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352)
168 |
169 | # ---- reverse attention branch_2 ----
170 | crop_2 = F.interpolate(x, scale_factor=2, mode='bilinear')
171 | x = -1*(torch.sigmoid(crop_2)) + 1
172 | x = x.expand(-1, 512, -1, -1).mul(x2)
173 | x = self.ra2_conv1(x)
174 | x = F.relu(self.ra2_conv2(x))
175 | x = F.relu(self.ra2_conv3(x))
176 | ra2_feat = self.ra2_conv4(x)
177 | x = ra2_feat + crop_2
178 | lateral_map_2 = F.interpolate(x, scale_factor=8, mode='bilinear') # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
179 |
180 | return lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2
181 |
182 |
183 | if __name__ == '__main__':
184 | ras = PraNet().cuda()
185 | input_tensor = torch.randn(1, 1, 96, 96).cuda()
186 |
187 | out = ras(input_tensor)
188 | print(out[0].shape)
--------------------------------------------------------------------------------
/models/R2U_Net_model.py:
--------------------------------------------------------------------------------
1 | # """ Full assembly of the parts to form the complete network """
2 | # """Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn import init
9 |
10 | def init_weights(net, init_type='normal', gain=0.02):
11 | def init_func(m):
12 | classname = m.__class__.__name__
13 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
14 | if init_type == 'normal':
15 | init.normal_(m.weight.data, 0.0, gain)
16 | elif init_type == 'xavier':
17 | init.xavier_normal_(m.weight.data, gain=gain)
18 | elif init_type == 'kaiming':
19 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
20 | elif init_type == 'orthogonal':
21 | init.orthogonal_(m.weight.data, gain=gain)
22 | else:
23 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
24 | if hasattr(m, 'bias') and m.bias is not None:
25 | init.constant_(m.bias.data, 0.0)
26 | elif classname.find('BatchNorm2d') != -1:
27 | init.normal_(m.weight.data, 1.0, gain)
28 | init.constant_(m.bias.data, 0.0)
29 |
30 | print('initialize network with %s' % init_type)
31 | net.apply(init_func)
32 | # class UNet(nn.Module):
33 | # def __init__(self, n_channels, n_classes, bilinear=True):
34 | # super(UNet, self).__init__()
35 | # self.n_channels = n_channels
36 | # self.n_classes = n_classes
37 | # self.bilinear = bilinear
38 |
39 | # self.inc = DoubleConv(n_channels, 64)
40 | # self.down1 = Down(64, 128)
41 | # self.down2 = Down(128, 256)
42 | # self.down3 = Down(256, 512)
43 | # self.down4 = Down(512, 512)
44 | # self.up1 = Up(1024, 256, bilinear)
45 | # self.up2 = Up(512, 128, bilinear)
46 | # self.up3 = Up(256, 64, bilinear)
47 | # self.up4 = Up(128, 64, bilinear)
48 | # self.outc = OutConv(64, n_classes)
49 |
50 | # def forward(self, x):
51 | # x1 = self.inc(x)
52 | # x2 = self.down1(x1)
53 | # x3 = self.down2(x2)
54 | # x4 = self.down3(x3)
55 | # x5 = self.down4(x4)
56 | # x = self.up1(x5, x4)
57 | # x = self.up2(x, x3)
58 | # x = self.up3(x, x2)
59 | # x = self.up4(x, x1)
60 | # logits = self.outc(x)
61 | # return logits
62 |
63 | class up_conv(nn.Module):
64 | def __init__(self,ch_in,ch_out):
65 | super(up_conv,self).__init__()
66 | self.up = nn.Sequential(
67 | nn.Upsample(scale_factor=2),
68 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
69 | nn.BatchNorm2d(ch_out),
70 | nn.ReLU(inplace=True)
71 | )
72 |
73 | def forward(self, x1, x2):
74 | x1 = self.up(x1)
75 | # input is CHW
76 | diffY = x2.size()[2] - x1.size()[2]
77 | diffX = x2.size()[3] - x1.size()[3]
78 |
79 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
80 | diffY // 2, diffY - diffY // 2])
81 | x = torch.cat([x2, x1], dim = 1)
82 | return x
83 |
84 | class Recurrent_block(nn.Module):
85 | def __init__(self,ch_out,t=2):
86 | super(Recurrent_block,self).__init__()
87 | self.t = t
88 | self.ch_out = ch_out
89 | self.conv = nn.Sequential(
90 | nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
91 | nn.BatchNorm2d(ch_out),
92 | nn.ReLU(inplace=True)
93 | )
94 |
95 | def forward(self,x):
96 | for i in range(self.t):
97 |
98 | if i==0:
99 | x1 = self.conv(x)
100 |
101 | x1 = self.conv(x+x1)
102 | return x1
103 |
104 | class RRCNN_block(nn.Module):
105 | def __init__(self,ch_in,ch_out,t=2):
106 | super(RRCNN_block,self).__init__()
107 | self.RCNN = nn.Sequential(
108 | Recurrent_block(ch_out,t=t),
109 | Recurrent_block(ch_out,t=t)
110 | )
111 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
112 |
113 | def forward(self,x):
114 | x = self.Conv_1x1(x)
115 | x1 = self.RCNN(x)
116 | return x+x1
117 | class R2U_Net(nn.Module):
118 | def __init__(self,img_ch,output_ch,t=2):
119 | super(R2U_Net,self).__init__()
120 |
121 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
122 | self.Upsample = nn.Upsample(scale_factor=2)
123 |
124 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
125 |
126 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
127 |
128 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
129 |
130 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
131 |
132 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
133 |
134 | self.Up5 = up_conv(ch_in=1024,ch_out=512)
135 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
136 |
137 | self.Up4 = up_conv(ch_in=512,ch_out=256)
138 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
139 |
140 | self.Up3 = up_conv(ch_in=256,ch_out=128)
141 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
142 |
143 | self.Up2 = up_conv(ch_in=128,ch_out=64)
144 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
145 |
146 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
147 |
148 |
149 | def forward(self,x):
150 | # encoding path
151 | x1 = self.RRCNN1(x)
152 |
153 | x2 = self.Maxpool(x1)
154 | x2 = self.RRCNN2(x2)
155 |
156 | x3 = self.Maxpool(x2)
157 | x3 = self.RRCNN3(x3)
158 |
159 | x4 = self.Maxpool(x3)
160 | x4 = self.RRCNN4(x4)
161 |
162 | x5 = self.Maxpool(x4)
163 | x5 = self.RRCNN5(x5)
164 |
165 | # decoding + concat path
166 | d5 = self.Up5(x5,x4)
167 | d5 = self.Up_RRCNN5(d5)
168 | d4 = self.Up4(d5,x3)
169 |
170 | d4 = self.Up_RRCNN4(d4)
171 |
172 | d3 = self.Up3(d4,x2)
173 |
174 | d3 = self.Up_RRCNN3(d3)
175 |
176 | d2 = self.Up2(d3,x1)
177 |
178 | d2 = self.Up_RRCNN2(d2)
179 | d1 = self.Conv_1x1(d2)
180 |
181 | return d1
182 |
183 | if __name__ == '__main__':
184 | net = R2U_Net(img_ch=1, output_ch=1)
185 | print(net)
186 |
187 |
188 |
--------------------------------------------------------------------------------
/models/UNet_2Plus.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from models.layers import unetConv2, unetUp_origin
6 | from models.init_weights import init_weights
7 | import numpy as np
8 | from torchvision import models
9 | class UNet_2Plus(nn.Module):
10 |
11 | def __init__(self, in_channels=1, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True, is_ds=True):
12 | super(UNet_2Plus, self).__init__()
13 | self.is_deconv = is_deconv
14 | self.in_channels = in_channels
15 | self.is_batchnorm = is_batchnorm
16 | self.is_ds = is_ds
17 | self.feature_scale = feature_scale
18 |
19 | # filters = [32, 64, 128, 256, 512]
20 | filters = [64, 128, 256, 512, 1024]
21 | # filters = [int(x / self.feature_scale) for x in filters]
22 |
23 | # downsampling
24 | self.conv00 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
25 | self.maxpool0 = nn.MaxPool2d(kernel_size=2)
26 | self.conv10 = unetConv2(filters[0], filters[1], self.is_batchnorm)
27 | self.maxpool1 = nn.MaxPool2d(kernel_size=2)
28 | self.conv20 = unetConv2(filters[1], filters[2], self.is_batchnorm)
29 | self.maxpool2 = nn.MaxPool2d(kernel_size=2)
30 | self.conv30 = unetConv2(filters[2], filters[3], self.is_batchnorm)
31 | self.maxpool3 = nn.MaxPool2d(kernel_size=2)
32 | self.conv40 = unetConv2(filters[3], filters[4], self.is_batchnorm)
33 |
34 |
35 | # upsampling
36 | self.up_concat01 = unetUp_origin(filters[1], filters[0], self.is_deconv)
37 | self.up_concat11 = unetUp_origin(filters[2], filters[1], self.is_deconv)
38 | self.up_concat21 = unetUp_origin(filters[3], filters[2], self.is_deconv)
39 | self.up_concat31 = unetUp_origin(filters[4], filters[3], self.is_deconv)
40 |
41 | self.up_concat02 = unetUp_origin(filters[1], filters[0], self.is_deconv, 3)
42 | self.up_concat12 = unetUp_origin(filters[2], filters[1], self.is_deconv, 3)
43 | self.up_concat22 = unetUp_origin(filters[3], filters[2], self.is_deconv, 3)
44 |
45 | self.up_concat03 = unetUp_origin(filters[1], filters[0], self.is_deconv, 4)
46 | self.up_concat13 = unetUp_origin(filters[2], filters[1], self.is_deconv, 4)
47 |
48 | self.up_concat04 = unetUp_origin(filters[1], filters[0], self.is_deconv, 5)
49 |
50 | # final conv (without any concat)
51 | self.final_1 = nn.Conv2d(filters[0], n_classes, 1)
52 | self.final_2 = nn.Conv2d(filters[0], n_classes, 1)
53 | self.final_3 = nn.Conv2d(filters[0], n_classes, 1)
54 | self.final_4 = nn.Conv2d(filters[0], n_classes, 1)
55 |
56 | # initialise weights
57 | for m in self.modules():
58 | if isinstance(m, nn.Conv2d):
59 | init_weights(m, init_type='kaiming')
60 | elif isinstance(m, nn.BatchNorm2d):
61 | init_weights(m, init_type='kaiming')
62 |
63 | def forward(self, inputs):
64 | # column : 0
65 | X_00 = self.conv00(inputs)
66 | maxpool0 = self.maxpool0(X_00)
67 | X_10 = self.conv10(maxpool0)
68 | maxpool1 = self.maxpool1(X_10)
69 | X_20 = self.conv20(maxpool1)
70 | maxpool2 = self.maxpool2(X_20)
71 | X_30 = self.conv30(maxpool2)
72 | maxpool3 = self.maxpool3(X_30)
73 | X_40 = self.conv40(maxpool3)
74 |
75 | # column : 1
76 | X_01 = self.up_concat01(X_10, X_00)
77 | X_11 = self.up_concat11(X_20, X_10)
78 | X_21 = self.up_concat21(X_30, X_20)
79 | X_31 = self.up_concat31(X_40, X_30)
80 | # column : 2
81 | X_02 = self.up_concat02(X_11, X_00, X_01)
82 | X_12 = self.up_concat12(X_21, X_10, X_11)
83 | X_22 = self.up_concat22(X_31, X_20, X_21)
84 | # column : 3
85 | X_03 = self.up_concat03(X_12, X_00, X_01, X_02)
86 | X_13 = self.up_concat13(X_22, X_10, X_11, X_12)
87 | # column : 4
88 | X_04 = self.up_concat04(X_13, X_00, X_01, X_02, X_03)
89 |
90 | # final layer
91 | final_1 = self.final_1(X_01)
92 | final_2 = self.final_2(X_02)
93 | final_3 = self.final_3(X_03)
94 | final_4 = self.final_4(X_04)
95 |
96 | final = (final_1 + final_2 + final_3 + final_4) / 4
97 |
98 | if self.is_ds:
99 | return final
100 | else:
101 | return final_4
102 |
103 | # if self.is_ds:
104 | # return F.sigmoid(final)
105 | # else:
106 | # return F.sigmoid(final_4)
107 |
108 | # model = UNet_2Plus()
109 | # print('# generator parameters:', 1.0 * sum(param.numel() for param in model.parameters())/1000000)
110 | # params = list(model.named_parameters())
111 | # for i in range(len(params)):
112 | # (name, param) = params[i]
113 | # print(name)
114 | # print(param.shape)
115 |
--------------------------------------------------------------------------------
/models/__pycache__/AttU_Net_model.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/AttU_Net_model.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/BaseNet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/BaseNet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/BaseNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/BaseNet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/F3net.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/F3net.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/GSConv.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/GSConv.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/GSConv.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/GSConv.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/InfNet_Res2Net.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/InfNet_Res2Net.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/InfNet_Res2Net.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/InfNet_Res2Net.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LDF.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/LDF.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LDunet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/LDunet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/PraNet_Res2Net.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/PraNet_Res2Net.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/R2U_Net_model.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/R2U_Net_model.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/Res2Net_v1b.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/Res2Net_v1b.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/UNet_2Plus.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/UNet_2Plus.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/UNet_2Plus.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/UNet_2Plus.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/attention_blocks.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/attention_blocks.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/attention_blocks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/attention_blocks.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/cenet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/cenet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/cenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/cenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/custom_functions.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/custom_functions.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/custom_functions.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/custom_functions.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/deeplab_v3p.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/deeplab_v3p.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/denseunet_model.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/denseunet_model.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/fcn.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/fcn.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/fcn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/fcn.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/init_weights.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/init_weights.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/init_weights.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/init_weights.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/layers.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/layers.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/layers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/layers.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/models.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/multi_scale.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/multi_scale.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/multi_scale.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/multi_scale.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/multi_scale_module.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/multi_scale_module.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/newnet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/newnet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/norm.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/norm.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/norm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/norm.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/resnet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/unet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/unet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/vggunet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/vggunet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/wassp.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/wassp.cpython-35.pyc
--------------------------------------------------------------------------------
/models/__pycache__/wassp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/wassp.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/wnet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/__pycache__/wnet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/adaptive_avgmax_pool.py:
--------------------------------------------------------------------------------
1 | """ PyTorch selectable adaptive pooling
2 | Adaptive pooling with the ability to select the type of pooling from:
3 | * 'avg' - Average pooling
4 | * 'max' - Max pooling
5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5
6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim
7 | Both a functional and a nn.Module version of the pooling is provided.
8 | Author: Ross Wightman (rwightman)
9 | """
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 |
15 | def pooling_factor(pool_type='avg'):
16 | return 2 if pool_type == 'avgmaxc' else 1
17 |
18 |
19 | def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False):
20 | """Selectable global pooling function with dynamic input kernel size
21 | """
22 | if pool_type == 'avgmaxc':
23 | x = torch.cat([
24 | F.avg_pool2d(
25 | x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad),
26 | F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
27 | ], dim=1)
28 | elif pool_type == 'avgmax':
29 | x_avg = F.avg_pool2d(
30 | x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
31 | x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
32 | x = 0.5 * (x_avg + x_max)
33 | elif pool_type == 'max':
34 | x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
35 | else:
36 | if pool_type != 'avg':
37 | print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
38 | x = F.avg_pool2d(
39 | x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
40 | return x
41 |
42 |
43 | class AdaptiveAvgMaxPool2d(torch.nn.Module):
44 | """Selectable global pooling layer with dynamic input kernel size
45 | """
46 | def __init__(self, output_size=1, pool_type='avg'):
47 | super(AdaptiveAvgMaxPool2d, self).__init__()
48 | self.output_size = output_size
49 | self.pool_type = pool_type
50 | if pool_type == 'avgmaxc' or pool_type == 'avgmax':
51 | self.pool = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size), nn.AdaptiveMaxPool2d(output_size)])
52 | elif pool_type == 'max':
53 | self.pool = nn.AdaptiveMaxPool2d(output_size)
54 | else:
55 | if pool_type != 'avg':
56 | print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
57 | self.pool = nn.AdaptiveAvgPool2d(output_size)
58 |
59 | def forward(self, x):
60 | if self.pool_type == 'avgmaxc':
61 | x = torch.cat([p(x) for p in self.pool], dim=1)
62 | elif self.pool_type == 'avgmax':
63 | x = 0.5 * torch.sum(torch.stack([p(x) for p in self.pool]), 0).squeeze(dim=0)
64 | else:
65 | x = self.pool(x)
66 | return x
67 |
68 | def factor(self):
69 | return pooling_factor(self.pool_type)
70 |
71 | def __repr__(self):
72 | return self.__class__.__name__ + ' (' \
73 | + 'output_size=' + str(self.output_size) \
74 | + ', pool_type=' + self.pool_type + ')'
75 |
--------------------------------------------------------------------------------
/models/backbone/ResNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | """
7 | 3x3 convolution with padding
8 | """
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | class BasicBlock(nn.Module):
14 | expansion = 1
15 |
16 | def __init__(self, inplanes, planes, stride=1, downsample=None):
17 | super(BasicBlock, self).__init__()
18 | self.conv1 = conv3x3(inplanes, planes, stride)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.relu = nn.ReLU(inplace=True)
21 | self.conv2 = conv3x3(planes, planes)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.downsample = downsample
24 | self.stride = stride
25 |
26 | def forward(self, x):
27 | residual = x
28 |
29 | out = self.conv1(x)
30 | out = self.bn1(out)
31 | out = self.relu(out)
32 |
33 | out = self.conv2(out)
34 | out = self.bn2(out)
35 |
36 | if self.downsample is not None:
37 | residual = self.downsample(x)
38 |
39 | out += residual
40 | out = self.relu(out)
41 |
42 | return out
43 |
44 |
45 | class Bottleneck(nn.Module):
46 | expansion = 4
47 |
48 | def __init__(self, inplanes, planes, stride=1, downsample=None):
49 | super(Bottleneck, self).__init__()
50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
51 | self.bn1 = nn.BatchNorm2d(planes)
52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
53 | padding=1, bias=False)
54 | self.bn2 = nn.BatchNorm2d(planes)
55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
56 | self.bn3 = nn.BatchNorm2d(planes * 4)
57 | self.relu = nn.ReLU(inplace=True)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x):
62 | residual = x
63 |
64 | out = self.conv1(x)
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 |
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv3(out)
73 | out = self.bn3(out)
74 |
75 | if self.downsample is not None:
76 | residual = self.downsample(x)
77 |
78 | out += residual
79 | out = self.relu(out)
80 |
81 | return out
82 |
83 |
84 | class ResNet(nn.Module):
85 | # ResNet50 with two branches
86 | def __init__(self):
87 | # self.inplanes = 128
88 | self.inplanes = 64
89 | super(ResNet, self).__init__()
90 |
91 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
92 | bias=False)
93 | self.bn1 = nn.BatchNorm2d(64)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
96 | self.layer1 = self._make_layer(Bottleneck, 64, 3)
97 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
98 | self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)
99 | self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)
100 |
101 | for m in self.modules():
102 | if isinstance(m, nn.Conv2d):
103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
104 | m.weight.data.normal_(0, math.sqrt(2. / n))
105 | elif isinstance(m, nn.BatchNorm2d):
106 | m.weight.data.fill_(1)
107 | m.bias.data.zero_()
108 |
109 | def _make_layer(self, block, planes, blocks, stride=1):
110 | downsample = None
111 | if stride != 1 or self.inplanes != planes * block.expansion:
112 | downsample = nn.Sequential(
113 | nn.Conv2d(self.inplanes, planes * block.expansion,
114 | kernel_size=1, stride=stride, bias=False),
115 | nn.BatchNorm2d(planes * block.expansion),
116 | )
117 |
118 | layers = []
119 |
120 | layers.append(block(self.inplanes, planes, stride, downsample))
121 | self.inplanes = planes * block.expansion
122 | for i in range(1, blocks):
123 | layers.append(block(self.inplanes, planes))
124 |
125 | return nn.Sequential(*layers)
126 |
127 | def forward(self, x):
128 | x = self.conv1(x)
129 | x = self.bn1(x)
130 | x = self.relu(x)
131 | x = self.maxpool(x)
132 |
133 | x = self.layer1(x)
134 | x = self.layer2(x)
135 | x1 = self.layer3_1(x)
136 | x1 = self.layer4_1(x1)
137 |
138 | x2 = self.layer3_2(x)
139 | x2 = self.layer4_2(x2)
140 |
141 | return x1, x2
142 |
--------------------------------------------------------------------------------
/models/backbone/VGGNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class B2_VGG(nn.Module):
6 | def __init__(self):
7 | super(B2_VGG, self).__init__()
8 | conv1 = nn.Sequential()
9 | conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1))
10 | conv1.add_module('relu1_1', nn.ReLU(inplace=True))
11 | conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1))
12 | conv1.add_module('relu1_2', nn.ReLU(inplace=True))
13 |
14 | self.conv1 = conv1
15 | conv2 = nn.Sequential()
16 | conv2.add_module('pool1', nn.AvgPool2d(2, stride=2))
17 | conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1))
18 | conv2.add_module('relu2_1', nn.ReLU())
19 | conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1))
20 | conv2.add_module('relu2_2', nn.ReLU())
21 | self.conv2 = conv2
22 |
23 | conv3 = nn.Sequential()
24 | conv3.add_module('pool2', nn.AvgPool2d(2, stride=2))
25 | conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1))
26 | conv3.add_module('relu3_1', nn.ReLU())
27 | conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1))
28 | conv3.add_module('relu3_2', nn.ReLU())
29 | conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1))
30 | conv3.add_module('relu3_3', nn.ReLU())
31 | self.conv3 = conv3
32 |
33 | conv4_1 = nn.Sequential()
34 | conv4_1.add_module('pool3_1', nn.AvgPool2d(2, stride=2))
35 | conv4_1.add_module('conv4_1_1', nn.Conv2d(256, 512, 3, 1, 1))
36 | conv4_1.add_module('relu4_1_1', nn.ReLU())
37 | conv4_1.add_module('conv4_2_1', nn.Conv2d(512, 512, 3, 1, 1))
38 | conv4_1.add_module('relu4_2_1', nn.ReLU())
39 | conv4_1.add_module('conv4_3_1', nn.Conv2d(512, 512, 3, 1, 1))
40 | conv4_1.add_module('relu4_3_1', nn.ReLU())
41 | self.conv4_1 = conv4_1
42 |
43 | conv5_1 = nn.Sequential()
44 | conv5_1.add_module('pool4_1', nn.AvgPool2d(2, stride=2))
45 | conv5_1.add_module('conv5_1_1', nn.Conv2d(512, 512, 3, 1, 1))
46 | conv5_1.add_module('relu5_1_1', nn.ReLU())
47 | conv5_1.add_module('conv5_2_1', nn.Conv2d(512, 512, 3, 1, 1))
48 | conv5_1.add_module('relu5_2_1', nn.ReLU())
49 | conv5_1.add_module('conv5_3_1', nn.Conv2d(512, 512, 3, 1, 1))
50 | conv5_1.add_module('relu5_3_1', nn.ReLU())
51 | self.conv5_1 = conv5_1
52 |
53 | conv4_2 = nn.Sequential()
54 | conv4_2.add_module('pool3_2', nn.AvgPool2d(2, stride=2))
55 | conv4_2.add_module('conv4_1_2', nn.Conv2d(256, 512, 3, 1, 1))
56 | conv4_2.add_module('relu4_1_2', nn.ReLU())
57 | conv4_2.add_module('conv4_2_2', nn.Conv2d(512, 512, 3, 1, 1))
58 | conv4_2.add_module('relu4_2_2', nn.ReLU())
59 | conv4_2.add_module('conv4_3_2', nn.Conv2d(512, 512, 3, 1, 1))
60 | conv4_2.add_module('relu4_3_2', nn.ReLU())
61 | self.conv4_2 = conv4_2
62 |
63 | conv5_2 = nn.Sequential()
64 | conv5_2.add_module('pool4_2', nn.AvgPool2d(2, stride=2))
65 | conv5_2.add_module('conv5_1_2', nn.Conv2d(512, 512, 3, 1, 1))
66 | conv5_2.add_module('relu5_1_2', nn.ReLU())
67 | conv5_2.add_module('conv5_2_2', nn.Conv2d(512, 512, 3, 1, 1))
68 | conv5_2.add_module('relu5_2_2', nn.ReLU())
69 | conv5_2.add_module('conv5_3_2', nn.Conv2d(512, 512, 3, 1, 1))
70 | conv5_2.add_module('relu5_3_2', nn.ReLU())
71 | self.conv5_2 = conv5_2
72 |
73 | pre_train = torch.load('./Snapshots/pre_trained/vgg16-397923af.pth')
74 | self._initialize_weights(pre_train)
75 |
76 | def forward(self, x):
77 | x = self.conv1(x)
78 | x = self.conv2(x)
79 | x = self.conv3(x)
80 | x1 = self.conv4_1(x)
81 | x1 = self.conv5_1(x1)
82 | x2 = self.conv4_2(x)
83 | x2 = self.conv5_2(x2)
84 | return x1, x2
85 |
86 | def _initialize_weights(self, pre_train):
87 | keys = list(pre_train.keys())
88 | self.conv1.conv1_1.weight.data.copy_(pre_train[keys[0]])
89 | self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]])
90 | self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]])
91 | self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]])
92 | self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]])
93 | self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]])
94 | self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]])
95 | self.conv4_1.conv4_1_1.weight.data.copy_(pre_train[keys[14]])
96 | self.conv4_1.conv4_2_1.weight.data.copy_(pre_train[keys[16]])
97 | self.conv4_1.conv4_3_1.weight.data.copy_(pre_train[keys[18]])
98 | self.conv5_1.conv5_1_1.weight.data.copy_(pre_train[keys[20]])
99 | self.conv5_1.conv5_2_1.weight.data.copy_(pre_train[keys[22]])
100 | self.conv5_1.conv5_3_1.weight.data.copy_(pre_train[keys[24]])
101 | self.conv4_2.conv4_1_2.weight.data.copy_(pre_train[keys[14]])
102 | self.conv4_2.conv4_2_2.weight.data.copy_(pre_train[keys[16]])
103 | self.conv4_2.conv4_3_2.weight.data.copy_(pre_train[keys[18]])
104 | self.conv5_2.conv5_1_2.weight.data.copy_(pre_train[keys[20]])
105 | self.conv5_2.conv5_2_2.weight.data.copy_(pre_train[keys[22]])
106 | self.conv5_2.conv5_3_2.weight.data.copy_(pre_train[keys[24]])
107 |
108 | self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]])
109 | self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]])
110 | self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]])
111 | self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]])
112 | self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]])
113 | self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]])
114 | self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]])
115 | self.conv4_1.conv4_1_1.bias.data.copy_(pre_train[keys[15]])
116 | self.conv4_1.conv4_2_1.bias.data.copy_(pre_train[keys[17]])
117 | self.conv4_1.conv4_3_1.bias.data.copy_(pre_train[keys[19]])
118 | self.conv5_1.conv5_1_1.bias.data.copy_(pre_train[keys[21]])
119 | self.conv5_1.conv5_2_1.bias.data.copy_(pre_train[keys[23]])
120 | self.conv5_1.conv5_3_1.bias.data.copy_(pre_train[keys[25]])
121 | self.conv4_2.conv4_1_2.bias.data.copy_(pre_train[keys[15]])
122 | self.conv4_2.conv4_2_2.bias.data.copy_(pre_train[keys[17]])
123 | self.conv4_2.conv4_3_2.bias.data.copy_(pre_train[keys[19]])
124 | self.conv5_2.conv5_1_2.bias.data.copy_(pre_train[keys[21]])
125 | self.conv5_2.conv5_2_2.bias.data.copy_(pre_train[keys[23]])
126 | self.conv5_2.conv5_3_2.bias.data.copy_(pre_train[keys[25]])
127 |
--------------------------------------------------------------------------------
/models/backbone/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/backbone/__init__.py
--------------------------------------------------------------------------------
/models/backbone/__pycache__/Res2Net.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/backbone/__pycache__/Res2Net.cpython-35.pyc
--------------------------------------------------------------------------------
/models/backbone/__pycache__/Res2Net.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/backbone/__pycache__/Res2Net.cpython-36.pyc
--------------------------------------------------------------------------------
/models/backbone/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/backbone/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/models/backbone/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/backbone/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/custom_functions.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torchvision.transforms.functional import pad
9 | import numpy as np
10 |
11 |
12 | def calc_pad_same(in_siz, out_siz, stride, ksize):
13 | """Calculate same padding width.
14 | Args:
15 | ksize: kernel size [I, J].
16 | Returns:
17 | pad_: Actual padding width.
18 | """
19 | return (out_siz - 1) * stride + ksize - in_siz
20 |
21 |
22 | def conv2d_same(input, kernel, groups,bias=None,stride=1,padding=0,dilation=1):
23 | n, c, h, w = input.shape
24 | kout, ki_c_g, kh, kw = kernel.shape
25 | pw = calc_pad_same(w, w, 1, kw)
26 | ph = calc_pad_same(h, h, 1, kh)
27 | pw_l = pw // 2
28 | pw_r = pw - pw_l
29 | ph_t = ph // 2
30 | ph_b = ph - ph_t
31 |
32 | input_ = F.pad(input, (pw_l, pw_r, ph_t, ph_b))
33 | result = F.conv2d(input_, kernel, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
34 | assert result.shape == input.shape
35 | return result
36 |
37 |
38 | def gradient_central_diff(input, cuda):
39 | return input, input
40 | kernel = [[1, 0, -1]]
41 | kernel_t = 0.5 * torch.Tensor(kernel) * -1. # pytorch implements correlation instead of conv
42 | if type(cuda) is int:
43 | if cuda != -1:
44 | kernel_t = kernel_t.cuda(device=cuda)
45 | else:
46 | if cuda is True:
47 | kernel_t = kernel_t.cuda()
48 | n, c, h, w = input.shape
49 |
50 | x = conv2d_same(input, kernel_t.unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), c)
51 | y = conv2d_same(input, kernel_t.t().unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), c)
52 | return x, y
53 |
54 |
55 | def compute_single_sided_diferences(o_x, o_y, input):
56 | # n,c,h,w
57 | #input = input.clone()
58 | o_y[:, :, 0, :] = input[:, :, 1, :].clone() - input[:, :, 0, :].clone()
59 | o_x[:, :, :, 0] = input[:, :, :, 1].clone() - input[:, :, :, 0].clone()
60 | # --
61 | o_y[:, :, -1, :] = input[:, :, -1, :].clone() - input[:, :, -2, :].clone()
62 | o_x[:, :, :, -1] = input[:, :, :, -1].clone() - input[:, :, :, -2].clone()
63 | return o_x, o_y
64 |
65 |
66 | def numerical_gradients_2d(input, cuda=False):
67 | """
68 | numerical gradients implementation over batches using torch group conv operator.
69 | the single sided differences are re-computed later.
70 | it matches np.gradient(image) with the difference than here output=x,y for an image while there output=y,x
71 | :param input: N,C,H,W
72 | :param cuda: whether or not use cuda
73 | :return: X,Y
74 | """
75 | n, c, h, w = input.shape
76 | assert h > 1 and w > 1
77 | x, y = gradient_central_diff(input, cuda)
78 | return x, y
79 |
80 |
81 | def convTri(input, r, cuda=False):
82 | """
83 | Convolves an image by a 2D triangle filter (the 1D triangle filter f is
84 | [1:r r+1 r:-1:1]/(r+1)^2, the 2D version is simply conv2(f,f'))
85 | :param input:
86 | :param r: integer filter radius
87 | :param cuda: move the kernel to gpu
88 | :return:
89 | """
90 | if (r <= 1):
91 | raise ValueError()
92 | n, c, h, w = input.shape
93 | return input
94 | f = list(range(1, r + 1)) + [r + 1] + list(reversed(range(1, r + 1)))
95 | kernel = torch.Tensor([f]) / (r + 1) ** 2
96 | if type(cuda) is int:
97 | if cuda != -1:
98 | kernel = kernel.cuda(device=cuda)
99 | else:
100 | if cuda is True:
101 | kernel = kernel.cuda()
102 |
103 | # padding w
104 | input_ = F.pad(input, (1, 1, 0, 0), mode='replicate')
105 | input_ = F.pad(input_, (r, r, 0, 0), mode='reflect')
106 | input_ = [input_[:, :, :, :r], input, input_[:, :, :, -r:]]
107 | input_ = torch.cat(input_, 3)
108 | t = input_
109 |
110 | # padding h
111 | input_ = F.pad(input_, (0, 0, 1, 1), mode='replicate')
112 | input_ = F.pad(input_, (0, 0, r, r), mode='reflect')
113 | input_ = [input_[:, :, :r, :], t, input_[:, :, -r:, :]]
114 | input_ = torch.cat(input_, 2)
115 |
116 | output = F.conv2d(input_,
117 | kernel.unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]),
118 | padding=0, groups=c)
119 | output = F.conv2d(output,
120 | kernel.t().unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]),
121 | padding=0, groups=c)
122 | return output
123 |
124 |
125 | def compute_normal(E, cuda=False):
126 | if torch.sum(torch.isnan(E)) != 0:
127 | print('nans found here')
128 | import ipdb;
129 | ipdb.set_trace()
130 | E_ = convTri(E, 4, cuda)
131 | Ox, Oy = numerical_gradients_2d(E_, cuda)
132 | Oxx, _ = numerical_gradients_2d(Ox, cuda)
133 | Oxy, Oyy = numerical_gradients_2d(Oy, cuda)
134 |
135 | aa = Oyy * torch.sign(-(Oxy + 1e-5)) / (Oxx + 1e-5)
136 | t = torch.atan(aa)
137 | O = torch.remainder(t, np.pi)
138 |
139 | if torch.sum(torch.isnan(O)) != 0:
140 | print('nans found here')
141 | import ipdb;
142 | ipdb.set_trace()
143 |
144 | return O
145 |
146 |
147 | def compute_normal_2(E, cuda=False):
148 | if torch.sum(torch.isnan(E)) != 0:
149 | print('nans found here')
150 | import ipdb;
151 | ipdb.set_trace()
152 | E_ = convTri(E, 4, cuda)
153 | Ox, Oy = numerical_gradients_2d(E_, cuda)
154 | Oxx, _ = numerical_gradients_2d(Ox, cuda)
155 | Oxy, Oyy = numerical_gradients_2d(Oy, cuda)
156 |
157 | aa = Oyy * torch.sign(-(Oxy + 1e-5)) / (Oxx + 1e-5)
158 | t = torch.atan(aa)
159 | O = torch.remainder(t, np.pi)
160 |
161 | if torch.sum(torch.isnan(O)) != 0:
162 | print('nans found here')
163 | import ipdb;
164 | ipdb.set_trace()
165 |
166 | return O, (Oyy, Oxx)
167 |
168 |
169 | def compute_grad_mag(E, cuda=False):
170 | E_ = convTri(E, 4, cuda)
171 | Ox, Oy = numerical_gradients_2d(E_, cuda)
172 | mag = torch.sqrt(torch.mul(Ox,Ox) + torch.mul(Oy,Oy) + 1e-6)
173 | mag = mag / mag.max();
174 |
175 | return mag
176 |
--------------------------------------------------------------------------------
/models/deeplab/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/__pycache__/aspp.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/__pycache__/aspp.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/__pycache__/aspp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/__pycache__/aspp.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/__pycache__/decoder.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/__pycache__/decoder.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/__pycache__/decoder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/__pycache__/decoder.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/aspp.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from models.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 | class _ASPPModule(nn.Module):
8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
9 | super(_ASPPModule, self).__init__()
10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
11 | stride=1, padding=padding, dilation=dilation, bias=False)
12 | self.bn = BatchNorm(planes)
13 | self.relu = nn.ReLU()
14 |
15 | self._init_weight()
16 |
17 | def forward(self, x):
18 | x = self.atrous_conv(x)
19 | x = self.bn(x)
20 |
21 | return self.relu(x)
22 |
23 | def _init_weight(self):
24 | for m in self.modules():
25 | if isinstance(m, nn.Conv2d):
26 | torch.nn.init.kaiming_normal_(m.weight)
27 | elif isinstance(m, SynchronizedBatchNorm2d):
28 | m.weight.data.fill_(1)
29 | m.bias.data.zero_()
30 | elif isinstance(m, nn.BatchNorm2d):
31 | m.weight.data.fill_(1)
32 | m.bias.data.zero_()
33 |
34 | class ASPP(nn.Module):
35 | def __init__(self, backbone, output_stride, BatchNorm):
36 | super(ASPP, self).__init__()
37 | if backbone == 'drn':
38 | inplanes = 512
39 | elif backbone == 'mobilenet':
40 | inplanes = 320
41 | else:
42 | inplanes = 2048
43 | if output_stride == 16:
44 | dilations = [1, 6, 12, 18]
45 | elif output_stride == 8:
46 | dilations = [1, 12, 24, 36]
47 | else:
48 | raise NotImplementedError
49 |
50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
54 |
55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
57 | BatchNorm(256),
58 | nn.ReLU())
59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
60 | self.bn1 = BatchNorm(256)
61 | self.relu = nn.ReLU()
62 | self.dropout = nn.Dropout(0.5)
63 | self._init_weight()
64 |
65 | def forward(self, x):
66 | x1 = self.aspp1(x)
67 | x2 = self.aspp2(x)
68 | x3 = self.aspp3(x)
69 | x4 = self.aspp4(x)
70 | x5 = self.global_avg_pool(x)
71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
73 |
74 | x = self.conv1(x)
75 | x = self.bn1(x)
76 | x = self.relu(x)
77 |
78 | return self.dropout(x)
79 |
80 | def _init_weight(self):
81 | for m in self.modules():
82 | if isinstance(m, nn.Conv2d):
83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
84 | # m.weight.data.normal_(0, math.sqrt(2. / n))
85 | torch.nn.init.kaiming_normal_(m.weight)
86 | elif isinstance(m, SynchronizedBatchNorm2d):
87 | m.weight.data.fill_(1)
88 | m.bias.data.zero_()
89 | elif isinstance(m, nn.BatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 |
93 |
94 | def build_aspp(backbone, output_stride, BatchNorm):
95 | return ASPP(backbone, output_stride, BatchNorm)
--------------------------------------------------------------------------------
/models/deeplab/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from models.deeplab.backbone import resnet, xception, drn, mobilenet
2 |
3 | def build_backbone(backbone, output_stride, BatchNorm):
4 | if backbone == 'resnet':
5 | return resnet.ResNet101(output_stride, BatchNorm)
6 | elif backbone == 'xception':
7 | return xception.AlignedXception(output_stride, BatchNorm)
8 | elif backbone == 'drn':
9 | return drn.drn_d_54(BatchNorm)
10 | elif backbone == 'mobilenet':
11 | return mobilenet.MobileNetV2(output_stride, BatchNorm)
12 | else:
13 | raise NotImplementedError
14 |
--------------------------------------------------------------------------------
/models/deeplab/backbone/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/backbone/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/backbone/__pycache__/drn.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/drn.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/backbone/__pycache__/drn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/drn.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/backbone/__pycache__/mobilenet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/mobilenet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/backbone/__pycache__/resnet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/resnet.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/backbone/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/backbone/__pycache__/xception.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/xception.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/backbone/__pycache__/xception.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/backbone/__pycache__/xception.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/backbone/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import math
5 | from models.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | def conv_bn(inp, oup, stride, BatchNorm):
9 | return nn.Sequential(
10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
11 | BatchNorm(oup),
12 | nn.ReLU6(inplace=True)
13 | )
14 |
15 |
16 | def fixed_padding(inputs, kernel_size, dilation):
17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
18 | pad_total = kernel_size_effective - 1
19 | pad_beg = pad_total // 2
20 | pad_end = pad_total - pad_beg
21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
22 | return padded_inputs
23 |
24 |
25 | class InvertedResidual(nn.Module):
26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
27 | super(InvertedResidual, self).__init__()
28 | self.stride = stride
29 | assert stride in [1, 2]
30 |
31 | hidden_dim = round(inp * expand_ratio)
32 | self.use_res_connect = self.stride == 1 and inp == oup
33 | self.kernel_size = 3
34 | self.dilation = dilation
35 |
36 | if expand_ratio == 1:
37 | self.conv = nn.Sequential(
38 | # dw
39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
40 | BatchNorm(hidden_dim),
41 | nn.ReLU6(inplace=True),
42 | # pw-linear
43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
44 | BatchNorm(oup),
45 | )
46 | else:
47 | self.conv = nn.Sequential(
48 | # pw
49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
50 | BatchNorm(hidden_dim),
51 | nn.ReLU6(inplace=True),
52 | # dw
53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
54 | BatchNorm(hidden_dim),
55 | nn.ReLU6(inplace=True),
56 | # pw-linear
57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
58 | BatchNorm(oup),
59 | )
60 |
61 | def forward(self, x):
62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
63 | if self.use_res_connect:
64 | x = x + self.conv(x_pad)
65 | else:
66 | x = self.conv(x_pad)
67 | return x
68 |
69 |
70 | class MobileNetV2(nn.Module):
71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True):
72 | super(MobileNetV2, self).__init__()
73 | block = InvertedResidual
74 | input_channel = 32
75 | current_stride = 1
76 | rate = 1
77 | interverted_residual_setting = [
78 | # t, c, n, s
79 | [1, 16, 1, 1],
80 | [6, 24, 2, 2],
81 | [6, 32, 3, 2],
82 | [6, 64, 4, 2],
83 | [6, 96, 3, 1],
84 | [6, 160, 3, 2],
85 | [6, 320, 1, 1],
86 | ]
87 |
88 | # building first layer
89 | input_channel = int(input_channel * width_mult)
90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
91 | current_stride *= 2
92 | # building inverted residual blocks
93 | for t, c, n, s in interverted_residual_setting:
94 | if current_stride == output_stride:
95 | stride = 1
96 | dilation = rate
97 | rate *= s
98 | else:
99 | stride = s
100 | dilation = 1
101 | current_stride *= s
102 | output_channel = int(c * width_mult)
103 | for i in range(n):
104 | if i == 0:
105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
106 | else:
107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
108 | input_channel = output_channel
109 | self.features = nn.Sequential(*self.features)
110 | self._initialize_weights()
111 |
112 | if pretrained:
113 | self._load_pretrained_model()
114 |
115 | self.low_level_features = self.features[0:4]
116 | self.high_level_features = self.features[4:]
117 |
118 | def forward(self, x):
119 | low_level_feat = self.low_level_features(x)
120 | x = self.high_level_features(low_level_feat)
121 | return x, low_level_feat
122 |
123 | def _load_pretrained_model(self):
124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
125 | model_dict = {}
126 | state_dict = self.state_dict()
127 | for k, v in pretrain_dict.items():
128 | if k in state_dict:
129 | model_dict[k] = v
130 | state_dict.update(model_dict)
131 | self.load_state_dict(state_dict)
132 |
133 | def _initialize_weights(self):
134 | for m in self.modules():
135 | if isinstance(m, nn.Conv2d):
136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137 | # m.weight.data.normal_(0, math.sqrt(2. / n))
138 | torch.nn.init.kaiming_normal_(m.weight)
139 | elif isinstance(m, SynchronizedBatchNorm2d):
140 | m.weight.data.fill_(1)
141 | m.bias.data.zero_()
142 | elif isinstance(m, nn.BatchNorm2d):
143 | m.weight.data.fill_(1)
144 | m.bias.data.zero_()
145 |
146 | if __name__ == "__main__":
147 | input = torch.rand(1, 3, 512, 512)
148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
149 | output, low_level_feat = model(input)
150 | print(output.size())
151 | print(low_level_feat.size())
152 |
--------------------------------------------------------------------------------
/models/deeplab/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 | from models.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5 |
6 | class Bottleneck(nn.Module):
7 | expansion = 4
8 |
9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
10 | super(Bottleneck, self).__init__()
11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
12 | self.bn1 = BatchNorm(planes)
13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
14 | dilation=dilation, padding=dilation, bias=False)
15 | self.bn2 = BatchNorm(planes)
16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
17 | self.bn3 = BatchNorm(planes * 4)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.downsample = downsample
20 | self.stride = stride
21 | self.dilation = dilation
22 |
23 | def forward(self, x):
24 | residual = x
25 |
26 | out = self.conv1(x)
27 | out = self.bn1(out)
28 | out = self.relu(out)
29 |
30 | out = self.conv2(out)
31 | out = self.bn2(out)
32 | out = self.relu(out)
33 |
34 | out = self.conv3(out)
35 | out = self.bn3(out)
36 |
37 | if self.downsample is not None:
38 | residual = self.downsample(x)
39 |
40 | out += residual
41 | out = self.relu(out)
42 |
43 | return out
44 |
45 | class ResNet(nn.Module):
46 |
47 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):
48 | self.inplanes = 64
49 | super(ResNet, self).__init__()
50 | blocks = [1, 2, 4]
51 | if output_stride == 16:
52 | strides = [1, 2, 2, 1]
53 | dilations = [1, 1, 1, 2]
54 | elif output_stride == 8:
55 | strides = [1, 2, 1, 1]
56 | dilations = [1, 1, 2, 4]
57 | else:
58 | raise NotImplementedError
59 |
60 | # Modules
61 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = BatchNorm(64)
64 | self.relu = nn.ReLU(inplace=True)
65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
66 |
67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
70 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
71 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
72 | self._init_weight()
73 |
74 | if pretrained:
75 | self._load_pretrained_model()
76 |
77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
78 | downsample = None
79 | if stride != 1 or self.inplanes != planes * block.expansion:
80 | downsample = nn.Sequential(
81 | nn.Conv2d(self.inplanes, planes * block.expansion,
82 | kernel_size=1, stride=stride, bias=False),
83 | BatchNorm(planes * block.expansion),
84 | )
85 |
86 | layers = []
87 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
88 | self.inplanes = planes * block.expansion
89 | for i in range(1, blocks):
90 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
91 |
92 | return nn.Sequential(*layers)
93 |
94 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
95 | downsample = None
96 | if stride != 1 or self.inplanes != planes * block.expansion:
97 | downsample = nn.Sequential(
98 | nn.Conv2d(self.inplanes, planes * block.expansion,
99 | kernel_size=1, stride=stride, bias=False),
100 | BatchNorm(planes * block.expansion),
101 | )
102 |
103 | layers = []
104 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
105 | downsample=downsample, BatchNorm=BatchNorm))
106 | self.inplanes = planes * block.expansion
107 | for i in range(1, len(blocks)):
108 | layers.append(block(self.inplanes, planes, stride=1,
109 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
110 |
111 | return nn.Sequential(*layers)
112 |
113 | def forward(self, input):
114 | x = self.conv1(input)
115 | x = self.bn1(x)
116 | x = self.relu(x)
117 | x = self.maxpool(x)
118 |
119 | x = self.layer1(x)
120 | low_level_feat = x
121 | x = self.layer2(x)
122 | x = self.layer3(x)
123 | x = self.layer4(x)
124 | return x, low_level_feat
125 |
126 | def _init_weight(self):
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130 | m.weight.data.normal_(0, math.sqrt(2. / n))
131 | elif isinstance(m, SynchronizedBatchNorm2d):
132 | m.weight.data.fill_(1)
133 | m.bias.data.zero_()
134 | elif isinstance(m, nn.BatchNorm2d):
135 | m.weight.data.fill_(1)
136 | m.bias.data.zero_()
137 |
138 | def _load_pretrained_model(self):
139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
140 | model_dict = {}
141 | state_dict = self.state_dict()
142 | for k, v in pretrain_dict.items():
143 | if k in state_dict:
144 | model_dict[k] = v
145 | state_dict.update(model_dict)
146 | self.load_state_dict(state_dict)
147 |
148 | def ResNet101(output_stride, BatchNorm, pretrained=True):
149 | """Constructs a ResNet-101 model.
150 | Args:
151 | pretrained (bool): If True, returns a model pre-trained on ImageNet
152 | """
153 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=False)
154 | # model.features[0]=nn.Conv2d(1, 64, kernel_size=3, padding=1)
155 | return model
156 |
157 | if __name__ == "__main__":
158 | import torch
159 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
160 | input = torch.rand(1, 3, 512, 512)
161 | output, low_level_feat = model(input)
162 | print(output.size())
163 | print(low_level_feat.size())
--------------------------------------------------------------------------------
/models/deeplab/decoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from models.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 | class Decoder(nn.Module):
8 | def __init__(self, num_classes, backbone, BatchNorm):
9 | super(Decoder, self).__init__()
10 | if backbone == 'resnet' or backbone == 'drn':
11 | low_level_inplanes = 256
12 | elif backbone == 'xception':
13 | low_level_inplanes = 128
14 | elif backbone == 'mobilenet':
15 | low_level_inplanes = 24
16 | else:
17 | raise NotImplementedError
18 |
19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
20 | self.bn1 = BatchNorm(48)
21 | self.relu = nn.ReLU()
22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
23 | BatchNorm(256),
24 | nn.ReLU(),
25 | nn.Dropout(0.5),
26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
27 | BatchNorm(256),
28 | nn.ReLU(),
29 | nn.Dropout(0.1),
30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
31 | self._init_weight()
32 |
33 |
34 | def forward(self, x, low_level_feat):
35 | low_level_feat = self.conv1(low_level_feat)
36 | low_level_feat = self.bn1(low_level_feat)
37 | low_level_feat = self.relu(low_level_feat)
38 |
39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
40 | x = torch.cat((x, low_level_feat), dim=1)
41 | x = self.last_conv(x)
42 |
43 | return x
44 |
45 | def _init_weight(self):
46 | for m in self.modules():
47 | if isinstance(m, nn.Conv2d):
48 | torch.nn.init.kaiming_normal_(m.weight)
49 | elif isinstance(m, SynchronizedBatchNorm2d):
50 | m.weight.data.fill_(1)
51 | m.bias.data.zero_()
52 | elif isinstance(m, nn.BatchNorm2d):
53 | m.weight.data.fill_(1)
54 | m.bias.data.zero_()
55 |
56 | def build_decoder(num_classes, backbone, BatchNorm):
57 | return Decoder(num_classes, backbone, BatchNorm)
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/__pycache__/comm.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/comm.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/__pycache__/comm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/comm.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/__pycache__/replicate.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/replicate.cpython-35.pyc
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/__pycache__/replicate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/deeplab/sync_batchnorm/__pycache__/replicate.cpython-36.pyc
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
61 | and passed to a registered callback.
62 | - After receiving the messages, the master device should gather the information and determine to message passed
63 | back to each slave devices.
64 | """
65 |
66 | def __init__(self, master_callback):
67 | """
68 | Args:
69 | master_callback: a callback to be invoked after having collected messages from slave devices.
70 | """
71 | self._master_callback = master_callback
72 | self._queue = queue.Queue()
73 | self._registry = collections.OrderedDict()
74 | self._activated = False
75 |
76 | def __getstate__(self):
77 | return {'master_callback': self._master_callback}
78 |
79 | def __setstate__(self, state):
80 | self.__init__(state['master_callback'])
81 |
82 | def register_slave(self, identifier):
83 | """
84 | Register an slave device.
85 | Args:
86 | identifier: an identifier, usually is the device id.
87 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
88 | """
89 | if self._activated:
90 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
91 | self._activated = False
92 | self._registry.clear()
93 | future = FutureResult()
94 | self._registry[identifier] = _MasterRegistry(future)
95 | return SlavePipe(identifier, self._queue, future)
96 |
97 | def run_master(self, master_msg):
98 | """
99 | Main entry for the master device in each forward pass.
100 | The messages were first collected from each devices (including the master device), and then
101 | an callback will be invoked to compute the message to be sent back to each devices
102 | (including the master device).
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 | Returns: the message to be sent back to the master device.
107 | """
108 | self._activated = True
109 |
110 | intermediates = [(0, master_msg)]
111 | for i in range(self.nr_slaves):
112 | intermediates.append(self._queue.get())
113 |
114 | results = self._master_callback(intermediates)
115 | assert results[0][0] == 0, 'The first result should belongs to the master.'
116 |
117 | for i, res in results:
118 | if i == 0:
119 | continue
120 | self._registry[i].result.put(res)
121 |
122 | for i in range(self.nr_slaves):
123 | assert self._queue.get() is True
124 |
125 | return results[0][1]
126 |
127 | @property
128 | def nr_slaves(self):
129 | return len(self._registry)
130 |
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
31 | Note that, as all modules are isomorphism, we assign each sub-module with a context
32 | (shared among multiple copies of this module on different devices).
33 | Through this context, different copies can share some information.
34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
35 | of any slave copies.
36 | """
37 | master_copy = modules[0]
38 | nr_modules = len(list(master_copy.modules()))
39 | ctxs = [CallbackContext() for _ in range(nr_modules)]
40 |
41 | for i, module in enumerate(modules):
42 | for j, m in enumerate(module.modules()):
43 | if hasattr(m, '__data_parallel_replicate__'):
44 | m.__data_parallel_replicate__(ctxs[j], i)
45 |
46 |
47 | class DataParallelWithCallback(DataParallel):
48 | """
49 | Data Parallel with a replication callback.
50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
51 | original `replicate` function.
52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
53 | Examples:
54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
56 | # sync_bn.__data_parallel_replicate__ will be invoked.
57 | """
58 |
59 | def replicate(self, module, device_ids):
60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
61 | execute_replication_callbacks(modules)
62 | return modules
63 |
64 |
65 | def patch_replication_callback(data_parallel):
66 | """
67 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
68 | Useful when you have customized `DataParallel` implementation.
69 | Examples:
70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
72 | > patch_replication_callback(sync_bn)
73 | # this is equivalent to
74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
76 | """
77 |
78 | assert isinstance(data_parallel, DataParallel)
79 |
80 | old_replicate = data_parallel.replicate
81 |
82 | @functools.wraps(old_replicate)
83 | def new_replicate(module, device_ids):
84 | modules = old_replicate(module, device_ids)
85 | execute_replication_callbacks(modules)
86 | return modules
87 |
88 | data_parallel.replicate = new_replicate
--------------------------------------------------------------------------------
/models/deeplab/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/models/deeplab_v3p.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import sys
5 | sys.path.append('..')
6 | from models.deeplab.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
7 | from models.deeplab.aspp import build_aspp
8 | from models.deeplab.decoder import build_decoder
9 | from models.deeplab.backbone import build_backbone
10 |
11 |
12 | class DeepLab(nn.Module):
13 | def __init__(self, num_classes, backbone='resnet', output_stride=16,
14 | sync_bn=True, freeze_bn=False):
15 | super(DeepLab, self).__init__()
16 | if backbone == 'drn':
17 | output_stride = 8
18 |
19 | if sync_bn == True:
20 | BatchNorm = SynchronizedBatchNorm2d
21 | else:
22 | BatchNorm = nn.BatchNorm2d
23 |
24 | self.backbone = build_backbone(backbone, output_stride, BatchNorm)
25 | self.aspp = build_aspp(backbone, output_stride, BatchNorm)
26 | self.decoder = build_decoder(num_classes, backbone, BatchNorm)
27 |
28 | if freeze_bn:
29 | self.freeze_bn()
30 |
31 | def forward(self, input):
32 | x, low_level_feat = self.backbone(input)
33 | x = self.aspp(x)
34 | x = self.decoder(x, low_level_feat)
35 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
36 |
37 | return x
38 |
39 | def freeze_bn(self):
40 | for m in self.modules():
41 | if isinstance(m, SynchronizedBatchNorm2d):
42 | m.eval()
43 | elif isinstance(m, nn.BatchNorm2d):
44 | m.eval()
45 |
46 | def get_1x_lr_params(self):
47 | modules = [self.backbone]
48 | for i in range(len(modules)):
49 | for m in modules[i].named_modules():
50 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
51 | or isinstance(m[1], nn.BatchNorm2d):
52 | for p in m[1].parameters():
53 | if p.requires_grad:
54 | yield p
55 |
56 | def get_10x_lr_params(self):
57 | modules = [self.aspp, self.decoder]
58 | for i in range(len(modules)):
59 | for m in modules[i].named_modules():
60 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
61 | or isinstance(m[1], nn.BatchNorm2d):
62 | for p in m[1].parameters():
63 | if p.requires_grad:
64 | yield p
65 |
66 |
67 | if __name__ == "__main__":
68 | import os
69 | os.environ['CUDA_VISIBLE_DEVICES'] = '5'
70 |
71 | # model = DeepLab(backbone='mobilenet', output_stride=16, num_classes=1).cuda()
72 | # model.eval()
73 | # input = torch.rand(1, 3, 512, 512).cuda()
74 | # output = model(input)
75 | # print(output.size())
76 |
77 |
78 |
--------------------------------------------------------------------------------
/models/denseunet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import torchvision
5 | from collections import OrderedDict
6 |
7 |
8 | class DenseUnet(nn.Module):
9 | def __init__(self, in_ch=3, num_classes=3, hybrid=False):
10 | super().__init__()
11 | self.hybrid = hybrid
12 | num_init_features = 96
13 | backbone = torchvision.models.densenet161(pretrained=False)
14 | self.first_convblock = nn.Sequential(OrderedDict([
15 | ('conv0', nn.Conv2d(in_ch, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
16 | ('norm0', nn.BatchNorm2d(num_init_features)),
17 | ('relu0', nn.ReLU(inplace=True)),
18 | # ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
19 | ]))
20 | self.pool0 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
21 |
22 | self.denseblock1 = backbone.features.denseblock1
23 | self.transition1 = backbone.features.transition1
24 | self.denseblock2 = backbone.features.denseblock2
25 | self.transition2 = backbone.features.transition2
26 | self.denseblock3 = backbone.features.denseblock3
27 | self.transition3 = backbone.features.transition3
28 | self.denseblock4 = backbone.features.denseblock4
29 | self.bn5 = backbone.features.norm5
30 | self.relu = nn.ReLU(inplace=True)
31 |
32 | self.conv3 = nn.Conv2d(2112, 2208, kernel_size=1, stride=1)
33 |
34 | self.convblock43 = ConvBlock(2208, 768)
35 | self.convblock32 = ConvBlock(768, 384, kernel_size=3, stride=1, padding=1)
36 | self.convblock21 = ConvBlock(384, 96, kernel_size=3, stride=1, padding=1)
37 | self.convblock10 = ConvBlock(96, 96, kernel_size=3, stride=1, padding=1)
38 | self.convblock00 = ConvBlock(96, 64, kernel_size=3, stride=1, padding=1)
39 | self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1, stride=1)
40 |
41 | def forward(self, x):
42 | db0 = self.first_convblock(x)
43 | x = self.pool0(db0)
44 | db1 = self.denseblock1(x)
45 | x = self.transition1(db1)
46 | db2 = self.denseblock2(x)
47 | x = self.transition2(db2)
48 | db3 = self.denseblock3(x)
49 | x = self.transition3(db3)
50 | x = self.denseblock4(x)
51 | x = self.bn5(x)
52 | db4 = self.relu(x)
53 |
54 | up4 = torch.nn.functional.interpolate(db4, scale_factor=2, mode='bilinear', align_corners=True)
55 | db3 = self.conv3(db3)
56 | db43 = torch.add(up4, db3)
57 | db43 = self.convblock43(db43)
58 |
59 | up43 = torch.nn.functional.interpolate(db43, scale_factor=2, mode='bilinear', align_corners=True)
60 | db432 = torch.add(up43, db2)
61 | db432 = self.convblock32(db432)
62 |
63 | up432 = torch.nn.functional.interpolate(db432, scale_factor=2, mode='bilinear', align_corners=True)
64 | db4321 = torch.add(up432, db1)
65 | db4321 = self.convblock21(db4321)
66 |
67 | up4321 = torch.nn.functional.interpolate(db4321, scale_factor=2, mode='bilinear', align_corners=True)
68 | db43210 = torch.add(up4321, db0)
69 | db43210 = self.convblock10(db43210)
70 |
71 | up43210 = torch.nn.functional.interpolate(db43210, scale_factor=2, mode='bilinear', align_corners=True)
72 | db43210 = self.convblock00(up43210)
73 |
74 | out = self.final_conv(db43210)
75 |
76 | if self.hybrid:
77 | return db43210, out
78 | else:
79 | return out
80 |
81 |
82 | class ConvBlock(torch.nn.Module):
83 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
84 | super().__init__()
85 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
86 | self.bn = nn.BatchNorm2d(out_channels)
87 | self.relu = nn.ReLU()
88 |
89 | def forward(self, input):
90 | x = self.conv1(input)
91 | x = self.bn(x)
92 | x = self.relu(x)
93 | return x
94 |
95 |
96 |
97 | class Dense_Block(nn.Module):
98 | ''' Build a dense_block where the output of each conv_block is fed to subsequent ones
99 | # Arguments
100 | x: input tensor
101 | stage: index for dense block
102 | nb_layers: the number of layers of conv_block to append to the model.
103 | nb_filter: number of filters
104 | growth_rate: growth rate
105 | dropout_rate: dropout rate
106 | weight_decay: weight decay factor
107 | grow_nb_filters: flag to decide to allow number of filters to grow
108 | '''
109 | def __init__(self, x, stage, nb_layers, nb_filter, growth_rate, dropout_rate=None, weight_decay=1e-4,
110 | grow_nb_filters=True):
111 | super().__init__()
112 | torchvision.models.densenet161()
113 |
114 | def forward(self, x):
115 | pass
116 |
117 |
118 | if __name__ == '__main__':
119 | # model = torchvision.models.densenet161()
120 | # for name, param in model.named_parameters():
121 | # print(name)
122 |
123 | import os
124 | os.environ['CUDA_VISIBLE_DEVICES'] = '4'
125 | # model_bb = torchvision.models.densenet161(pretrained=True).cuda()
126 |
127 | # model = DenseUnet().cuda()
128 | # data = torch.randn((2, 3, 224, 224)).cuda()
129 | # pred = model(data)
130 | # print(pred.shape)
131 |
132 |
133 |
--------------------------------------------------------------------------------
/models/init_weights.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 |
5 | def weights_init_normal(m):
6 | classname = m.__class__.__name__
7 | #print(classname)
8 | if classname.find('Conv') != -1:
9 | init.normal_(m.weight.data, 0.0, 0.02)
10 | elif classname.find('Linear') != -1:
11 | init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | init.normal_(m.weight.data, 1.0, 0.02)
14 | init.constant_(m.bias.data, 0.0)
15 |
16 |
17 | def weights_init_xavier(m):
18 | classname = m.__class__.__name__
19 | #print(classname)
20 | if classname.find('Conv') != -1:
21 | init.xavier_normal_(m.weight.data, gain=1)
22 | elif classname.find('Linear') != -1:
23 | init.xavier_normal_(m.weight.data, gain=1)
24 | elif classname.find('BatchNorm') != -1:
25 | init.normal_(m.weight.data, 1.0, 0.02)
26 | init.constant_(m.bias.data, 0.0)
27 |
28 |
29 | def weights_init_kaiming(m):
30 | classname = m.__class__.__name__
31 | #print(classname)
32 | if classname.find('Conv') != -1:
33 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
34 | elif classname.find('Linear') != -1:
35 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
36 | elif classname.find('BatchNorm') != -1:
37 | init.normal_(m.weight.data, 1.0, 0.02)
38 | init.constant_(m.bias.data, 0.0)
39 |
40 |
41 | def weights_init_orthogonal(m):
42 | classname = m.__class__.__name__
43 | #print(classname)
44 | if classname.find('Conv') != -1:
45 | init.orthogonal_(m.weight.data, gain=1)
46 | elif classname.find('Linear') != -1:
47 | init.orthogonal_(m.weight.data, gain=1)
48 | elif classname.find('BatchNorm') != -1:
49 | init.normal_(m.weight.data, 1.0, 0.02)
50 | init.constant_(m.bias.data, 0.0)
51 |
52 |
53 | def init_weights(net, init_type='normal'):
54 | #print('initialization method [%s]' % init_type)
55 | if init_type == 'normal':
56 | net.apply(weights_init_normal)
57 | elif init_type == 'xavier':
58 | net.apply(weights_init_xavier)
59 | elif init_type == 'kaiming':
60 | net.apply(weights_init_kaiming)
61 | elif init_type == 'orthogonal':
62 | net.apply(weights_init_orthogonal)
63 | else:
64 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
65 |
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.init_weights import init_weights
5 |
6 |
7 | class unetConv2(nn.Module):
8 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
9 | super(unetConv2, self).__init__()
10 | self.n = n
11 | self.ks = ks
12 | self.stride = stride
13 | self.padding = padding
14 | s = stride
15 | p = padding
16 | if is_batchnorm:
17 | for i in range(1, n + 1):
18 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
19 | nn.BatchNorm2d(out_size),
20 | nn.ReLU(inplace=True), )
21 | setattr(self, 'conv%d' % i, conv)
22 | in_size = out_size
23 |
24 | else:
25 | for i in range(1, n + 1):
26 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
27 | nn.ReLU(inplace=True), )
28 | setattr(self, 'conv%d' % i, conv)
29 | in_size = out_size
30 |
31 | # initialise the blocks
32 | for m in self.children():
33 | init_weights(m, init_type='kaiming')
34 |
35 | def forward(self, inputs):
36 | x = inputs
37 | for i in range(1, self.n + 1):
38 | conv = getattr(self, 'conv%d' % i)
39 | x = conv(x)
40 |
41 | return x
42 |
43 | class unetUp(nn.Module):
44 | def __init__(self, in_size, out_size, is_deconv, n_concat=2):
45 | super(unetUp, self).__init__()
46 | # self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False)
47 | self.conv = unetConv2(out_size*2, out_size, False)
48 | if is_deconv:
49 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1)
50 | else:
51 | self.up = nn.UpsamplingBilinear2d(scale_factor=2)
52 |
53 | # initialise the blocks
54 | for m in self.children():
55 | if m.__class__.__name__.find('unetConv2') != -1: continue
56 | init_weights(m, init_type='kaiming')
57 |
58 | def forward(self, inputs0, *input):
59 | # print(self.n_concat)
60 | # print(input)
61 | outputs0 = self.up(inputs0)
62 | for i in range(len(input)):
63 | outputs0 = torch.cat([outputs0, input[i]], 1)
64 | return self.conv(outputs0)
65 |
66 | class unetUp_origin(nn.Module):
67 | def __init__(self, in_size, out_size, is_deconv, n_concat=2):
68 | super(unetUp_origin, self).__init__()
69 | # self.conv = unetConv2(out_size*2, out_size, False)
70 | if is_deconv:
71 | self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False)
72 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1)
73 | else:
74 | self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False)
75 | self.up = nn.UpsamplingBilinear2d(scale_factor=2)
76 |
77 | # initialise the blocks
78 | for m in self.children():
79 | if m.__class__.__name__.find('unetConv2') != -1: continue
80 | init_weights(m, init_type='kaiming')
81 |
82 | def forward(self, inputs0, *input):
83 | # print(self.n_concat)
84 | # print(input)
85 | outputs0 = self.up(inputs0)
86 | for i in range(len(input)):
87 | outputs0 = torch.cat([outputs0, input[i]], 1)
88 | return self.conv(outputs0)
89 |
--------------------------------------------------------------------------------
/models/mynn.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from config import cfg
7 | import torch.nn as nn
8 | from math import sqrt
9 | import torch
10 | from torch.autograd.function import InplaceFunction
11 | from itertools import repeat
12 | from torch.nn.modules import Module
13 | from torch.utils.checkpoint import checkpoint
14 |
15 |
16 | def Norm2d(in_channels):
17 | """
18 | Custom Norm Function to allow flexible switching
19 | """
20 | layer = getattr(cfg.MODEL,'BNFUNC')
21 | normalizationLayer = layer(in_channels)
22 | return normalizationLayer
23 |
24 |
25 | def initialize_weights(*models):
26 | for model in models:
27 | for module in model.modules():
28 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
29 | nn.init.kaiming_normal(module.weight)
30 | if module.bias is not None:
31 | module.bias.data.zero_()
32 | elif isinstance(module, nn.BatchNorm2d):
33 | module.weight.data.fill_(1)
34 | module.bias.data.zero_()
35 |
--------------------------------------------------------------------------------
/models/new2net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from backbone.origin.from_origin import Backbone_ResNet50_in3, Backbone_VGG16_in3
5 | from module.BaseBlocks import BasicConv2d
6 | from module.MyModule import AIM, SIM
7 | from utils.tensor_ops import cus_sample, upsample_add
8 |
9 |
10 | class MINet_VGG16(nn.Module):
11 | def __init__(self):
12 | super(MINet_VGG16, self).__init__()
13 | self.upsample_add = upsample_add
14 | self.upsample = cus_sample
15 |
16 | (
17 | self.encoder1,
18 | self.encoder2,
19 | self.encoder4,
20 | self.encoder8,
21 | self.encoder16,
22 | ) = Backbone_VGG16_in3()
23 |
24 | self.trans = AIM((64, 128, 256, 512, 512), (32, 64, 64, 64, 64))
25 |
26 | self.sim16 = SIM(64, 32)
27 | self.sim8 = SIM(64, 32)
28 | self.sim4 = SIM(64, 32)
29 | self.sim2 = SIM(64, 32)
30 | self.sim1 = SIM(32, 16)
31 |
32 | self.upconv16 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1)
33 | self.upconv8 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1)
34 | self.upconv4 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1)
35 | self.upconv2 = BasicConv2d(64, 32, kernel_size=3, stride=1, padding=1)
36 | self.upconv1 = BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
37 |
38 | self.classifier = nn.Conv2d(32, 1, 1)
39 |
40 | def forward(self, in_data):
41 | in_data_1 = self.encoder1(in_data)
42 | in_data_2 = self.encoder2(in_data_1)
43 | in_data_4 = self.encoder4(in_data_2)
44 | in_data_8 = self.encoder8(in_data_4)
45 | in_data_16 = self.encoder16(in_data_8)
46 |
47 | in_data_1, in_data_2, in_data_4, in_data_8, in_data_16 = self.trans(
48 | in_data_1, in_data_2, in_data_4, in_data_8, in_data_16
49 | )
50 |
51 | out_data_16 = self.upconv16(self.sim16(in_data_16)) # 1024
52 |
53 | out_data_8 = self.upsample_add(out_data_16, in_data_8)
54 | out_data_8 = self.upconv8(self.sim8(out_data_8)) # 512
55 |
56 | out_data_4 = self.upsample_add(out_data_8, in_data_4)
57 | out_data_4 = self.upconv4(self.sim4(out_data_4)) # 256
58 |
59 | out_data_2 = self.upsample_add(out_data_4, in_data_2)
60 | out_data_2 = self.upconv2(self.sim2(out_data_2)) # 64
61 |
62 | out_data_1 = self.upsample_add(out_data_2, in_data_1)
63 | out_data_1 = self.upconv1(self.sim1(out_data_1)) # 32
64 |
65 | out_data = self.classifier(out_data_1)
66 |
67 | return out_data
68 |
69 |
70 | class MINet_Res50(nn.Module):
71 | def __init__(self):
72 | super(MINet_Res50, self).__init__()
73 | self.div_2, self.div_4, self.div_8, self.div_16, self.div_32 = Backbone_ResNet50_in3()
74 |
75 | self.upsample_add = upsample_add
76 | self.upsample = cus_sample
77 |
78 | self.trans = AIM(iC_list=(64, 256, 512, 1024, 2048), oC_list=(64, 64, 64, 64, 64))
79 |
80 | self.sim32 = SIM(64, 32)
81 | self.sim16 = SIM(64, 32)
82 | self.sim8 = SIM(64, 32)
83 | self.sim4 = SIM(64, 32)
84 | self.sim2 = SIM(64, 32)
85 |
86 | self.upconv32 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1)
87 | self.upconv16 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1)
88 | self.upconv8 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1)
89 | self.upconv4 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1)
90 | self.upconv2 = BasicConv2d(64, 32, kernel_size=3, stride=1, padding=1)
91 | self.upconv1 = BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
92 |
93 | self.classifier = nn.Conv2d(32, 1, 1)
94 |
95 | def forward(self, in_data):
96 | in_data_2 = self.div_2(in_data)
97 | in_data_4 = self.div_4(in_data_2)
98 | in_data_8 = self.div_8(in_data_4)
99 | in_data_16 = self.div_16(in_data_8)
100 | in_data_32 = self.div_32(in_data_16)
101 |
102 | in_data_2, in_data_4, in_data_8, in_data_16, in_data_32 = self.trans(
103 | in_data_2, in_data_4, in_data_8, in_data_16, in_data_32
104 | )
105 |
106 | out_data_32 = self.upconv32(self.sim32(in_data_32)) # 1024
107 |
108 | out_data_16 = self.upsample_add(out_data_32, in_data_16) # 1024
109 | out_data_16 = self.upconv16(self.sim16(out_data_16))
110 |
111 | out_data_8 = self.upsample_add(out_data_16, in_data_8)
112 | out_data_8 = self.upconv8(self.sim8(out_data_8)) # 512
113 |
114 | out_data_4 = self.upsample_add(out_data_8, in_data_4)
115 | out_data_4 = self.upconv4(self.sim4(out_data_4)) # 256
116 |
117 | out_data_2 = self.upsample_add(out_data_4, in_data_2)
118 | out_data_2 = self.upconv2(self.sim2(out_data_2)) # 64
119 |
120 | out_data_1 = self.upconv1(self.upsample(out_data_2, scale_factor=2)) # 32
121 | out_data = self.classifier(out_data_1)
122 |
123 | return out_data
124 |
125 |
126 | if __name__ == "__main__":
127 | in_data = torch.randn((1, 3, 320, 320))
128 | net = MINet_VGG16()
129 | print(sum([x.nelement() for x in net.parameters()]))
--------------------------------------------------------------------------------
/models/norm.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from config import cfg
7 | import torch.nn as nn
8 | from math import sqrt
9 | import torch
10 | from torch.autograd.function import InplaceFunction
11 | from itertools import repeat
12 | from torch.nn.modules import Module
13 | from torch.utils.checkpoint import checkpoint
14 |
15 |
16 | def Norm2d(in_channels):
17 | """
18 | Custom Norm Function to allow flexible switching
19 | """
20 | layer = getattr(cfg.MODEL,'BNFUNC')
21 | normalizationLayer = layer(in_channels)
22 | return normalizationLayer
23 |
24 |
25 | def initialize_weights(*models):
26 | for model in models:
27 | for module in model.modules():
28 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
29 | nn.init.kaiming_normal(module.weight)
30 | if module.bias is not None:
31 | module.bias.data.zero_()
32 | elif isinstance(module, nn.BatchNorm2d):
33 | module.weight.data.fill_(1)
34 | module.bias.data.zero_()
35 |
--------------------------------------------------------------------------------
/models/pretrain/SAMNet_with_ImageNet_pretrain.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLWK/CGRNet/a9a65fa192cc9888e7861755313b8b3ac80fa512/models/pretrain/SAMNet_with_ImageNet_pretrain.pth
--------------------------------------------------------------------------------
/models/pretrain/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import ModelBuilder, SegmentationModule, SAUNet, VGG19UNet, VGG19UNet_without_boudary, VGGUNet
2 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import math
6 | from lib.nn import SynchronizedBatchNorm2d
7 |
8 | try:
9 | from urllib import urlretrieve
10 | except ImportError:
11 | from urllib.request import urlretrieve
12 |
13 |
14 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon!
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth',
19 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
20 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth'
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1):
25 | "3x3 convolution with padding"
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 | padding=1, bias=False)
28 |
29 |
30 | class BasicBlock(nn.Module):
31 | expansion = 1
32 |
33 | def __init__(self, inplanes, planes, stride=1, downsample=None):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(inplanes, planes, stride)
36 | self.bn1 = SynchronizedBatchNorm2d(planes)
37 | self.relu = nn.ReLU(inplace=True)
38 | self.conv2 = conv3x3(planes, planes)
39 | self.bn2 = SynchronizedBatchNorm2d(planes)
40 | self.downsample = downsample
41 | self.stride = stride
42 |
43 | def forward(self, x):
44 | residual = x
45 |
46 | out = self.conv1(x)
47 | out = self.bn1(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv2(out)
51 | out = self.bn2(out)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 |
56 | out += residual
57 | out = self.relu(out)
58 |
59 | return out
60 |
61 |
62 | class Bottleneck(nn.Module):
63 | expansion = 4
64 |
65 | def __init__(self, inplanes, planes, stride=1, downsample=None):
66 | super(Bottleneck, self).__init__()
67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68 | self.bn1 = SynchronizedBatchNorm2d(planes)
69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70 | padding=1, bias=False)
71 | self.bn2 = SynchronizedBatchNorm2d(planes)
72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
73 | self.bn3 = SynchronizedBatchNorm2d(planes * 4)
74 | self.relu = nn.ReLU(inplace=True)
75 | self.downsample = downsample
76 | self.stride = stride
77 |
78 | def forward(self, x):
79 | residual = x
80 |
81 | out = self.conv1(x)
82 | out = self.bn1(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv2(out)
86 | out = self.bn2(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv3(out)
90 | out = self.bn3(out)
91 |
92 | if self.downsample is not None:
93 | residual = self.downsample(x)
94 |
95 | out += residual
96 | out = self.relu(out)
97 |
98 | return out
99 |
100 |
101 | class ResNet(nn.Module):
102 |
103 | def __init__(self, block, layers, num_classes=1000):
104 | self.inplanes = 128
105 | super(ResNet, self).__init__()
106 | self.conv1 = conv3x3(3, 64, stride=2)
107 | self.bn1 = SynchronizedBatchNorm2d(64)
108 | self.relu1 = nn.ReLU(inplace=True)
109 | self.conv2 = conv3x3(64, 64)
110 | self.bn2 = SynchronizedBatchNorm2d(64)
111 | self.relu2 = nn.ReLU(inplace=True)
112 | self.conv3 = conv3x3(64, 128)
113 | self.bn3 = SynchronizedBatchNorm2d(128)
114 | self.relu3 = nn.ReLU(inplace=True)
115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
116 |
117 | self.layer1 = self._make_layer(block, 64, layers[0])
118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
121 | self.avgpool = nn.AvgPool2d(7, stride=1)
122 | self.fc = nn.Linear(512 * block.expansion, num_classes)
123 |
124 | for m in self.modules():
125 | if isinstance(m, nn.Conv2d):
126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
127 | m.weight.data.normal_(0, math.sqrt(2. / n))
128 | elif isinstance(m, SynchronizedBatchNorm2d):
129 | m.weight.data.fill_(1)
130 | m.bias.data.zero_()
131 |
132 | def _make_layer(self, block, planes, blocks, stride=1):
133 | downsample = None
134 | if stride != 1 or self.inplanes != planes * block.expansion:
135 | downsample = nn.Sequential(
136 | nn.Conv2d(self.inplanes, planes * block.expansion,
137 | kernel_size=1, stride=stride, bias=False),
138 | SynchronizedBatchNorm2d(planes * block.expansion),
139 | )
140 |
141 | layers = []
142 | layers.append(block(self.inplanes, planes, stride, downsample))
143 | self.inplanes = planes * block.expansion
144 | for i in range(1, blocks):
145 | layers.append(block(self.inplanes, planes))
146 |
147 | return nn.Sequential(*layers)
148 |
149 | def forward(self, x):
150 | x = self.relu1(self.bn1(self.conv1(x)))
151 | x = self.relu2(self.bn2(self.conv2(x)))
152 | x = self.relu3(self.bn3(self.conv3(x)))
153 | x = self.maxpool(x)
154 |
155 | x = self.layer1(x)
156 | x = self.layer2(x)
157 | x = self.layer3(x)
158 | x = self.layer4(x)
159 |
160 | x = self.avgpool(x)
161 | x = x.view(x.size(0), -1)
162 | x = self.fc(x)
163 |
164 | return x
165 |
166 | def resnet18(pretrained=False, **kwargs):
167 | """Constructs a ResNet-18 model.
168 | Args:
169 | pretrained (bool): If True, returns a model pre-trained on ImageNet
170 | """
171 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
172 | if pretrained:
173 | model.load_state_dict(load_url(model_urls['resnet18']))
174 | return model
175 |
176 | '''
177 | def resnet34(pretrained=False, **kwargs):
178 | """Constructs a ResNet-34 model.
179 | Args:
180 | pretrained (bool): If True, returns a model pre-trained on ImageNet
181 | """
182 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
183 | if pretrained:
184 | model.load_state_dict(load_url(model_urls['resnet34']))
185 | return model
186 | '''
187 |
188 | def resnet50(pretrained=False, **kwargs):
189 | """Constructs a ResNet-50 model.
190 | Args:
191 | pretrained (bool): If True, returns a model pre-trained on ImageNet
192 | """
193 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
194 | if pretrained:
195 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
196 | return model
197 |
198 |
199 | def resnet101(pretrained=False, **kwargs):
200 | """Constructs a ResNet-101 model.
201 | Args:
202 | pretrained (bool): If True, returns a model pre-trained on ImageNet
203 | """
204 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
205 | if pretrained:
206 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False)
207 | return model
208 |
209 | # def resnet152(pretrained=False, **kwargs):
210 | # """Constructs a ResNet-152 model.
211 | #
212 | # Args:
213 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
214 | # """
215 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
216 | # if pretrained:
217 | # model.load_state_dict(load_url(model_urls['resnet152']))
218 | # return model
219 |
220 | def load_url(url, model_dir='./pretrained', map_location=None):
221 | if not os.path.exists(model_dir):
222 | os.makedirs(model_dir)
223 | filename = url.split('/')[-1]
224 | cached_file = os.path.join(model_dir, filename)
225 | if not os.path.exists(cached_file):
226 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
227 | urlretrieve(url, cached_file)
228 | return torch.load(cached_file, map_location=map_location)
229 |
--------------------------------------------------------------------------------
/models/segnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from collections import OrderedDict
5 |
6 | class SegNet(nn.Module):
7 | def __init__(self,input_nbr,label_nbr):
8 | super(SegNet, self).__init__()
9 |
10 | batchNorm_momentum = 0.1
11 |
12 | self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1)
13 | self.bn11 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
14 | self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
15 | self.bn12 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
16 |
17 | self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
18 | self.bn21 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
19 | self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
20 | self.bn22 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
21 |
22 | self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
23 | self.bn31 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
24 | self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
25 | self.bn32 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
26 | self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
27 | self.bn33 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
28 |
29 | self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
30 | self.bn41 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
31 | self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
32 | self.bn42 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
33 | self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
34 | self.bn43 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
35 |
36 | self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
37 | self.bn51 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
38 | self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
39 | self.bn52 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
40 | self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
41 | self.bn53 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
42 |
43 | self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
44 | self.bn53d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
45 | self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
46 | self.bn52d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
47 | self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
48 | self.bn51d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
49 |
50 | self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
51 | self.bn43d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
52 | self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
53 | self.bn42d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
54 | self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
55 | self.bn41d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
56 |
57 | self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
58 | self.bn33d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
59 | self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
60 | self.bn32d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
61 | self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1)
62 | self.bn31d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
63 |
64 | self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
65 | self.bn22d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
66 | self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
67 | self.bn21d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
68 |
69 | self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
70 | self.bn12d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
71 | self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1)
72 | self.sigmoid = nn.Sigmoid()
73 |
74 | def forward(self, x):
75 |
76 | # Stage 1
77 | x11 = F.relu(self.bn11(self.conv11(x)))
78 | x12 = F.relu(self.bn12(self.conv12(x11)))
79 | x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=1,return_indices=True)
80 |
81 | # Stage 2
82 | x21 = F.relu(self.bn21(self.conv21(x1p)))
83 | x22 = F.relu(self.bn22(self.conv22(x21)))
84 | x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=1,return_indices=True)
85 |
86 | # Stage 3
87 | x31 = F.relu(self.bn31(self.conv31(x2p)))
88 | x32 = F.relu(self.bn32(self.conv32(x31)))
89 | x33 = F.relu(self.bn33(self.conv33(x32)))
90 | x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=1,return_indices=True)
91 |
92 | # Stage 4
93 | x41 = F.relu(self.bn41(self.conv41(x3p)))
94 | x42 = F.relu(self.bn42(self.conv42(x41)))
95 | x43 = F.relu(self.bn43(self.conv43(x42)))
96 | x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=1,return_indices=True)
97 |
98 | # Stage 5
99 | x51 = F.relu(self.bn51(self.conv51(x4p)))
100 | x52 = F.relu(self.bn52(self.conv52(x51)))
101 | x53 = F.relu(self.bn53(self.conv53(x52)))
102 | x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=1,return_indices=True)
103 |
104 |
105 | # Stage 5d
106 | x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=1)
107 | x53d = F.relu(self.bn53d(self.conv53d(x5d)))
108 | x52d = F.relu(self.bn52d(self.conv52d(x53d)))
109 | x51d = F.relu(self.bn51d(self.conv51d(x52d)))
110 |
111 | # Stage 4d
112 | x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=1)
113 | x43d = F.relu(self.bn43d(self.conv43d(x4d)))
114 | x42d = F.relu(self.bn42d(self.conv42d(x43d)))
115 | x41d = F.relu(self.bn41d(self.conv41d(x42d)))
116 |
117 | # Stage 3d
118 | x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=1)
119 | x33d = F.relu(self.bn33d(self.conv33d(x3d)))
120 | x32d = F.relu(self.bn32d(self.conv32d(x33d)))
121 | x31d = F.relu(self.bn31d(self.conv31d(x32d)))
122 |
123 | # Stage 2d
124 | x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=1)
125 | x22d = F.relu(self.bn22d(self.conv22d(x2d)))
126 | x21d = F.relu(self.bn21d(self.conv21d(x22d)))
127 |
128 | # Stage 1d
129 | x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=1)
130 | x12d = F.relu(self.bn12d(self.conv12d(x1d)))
131 | x11d = self.conv11d(x12d)
132 | # x11d = self.sigmoid(x11d)
133 | return x11d
--------------------------------------------------------------------------------
/models/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import warnings
5 | warnings.filterwarnings(action='ignore')
6 |
7 |
8 | class double_conv(nn.Module):
9 | '''(conv => BN => ReLU) * 2'''
10 |
11 | def __init__(self, in_ch, out_ch):
12 | super(double_conv, self).__init__()
13 | self.conv = nn.Sequential(
14 | nn.Conv2d(in_ch, out_ch, 3, padding=1),
15 | nn.BatchNorm2d(out_ch),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(out_ch, out_ch, 3, padding=1),
18 | nn.BatchNorm2d(out_ch),
19 | nn.ReLU(inplace=True)
20 | )
21 |
22 | def forward(self, x):
23 | x = self.conv(x)
24 | return x
25 |
26 |
27 | class inconv(nn.Module):
28 | def __init__(self, in_ch, out_ch):
29 | super(inconv, self).__init__()
30 | self.conv = double_conv(in_ch, out_ch)
31 |
32 | def forward(self, x):
33 | x = self.conv(x)
34 | return x
35 |
36 |
37 | class down(nn.Module):
38 | def __init__(self, in_ch, out_ch):
39 | super(down, self).__init__()
40 | self.mpconv = nn.Sequential(
41 | nn.MaxPool2d(2),
42 | double_conv(in_ch, out_ch)
43 | )
44 |
45 | def forward(self, x):
46 | x = self.mpconv(x)
47 | return x
48 |
49 |
50 | class up(nn.Module):
51 | def __init__(self, in_ch, out_ch, bilinear=True):
52 | super(up, self).__init__()
53 |
54 | if bilinear:
55 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
56 | else:
57 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
58 |
59 | self.conv = double_conv(in_ch, out_ch)
60 |
61 | def forward(self,x1, x2):
62 | x1 = self.up(x1)
63 |
64 | diffY = x2.size()[2] - x1.size()[2]
65 | diffX = x2.size()[3] - x1.size()[3]
66 |
67 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
68 | diffY // 2, diffY - diffY // 2))
69 |
70 | x = torch.cat([x2, x1], dim=1)
71 | x = self.conv(x)
72 | return x
73 |
74 |
75 | class outconv(nn.Module):
76 | def __init__(self, in_ch, out_ch):
77 | super(outconv, self).__init__()
78 | self.conv = nn.Conv2d(in_ch, out_ch, 1)
79 |
80 | def forward(self, x):
81 | x = self.conv(x)
82 | return x
83 |
84 |
85 | class GEncoder(nn.Module):
86 | def __init__(self, n_channels, n_classes, deep_supervision = False):
87 | super(GEncoder, self).__init__()
88 | self.deep_supervision = deep_supervision
89 | self.inc = inconv(n_channels, 64)
90 | self.down1 = down(64, 128)
91 | self.down2 = down(128, 256)
92 | self.down3 = down(256, 512)
93 | self.down4 = down(512, 1024)
94 |
95 |
96 | def forward(self, x):
97 | x1 = self.inc(x)
98 | x2 = self.down1(x1)
99 | x3 = self.down2(x2)
100 | x4 = self.down3(x3)
101 | x5 = self.down4(x4)
102 |
103 |
104 | if self.deep_supervision:
105 | x11 = F.interpolate(self.dsoutc1(x11), x0.shape[2:], mode='bilinear')
106 | x22 = F.interpolate(self.dsoutc2(x22), x0.shape[2:], mode='bilinear')
107 | x33 = F.interpolate(self.dsoutc3(x33), x0.shape[2:], mode='bilinear')
108 | x44 = F.interpolate(self.dsoutc4(x44), x0.shape[2:], mode='bilinear')
109 |
110 | return x0, x11, x22, x33, x44
111 | else:
112 | return x5
113 |
114 | if __name__ == '__main__':
115 | ras =GEncoder(n_channels=1, n_classes=1).cuda()
116 | input_tensor = torch.randn(4, 1, 96, 96).cuda()
117 | out = ras(input_tensor)
118 | print(out[0].shape)
119 |
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import warnings
5 | warnings.filterwarnings(action='ignore')
6 |
7 |
8 | class double_conv(nn.Module):
9 | '''(conv => BN => ReLU) * 2'''
10 |
11 | def __init__(self, in_ch, out_ch):
12 | super(double_conv, self).__init__()
13 | self.conv = nn.Sequential(
14 | nn.Conv2d(in_ch, out_ch, 3, padding=1),
15 | nn.BatchNorm2d(out_ch),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(out_ch, out_ch, 3, padding=1),
18 | nn.BatchNorm2d(out_ch),
19 | nn.ReLU(inplace=True)
20 | )
21 |
22 | def forward(self, x):
23 | x = self.conv(x)
24 | return x
25 |
26 |
27 | class inconv(nn.Module):
28 | def __init__(self, in_ch, out_ch):
29 | super(inconv, self).__init__()
30 | self.conv = double_conv(in_ch, out_ch)
31 |
32 | def forward(self, x):
33 | x = self.conv(x)
34 | return x
35 |
36 |
37 | class down(nn.Module):
38 | def __init__(self, in_ch, out_ch):
39 | super(down, self).__init__()
40 | self.mpconv = nn.Sequential(
41 | nn.MaxPool2d(2),
42 | double_conv(in_ch, out_ch)
43 | )
44 |
45 | def forward(self, x):
46 | x = self.mpconv(x)
47 | return x
48 |
49 |
50 | class up(nn.Module):
51 | def __init__(self, in_ch, out_ch, bilinear=True):
52 | super(up, self).__init__()
53 |
54 | if bilinear:
55 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
56 | else:
57 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
58 |
59 | self.conv = double_conv(in_ch, out_ch)
60 |
61 | def forward(self,x1, x2):
62 | x1 = self.up(x1)
63 |
64 | diffY = x2.size()[2] - x1.size()[2]
65 | diffX = x2.size()[3] - x1.size()[3]
66 |
67 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
68 | diffY // 2, diffY - diffY // 2))
69 |
70 | x = torch.cat([x2, x1], dim=1)
71 | x = self.conv(x)
72 | return x
73 |
74 |
75 | class outconv(nn.Module):
76 | def __init__(self, in_ch, out_ch):
77 | super(outconv, self).__init__()
78 | self.conv = nn.Conv2d(in_ch, out_ch, 1)
79 |
80 | def forward(self, x):
81 | x = self.conv(x)
82 | return x
83 |
84 |
85 | class UNet(nn.Module):
86 | def __init__(self, n_channels, n_classes, deep_supervision = False):
87 | super(UNet, self).__init__()
88 | self.deep_supervision = deep_supervision
89 | self.inc = inconv(n_channels, 64)
90 | self.down1 = down(64, 128)
91 | self.down2 = down(128, 256)
92 | self.down3 = down(256, 512)
93 | self.down4 = down(512, 512)
94 | self.up1 = up(1024, 256)
95 | self.up2 = up(512, 128)
96 | self.up3 = up(256, 64)
97 | self.up4 = up(128, 64)
98 | self.outc = outconv(64, n_classes)
99 |
100 | self.dsoutc4 = outconv(256, n_classes)
101 | self.dsoutc3 = outconv(128, n_classes)
102 | self.dsoutc2 = outconv(64, n_classes)
103 | self.dsoutc1 = outconv(64, n_classes)
104 |
105 | def forward(self, x):
106 | x1 = self.inc(x)
107 | # print(x1.shape)
108 | x2 = self.down1(x1)
109 | # print(x2.shape)
110 | x3 = self.down2(x2)
111 | # print(x3.shape)
112 | x4 = self.down3(x3)
113 | # print(x4.shape)
114 | x5 = self.down4(x4)
115 | # print(x5.shape)
116 | x44 = self.up1(x5, x4)
117 | x33 = self.up2(x44, x3)
118 | x22 = self.up3(x33, x2)
119 | x11 = self.up4(x22, x1)
120 | x0 = self.outc(x11)
121 | if self.deep_supervision:
122 | x11 = F.interpolate(self.dsoutc1(x11), x0.shape[2:], mode='bilinear')
123 | x22 = F.interpolate(self.dsoutc2(x22), x0.shape[2:], mode='bilinear')
124 | x33 = F.interpolate(self.dsoutc3(x33), x0.shape[2:], mode='bilinear')
125 | x44 = F.interpolate(self.dsoutc4(x44), x0.shape[2:], mode='bilinear')
126 |
127 | return x0, x11, x22, x33, x44
128 | else:
129 | return x0
130 |
131 |
132 |
133 | if __name__ == '__main__':
134 | ras = UNet(n_channels=1, n_classes=1).cuda()
135 | input_tensor = torch.randn(4, 1, 96, 96).cuda()
136 | out = ras(input_tensor)
137 | print(out[0].shape)
138 |
139 |
140 |
--------------------------------------------------------------------------------
/models/wassp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import models
3 | import torch.nn as nn
4 |
5 | # from resnet import resnet34
6 | # import resnet
7 | from torch.nn import functional as F
8 | class ConvBnRelu(nn.Module):
9 | def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
10 | groups=1, has_bn=True, norm_layer=nn.BatchNorm2d,
11 | has_relu=True, inplace=True, has_bias=False):
12 | super(ConvBnRelu, self).__init__()
13 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
14 | stride=stride, padding=pad,
15 | dilation=dilation, groups=groups, bias=has_bias)
16 | self.has_bn = has_bn
17 | if self.has_bn:
18 | self.bn = nn.BatchNorm2d(out_planes)
19 | self.has_relu = has_relu
20 | if self.has_relu:
21 | self.relu = nn.ReLU(inplace=inplace)
22 |
23 | def forward(self, x):
24 | x = self.conv(x)
25 | if self.has_bn:
26 | x = self.bn(x)
27 | if self.has_relu:
28 | x = self.relu(x)
29 |
30 | return x
31 |
32 |
33 |
34 | class SAPP(nn.Module):
35 | def __init__(self, in_channels):
36 | super(SAPP, self).__init__()
37 | self.conv3x3=nn.Conv2d(in_channels=in_channels, out_channels=in_channels,dilation=1,kernel_size=3, padding=1)
38 |
39 | self.bn=nn.ModuleList([nn.BatchNorm2d(in_channels),nn.BatchNorm2d(in_channels),nn.BatchNorm2d(in_channels)])
40 | self.conv1x1=nn.ModuleList([nn.Conv2d(in_channels=2*in_channels, out_channels=in_channels,dilation=1,kernel_size=1, padding=0),
41 | nn.Conv2d(in_channels=2*in_channels, out_channels=in_channels,dilation=1,kernel_size=1, padding=0)])
42 | self.conv3x3_1=nn.ModuleList([nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2,dilation=1,kernel_size=3, padding=1),
43 | nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2,dilation=1,kernel_size=3, padding=1)])
44 | self.conv3x3_2=nn.ModuleList([nn.Conv2d(in_channels=in_channels//2, out_channels=2,dilation=1,kernel_size=3, padding=1),
45 | nn.Conv2d(in_channels=in_channels//2, out_channels=2,dilation=1,kernel_size=3, padding=1)])
46 | self.conv_last=ConvBnRelu(in_planes=in_channels,out_planes=in_channels,ksize=1,stride=1,pad=0,dilation=1)
47 | self.norm = nn.Sigmoid()
48 | self.conv1= nn.Conv2d(in_channels*2, 1, kernel_size=1, padding=0)
49 | self.dconv1=nn.Conv2d(in_channels*2, in_channels, kernel_size=1, padding=0)
50 | self.gamma = nn.Parameter(torch.zeros(1))
51 |
52 | self.relu=nn.ReLU(inplace=True)
53 |
54 | def forward(self, x):
55 |
56 | x_size= x.size()
57 |
58 | branches_1=self.conv3x3(x)
59 | branches_1=self.bn[0](branches_1)
60 |
61 | branches_2=F.conv2d(x,self.conv3x3.weight,padding=2,dilation=2)#share weight
62 | branches_2=self.bn[1](branches_2)
63 |
64 | branches_3=F.conv2d(x,self.conv3x3.weight,padding=4,dilation=4)#share weight
65 | branches_3=self.bn[2](branches_3)
66 |
67 | feat=torch.cat([branches_1,branches_2],dim=1)
68 |
69 | feat_g =feat
70 | # print(feat_g.shape)
71 | feat_g1 = self.relu(self.conv1(feat_g))
72 | feat_g1 = self.norm(feat_g1)
73 |
74 | out1 = feat_g * feat_g1
75 | out1 = self.dconv1(out1)
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 | # feat=feat_cat.detach()
84 | feat=self.relu(self.conv1x1[0](feat))
85 | feat=self.relu(self.conv3x3_1[0](feat))
86 | att=self.conv3x3_2[0](feat)
87 | att = F.softmax(att, dim=1)
88 |
89 | att_1=att[:,0,:,:].unsqueeze(1)
90 | att_2=att[:,1,:,:].unsqueeze(1)
91 |
92 | fusion_1_2=att_1*branches_1+att_2*branches_2 +out1
93 |
94 |
95 |
96 | feat1=torch.cat([fusion_1_2,branches_3],dim=1)
97 |
98 | feat_g =feat1
99 | feat_g1 = self.relu(self.conv1(feat_g))
100 | feat_g1 = self.norm(feat_g1)
101 | out2 = feat_g * feat_g1
102 | out2 = self.dconv1(out2)
103 |
104 |
105 | # feat=feat_cat.detach()
106 | feat1=self.relu(self.conv1x1[0](feat1))
107 | feat1=self.relu(self.conv3x3_1[0](feat1))
108 | att1=self.conv3x3_2[0](feat1)
109 | att1 = F.softmax(att1, dim=1)
110 |
111 |
112 | att_1_2=att1[:,0,:,:].unsqueeze(1)
113 |
114 | att_3=att1[:,1,:,:].unsqueeze(1)
115 |
116 |
117 | ax=self.relu(self.gamma*(att_1_2*fusion_1_2+att_3*branches_3 +out2)+(1-self.gamma)*x)
118 | ax=self.conv_last(ax)
119 |
120 | return ax
121 |
122 |
123 |
124 | # if __name__=='__main__':
125 | # x=torch.randn(1,512,6,6)
126 |
127 | # net = SAPblock(512)
128 | # ax =net(x)
129 |
130 | # print(ax.shape)
131 |
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/models/wnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torchvision import models
5 | from models.multi_scale_module import FoldConv_aspp
6 |
7 |
8 | class GateNet(nn.Module):
9 | def __init__(self):
10 | super(GateNet, self).__init__()
11 | ################################vgg16#######################################
12 | feats = list(models.vgg16_bn(pretrained=True).features.children())
13 | feats[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
14 | self.conv1 = nn.Sequential(*feats[:6])
15 | # print(self.conv1)
16 | self.conv2 = nn.Sequential(*feats[6:13])
17 | self.conv3 = nn.Sequential(*feats[13:23])
18 | self.conv4 = nn.Sequential(*feats[23:33])
19 | self.conv5 = nn.Sequential(*feats[33:43])
20 | ################################Gate#######################################
21 | self.attention_feature5 = nn.Sequential(nn.Conv2d(64+32, 2, kernel_size=3, padding=1))
22 | self.attention_feature4 = nn.Sequential(nn.Conv2d(128+64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
23 | nn.Conv2d(64, 2, kernel_size=3, padding=1))
24 | self.attention_feature3 = nn.Sequential(nn.Conv2d(256+128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(),
25 | nn.Conv2d(128, 2, kernel_size=3, padding=1))
26 | self.attention_feature2 = nn.Sequential(nn.Conv2d(512+256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU(),
27 | nn.Conv2d(256, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
28 | nn.Conv2d(64, 2, kernel_size=3, padding=1))
29 | self.attention_feature1 = nn.Sequential(nn.Conv2d(512+512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU(),
30 | nn.Conv2d(512, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(),
31 | nn.Conv2d(128, 2, kernel_size=3, padding=1))
32 | ###############################Transition Layer########################################
33 | self.dem1 = nn.Sequential(FoldConv_aspp(in_channel=512,
34 | out_channel=512,
35 | out_size=384 // 16,
36 | kernel_size=3,
37 | stride=1,
38 | padding=2,
39 | dilation=2,
40 | win_size=2,
41 | win_padding=0,
42 |
43 | ), nn.BatchNorm2d(512), nn.PReLU())
44 | self.dem2 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
45 | self.dem3 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
46 | self.dem4 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
47 | self.dem5 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU())
48 | ################################FPN branch#######################################
49 | self.output1 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
50 | self.output2 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
51 | self.output3 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
52 | self.output4 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.BatchNorm2d(32), nn.PReLU())
53 | self.output5 = nn.Sequential(nn.Conv2d(32, 1, kernel_size=3, padding=1))
54 | ################################Parallel branch#######################################
55 | self.dem1_1 = nn.Sequential(nn.Conv2d(512, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU())
56 | self.dem2_1 = nn.Sequential(nn.Conv2d(256, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU())
57 | self.dem3_1 = nn.Sequential(nn.Conv2d(128, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU())
58 | self.dem4_1 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU())
59 | self.dem5_1 = nn.Sequential(nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.PReLU())
60 | self.out_res = nn.Sequential(nn.Conv2d(32+32+32+32+32+1, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
61 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
62 | nn.Conv2d(64, 1, kernel_size=3, padding=1))
63 | #######################################################################
64 |
65 |
66 | for m in self.modules():
67 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
68 | m.inplace = True
69 |
70 |
71 | def forward(self, x):
72 | input = x
73 | B,_,_,_ = input.size()
74 | ################################Encoder block#######################################
75 | E1 = self.conv1(x)
76 | E2 = self.conv2(E1)
77 | E3 = self.conv3(E2)
78 | E4 = self.conv4(E3)
79 | E5 = self.conv5(E4)
80 | ################################Transition Layer#######################################
81 | T5 = self.dem1(E5)
82 | T4 = self.dem2(E4)
83 | T3 = self.dem3(E3)
84 | T2 = self.dem4(E2)
85 | T1 = self.dem5(E1)
86 | ################################Gated FPN#######################################
87 | G5 = self.attention_feature1(torch.cat((E5,T5),1))
88 | G5 = F.adaptive_avg_pool2d(F.sigmoid(G5),1)
89 | D5 = self.output1(G5[:, 0,:,:].unsqueeze(1).repeat(1,512,1,1)*T5)
90 |
91 | G4 = self.attention_feature2(torch.cat((E4,F.upsample(D5, size=E4.size()[2:], mode='bilinear')),1))
92 | G4 = F.adaptive_avg_pool2d(F.sigmoid(G4),1)
93 | D4 = self.output2(F.upsample(D5, size=E4.size()[2:], mode='bilinear')+G4[:, 0,:,:].unsqueeze(1).repeat(1,256,1,1)*T4)
94 |
95 | G3 = self.attention_feature3(torch.cat((E3,F.upsample(D4, size=E3.size()[2:], mode='bilinear')),1))
96 | G3 = F.adaptive_avg_pool2d(F.sigmoid(G3),1)
97 | D3 = self.output3(F.upsample(D4, size=E3.size()[2:], mode='bilinear')+G3[:, 0,:,:].unsqueeze(1).repeat(1,128,1,1)*T3)
98 |
99 | G2 = self.attention_feature4(torch.cat((E2,F.upsample(D3, size=E2.size()[2:], mode='bilinear')),1))
100 | G2 = F.adaptive_avg_pool2d(F.sigmoid(G2),1)
101 | D2 = self.output4(F.upsample(D3, size=E2.size()[2:], mode='bilinear')+G2[:, 0,:,:].unsqueeze(1).repeat(1,64,1,1)*T2)
102 |
103 | G1 = self.attention_feature5(torch.cat((E1,F.upsample(D2, size=E1.size()[2:], mode='bilinear')),1))
104 | G1 = F.adaptive_avg_pool2d(F.sigmoid(G1),1)
105 | D1 = self.output5(F.upsample(D2, size=E1.size()[2:], mode='bilinear')+G1[:, 0,:,:].unsqueeze(1).repeat(1,32,1,1)*T1)
106 |
107 |
108 | ################################Gated Parallel&Dual branch residual fuse#######################################
109 | R5 = self.dem1_1(T5)
110 | R4 = self.dem2_1(T4)
111 | R3 = self.dem3_1(T3)
112 | R2 = self.dem4_1(T2)
113 | R1 = self.dem5_1(T1)
114 | output_res = self.out_res(torch.cat((D1,F.upsample(G5[:, 1,:,:].unsqueeze(1).repeat(1,32,1,1)*R5,size=E1.size()[2:], mode='bilinear'),F.upsample(G4[:, 1,:,:].unsqueeze(1).repeat(1,32,1,1)*R4,size=E1.size()[2:], mode='bilinear'),F.upsample(G3[:, 1,:,:].unsqueeze(1).repeat(1,32,1,1)*R3,size=E1.size()[2:], mode='bilinear'),F.upsample(G2[:, 1,:,:].unsqueeze(1).repeat(1,32,1,1)*R2,size=E1.size()[2:], mode='bilinear'),F.upsample(G1[:, 1,:,:].unsqueeze(1).repeat(1,32,1,1)*R1,size=E1.size()[2:], mode='bilinear')),1))
115 | output_res = F.upsample(output_res,size=input.size()[2:], mode='bilinear')
116 | output_fpn = F.upsample(D1, size=input.size()[2:], mode='bilinear')
117 | pre_sal = output_fpn+output_res
118 | # print(pre_sal.shape)
119 | #######################################################################
120 | # if self.training:
121 | # return output_fpn, pre_sal
122 | return output_fpn, pre_sal
123 |
124 | if __name__ == "__main__":
125 | model = GateNet()
126 | input = torch.autograd.Variable(torch.randn(4, 3, 384, 384))
127 | output = model(input)
128 | print(output.shape)
129 |
--------------------------------------------------------------------------------
/trian_CGR_XS.py:
--------------------------------------------------------------------------------
1 | from torch import optim
2 | # from losses import *
3 | from data.dataloader import XSDataset, XSDatatest
4 | import torch.nn as nn
5 | import torch
6 |
7 | from models.unet import UNet
8 |
9 | from CGRmodes.CGR import CGRNet
10 | # from models.vggunet import VGGUNet
11 | # from utils.metric import *
12 | from torchvision.transforms import transforms
13 | from evaluation import *
14 | # from utils.metric import *
15 | import torch.nn.functional as F
16 | # from models.newnet import FastSal
17 | import tqdm
18 |
19 | def iou_loss(pred, mask):
20 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask)
21 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
22 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
23 |
24 | pred = torch.sigmoid(pred)
25 | inter = ((pred*mask)*weit).sum(dim=(2,3))
26 | union = ((pred+mask)*weit).sum(dim=(2,3))
27 | wiou = 1-(inter+1)/(union-inter+1)
28 | loss_total= (wbce+wiou).mean()/wiou.size(0)
29 | return loss_total
30 |
31 |
32 | class FocalLoss(nn.Module):
33 | def __init__(self, alpha=0.3, gamma=2, logits=True, reduce=True):
34 | super(FocalLoss, self).__init__()
35 | self.alpha = alpha
36 | self.gamma = gamma
37 | self.logits = logits
38 | self.reduce = reduce
39 |
40 | def forward(self, inputs, targets):
41 | if self.logits:
42 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
43 | else:
44 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
45 | pt = torch.exp(-BCE_loss)
46 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
47 |
48 | if self.reduce:
49 | return torch.mean(F_loss)
50 | else:
51 | return F_loss
52 |
53 |
54 |
55 | def test(testLoader,fold, net, device):
56 | net.to(device)
57 | sig = torch.nn.Sigmoid()
58 | print('start test!')
59 | net.eval()
60 | with torch.no_grad():
61 | # when in test stage, no grad
62 | acc = 0. # Accuracy
63 | SE = 0. # Sensitivity (Recall)
64 | SP = 0. # Specificity
65 | PC = 0. # Precision
66 | F1 = 0. # F1 Score
67 | JS = 0. # Jaccard Similarity
68 | DC = 0. # Dice Coefficient
69 | count = 0
70 | for image, label, path in tqdm.tqdm(testLoader):
71 | image = image.to(device=device, dtype=torch.float32)
72 | label = label.to(device=device, dtype=torch.float32)
73 | # pred,p1,p2,p3,p4,e= net(image)
74 | e1,p2= net(image)
75 | pred = sig(p2)
76 | # print(pred.shape)
77 | acc += get_accuracy(pred,label)
78 | SE += get_sensitivity(pred,label)
79 | SP += get_specificity(pred,label)
80 | PC += get_precision(pred,label)
81 | F1 += get_F1(pred,label)
82 | JS += get_JS(pred,label)
83 | DC += get_DC(pred,label)
84 | count+=1
85 | acc = acc/count
86 | SE = SE/count
87 | SP = SP/count
88 | PC = PC/count
89 | F1 = F1/count
90 | JS = JS/count
91 | DC = DC/count
92 | score = JS + DC
93 | print("\tacc: {:.4f}\tSE: {:.4f}\tSP: {:.4f}\tPC: {:.4f}\tF1: {:.4f} \tJS: {:.4f}".format(acc, SE, SP, PC, F1, JS))
94 | return acc, SE, SP, PC, F1, JS, DC, score
95 |
96 |
97 |
98 |
99 | def train_net(net, device, data_path,test_data_path, fold, epochs=40, batch_size=4, lr=0.00001):
100 | # 加载训练集
101 | isbi_dataset = XSDataset(data_path)
102 | train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
103 | batch_size=batch_size,
104 | shuffle=True,
105 | drop_last=True)
106 |
107 | test_dataset = XSDatatest(test_data_path)
108 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
109 | batch_size=1,
110 | shuffle=False)
111 | # 定义RMSprop算法11
112 | optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
113 | # optimizer = optim.Adam(net.parameters(), lr=lr)
114 |
115 | # 定义Loss算法
116 |
117 | # criterion2 = nn.BCEWithLogitsLoss()
118 | # criterion3 = nn.BCELoss()
119 | criterion2 = FocalLoss()
120 |
121 | # best_loss统计,初始化为正无穷
122 | best_loss = float('inf')
123 | result = 0
124 | # 训练epochs次
125 | # f = open('./segmentation2/UNet.csv', 'w')
126 | # f.write('epoch,loss'+'\n')
127 | for epoch in range(epochs):
128 | # 训练模式
129 | net.train()
130 | # 按照batch_size开始训练
131 | for image, label, edge in train_loader:
132 |
133 | optimizer.zero_grad()
134 | # 将数据拷贝到device中
135 | image = image.to(device=device, dtype=torch.float32)
136 | label = label.to(device=device, dtype=torch.float32)
137 | edge = edge.to(device=device, dtype=torch.float32)
138 | # 使用网络参数,输出预测结果
139 | # pred, p1,p2,p3,p4,e= net(image)
140 | # # 计算loss
141 | e1,p2= net(image)
142 | loss = iou_loss(p2, label)+ criterion2(e1, edge)
143 |
144 | print('Train Epoch:{}'.format(epoch))
145 | print('Loss/train', loss.item())
146 | # print('Loss/edge', loss1.item())
147 |
148 | # 保存loss值最小的网络参数
149 | if loss < best_loss:
150 | best_loss = loss
151 | # torch.save(net.state_dict(), './LUNG/fff'+str(fold)+'.pth')
152 | # 更新参数
153 | loss.backward()
154 | optimizer.step()
155 | # f.write(str(epoch)+","+str(best_loss.item())+"\n")
156 | if epoch>0:
157 | acc, SE, SP, PC, F1, JS, DC, score=test(test_loader,fold, net, device)
158 | if result < score:
159 | result = score
160 | # best_epoch = epoch
161 | torch.save(net.state_dict(), '/home/wangkun/BPGL/result/EPGNet/XS/EPNet_best_'+str(fold)+'.pth')
162 | with open("/home/wangkun/BPGL/result/EPGNet/XS/EPNet_"+str(fold)+".csv", "a") as w:
163 | w.write("epoch="+str(epoch)+",acc="+str(acc)+", SE="+str(SE)+",SP="+str(SP)+",PC="+str(PC)+",F1="+str(F1)+",JS="+str(JS)+",DC="+str(DC)+",Score="+str(score)+"\n")
164 |
165 |
166 | if __name__ == "__main__":
167 | import os
168 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
169 | def weights_init(m):
170 | classname = m.__class__.__name__
171 | # print(classname)
172 | if classname.find('Conv') != -1:
173 | torch.nn.init.xavier_uniform_(m.weight.data)
174 | if m.bias is not None:
175 | torch.nn.init.constant_(m.bias.data, 0.0)
176 | seed=1234
177 | torch.manual_seed(seed)
178 | torch.cuda.manual_seed_all(seed)
179 |
180 | # 选择设备,有cuda用cuda,没有就用cpu
181 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
182 | # 加载网络,图片单通道1,分类为1。
183 | # net = UNet(n_channels=1, n_classes=1)
184 | # net = MyMGNet(n_channels=1, n_classes=1)
185 | net = CGRNet(n_channels=3, n_classes=1)
186 | # net = VGG19UNet_without_boudary(n_channels=1, n_classes=1)
187 | # net = R2U_Net(img_ch=1, output_ch=1)
188 | # net = CE_Net(num_classes=1, num_channels=1)
189 | # net = VGG19UNet(n_channels=1, n_classes=1)
190 | # net = AttU_Net(img_ch=1, output_ch=1)
191 | # net =get_fcn8s(n_class=1)
192 | # net = UNet_2Plus(in_channels=1, n_classes=1)
193 | # net = DenseUnet(in_ch=1, num_classes=1)
194 | # net = CPFNet()
195 | net.to(device=device)
196 | # 指定训练集地址,开始训练
197 | fold=1
198 | # data_path = "/home/wangkun/shape-attentive-unet/data/train_96"
199 | # data_path = "/home/wangkun/shape-attentive-unet/data/ISIC/train"
200 | # test_data_path = "/home/wangkun/shape-attentive-unet/data/ISIC/test"
201 | # data_path = "/home/wangkun/shape-attentive-unet/data/LUNG/train"
202 | # test_data_path = "/home/wangkun/shape-attentive-unet/data/LUNG/test"
203 | data_path = "/home/wangkun/data/dataset/XS/Train/"
204 | test_data_path = "/home/wangkun/data/dataset/XS/val/"
205 |
206 |
207 |
208 |
209 | train_net(net, device, data_path,test_data_path, fold)
210 |
211 |
212 |
213 |
214 |
215 | #by kun wang
--------------------------------------------------------------------------------