├── __init__.py
├── models
├── __init__.py
├── backbone
│ ├── __init__.py
│ ├── resnext.py
│ ├── resnextdcn.py
│ ├── res2next.py
│ ├── res2netv1b.py
│ ├── res2net.py
│ └── resnetv1b.py
├── base_ops
│ ├── __init__.py
│ ├── CoordConv.py
│ └── DCNv2.py
├── necks
│ ├── __init__.py
│ ├── jpfm.py
│ └── fpn.py
├── model_store.py
├── download.py
├── segbase.py
└── PointNuNet.py
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── vcs.xml
├── modules.xml
└── PointNu-Net.iml
├── .gitmodules
├── utils
├── __init__.py
├── io.py
├── imsave.py
├── matrix_nms.py
├── imop.py
├── fmix.py
├── _aug.py
└── dataloader.py
├── losses
├── __init__.py
├── bce_loss.py
├── focal_loss.py
├── generalise_focal_loss.py
├── lovasz_losses.py
└── dice_loss.py
├── requirements.txt
├── eval_pannuke.py
├── LICENSE
├── configs
├── pannuke.yaml
├── pannuke_resnext101dcn.yaml
├── monuseg.yaml
├── cpm17_notype_large.yaml
├── consep_notype_large.yaml
├── consep_notype_middle.yaml
├── consep_notype_tiny.yaml
├── consep_type_large.yaml
├── consep_type_middle.yaml
├── consep_type_tiny.yaml
├── kumar_swin.yaml
├── kumar_resnet.yaml
├── kumar_notype_large.yaml
└── kumar_resnext101dcn.yaml
├── train.py
├── train_pannuke.py
├── infer_pannuke.py
├── README.md
└── inference.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/backbone/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/base_ops/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/necks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "PanNuKe-metrics"]
2 | path = PanNuKe-metrics
3 | url = https://github.com/TIA-Lab/PanNuke-metrics
4 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .io import prepare_sub_folder,get_config
2 | from .imsave import collate_func,_imageshow,_imagesave
3 | from .imop import gaussian_2d_kernel,insert_image,label_relaxation
4 |
5 |
6 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .focal_loss import BinaryFocalLoss,FocalLoss_Ori,BCELoss
2 | from .dice_loss import BinaryDiceLoss
3 | from .bce_loss import BCEWithLogitsLossWithOHEM
4 | from .lovasz_losses import lovasz_hinge
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==1.1.0
2 | imgaug==0.4.0
3 | matplotlib==3.4.3
4 | mmcv==2.0.0
5 | mmcv_full==1.4.8
6 | numpy==1.22.4+mkl
7 | opencv_python_headless==4.5.4.60
8 | Pillow==9.5.0
9 | PyYAML==6.0
10 | requests==2.21.0
11 | scikit_image==0.18.3
12 | scipy==1.7.1
13 | skimage==0.0
14 | tifffile==2021.7.2
15 | timm==0.4.12
16 | torch==1.10.0+cu113
17 | torchsummaryX==1.3.0
18 | torchvision==0.11.0+cu113
19 | tqdm==4.62.3
20 |
--------------------------------------------------------------------------------
/.idea/PointNu-Net.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/eval_pannuke.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--name', type=str, default='pannuke_exp')
6 | parser.add_argument('--save_path', type=str, default='pannuke_outputs')
7 | parser.add_argument('--train_fold', type=int, default=1)
8 | parser.add_argument('--test_fold', type=int, default=3)
9 | opts = parser.parse_args()
10 |
11 | pred_path = rf'outputs/{opts.name}/train_{opts.train_fold}_to_test_{opts.test_fold}'
12 | true_path = rf'datasets/PanNuKe/masks/fold{opts.test_fold}'
13 | save_path = opts.save_path
14 | cmd=rf"python PanNuke-metrics/run.py --true_path={true_path} --pred_path={pred_path} --save_path={save_path}"
15 |
16 | os.system(cmd)
17 |
--------------------------------------------------------------------------------
/models/base_ops/CoordConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class CoordConv2d(nn.Conv2d):
7 |
8 | def __init__(self, in_chan, out_chan, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
9 | super(CoordConv2d, self).__init__(
10 | in_chan + 2, out_chan, kernel_size, stride=stride,
11 | padding=padding, dilation=dilation, groups=groups, bias=bias)
12 |
13 | def forward(self, x):
14 | batchsize, H, W = x.size(0), x.size(2), x.size(3)
15 | h_range = torch.linspace(-1, 1, H, device=x.device, dtype=x.dtype)
16 | w_range = torch.linspace(-1, 1, W, device=x.device, dtype=x.dtype)
17 | h_chan, w_chan = torch.meshgrid(h_range, w_range)
18 | h_chan = h_chan.expand([batchsize, 1, -1, -1])
19 | w_chan = w_chan.expand([batchsize, 1, -1, -1])
20 |
21 | feat = torch.cat([h_chan, w_chan, x], dim=1)
22 |
23 | return F.conv2d(feat, self.weight, self.bias,
24 | self.stride, self.padding, self.dilation, self.groups)
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Kaiseem
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/necks/jpfm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class JPFM(nn.Module):
5 | def __init__(self,in_channel,width=256):
6 | super(JPFM, self).__init__()
7 |
8 | self.out_channel = width*4
9 | self.dilation1 = nn.Sequential(
10 | nn.Conv2d(in_channel, width, 3, padding=1, dilation=1, bias=False),
11 | nn.BatchNorm2d(width),
12 | nn.ReLU(True))
13 | self.dilation2 = nn.Sequential(
14 | nn.Conv2d(in_channel, width, 3, padding=2, dilation=2, bias=False),
15 | nn.BatchNorm2d(width),
16 | nn.ReLU(True))
17 | self.dilation3 = nn.Sequential(
18 | nn.Conv2d(in_channel, width, 3, padding=4, dilation=4, bias=False),
19 | nn.BatchNorm2d(width),
20 | nn.ReLU(True))
21 | self.dilation4 = nn.Sequential(
22 | nn.Conv2d(in_channel, width, 3, padding=8, dilation=8, bias=False),
23 | nn.BatchNorm2d(width),
24 | nn.ReLU(True))
25 |
26 | def forward(self,feat):
27 | feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)], dim=1)
28 | return feat
29 |
30 |
--------------------------------------------------------------------------------
/utils/io.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | IMG_EXTENSIONS = [
4 | '.jpg', '.JPG', '.jpeg', '.JPEG',
5 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
6 | '.tif', '.TIF', '.tiff', '.TIFF','npy','mat'
7 | ]
8 |
9 | def prepare_sub_folder(output_directory):
10 | image_directory = os.path.join(output_directory, 'images')
11 | if not os.path.exists(image_directory):
12 | print("Creating directory: {}".format(image_directory))
13 | os.makedirs(image_directory)
14 | checkpoint_directory = os.path.join(output_directory, 'checkpoints')
15 | if not os.path.exists(checkpoint_directory):
16 | print("Creating directory: {}".format(checkpoint_directory))
17 | os.makedirs(checkpoint_directory)
18 | return checkpoint_directory, image_directory
19 |
20 | def get_config(config):
21 | with open(config,'r') as stream:
22 | return yaml.load(stream, Loader=yaml.FullLoader)
23 |
24 | def is_image_file(filename):
25 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
26 |
27 | def make_dataset(dir, max_dataset_size=float("inf")):
28 | images = []
29 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
30 | for root, _, fnames in sorted(os.walk(dir)):
31 | for fname in fnames:
32 | if is_image_file(fname):
33 | path = os.path.join(root, fname)
34 | images.append(path)
35 | return images[:min(max_dataset_size, len(images))]
36 |
--------------------------------------------------------------------------------
/configs/pannuke.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: hrnet64
16 | pretrain: true
17 | frozen_stages: -1
18 | norm_eval: false
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 7
22 | num_classes: 6
23 | kernel_size: 1
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 8 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 | dataroot: ./datasets/PanNuKe
42 |
--------------------------------------------------------------------------------
/configs/pannuke_resnext101dcn.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: resnext101dcn
16 | pretrain: true
17 | frozen_stages: -1
18 | norm_eval: false
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 7
22 | num_classes: 6
23 | kernel_size: 1
24 | output_stride: 4
25 |
26 |
27 | train:
28 | max_epoch: 100 # maximum number of training iterations
29 | batch_size: 8 # batch size
30 | num_workers: 4 #
31 | optim: adamw
32 | lr: 0.0001 # initial learning rate 5e-4
33 | weight_decay: 0.0001 # weight decay 1e-4
34 | beta1: 0.9 # Adam parameter
35 | beta2: 0.999 # Adam parameter
36 | lr_policy: multistep # learning rate scheduler
37 | gamma: 0.1 # how much to decay learning rate
38 | use_mixed: false
39 | lambda_ins: 1 # weight of image instance segmentation loss
40 | lambda_cate: 1 # weight of image category classification loss
41 |
42 | dataroot: ./datasets/PanNuKe
--------------------------------------------------------------------------------
/configs/monuseg.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: hrnet64
16 | pretrain: true
17 | frozen_stages: -1
18 | norm_eval: false
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 7
22 | num_classes: 2
23 | kernel_size: 1
24 | output_stride: 4
25 |
26 |
27 | train:
28 | max_epoch: 100 # maximum number of training iterations
29 | batch_size: 8 # batch size
30 | num_workers: 4 #
31 | optim: adamw
32 | lr: 0.0005 # initial learning rate 5e-4
33 | weight_decay: 0.0001 # weight decay 1e-4
34 | beta1: 0.9 # Adam parameter
35 | beta2: 0.999 # Adam parameter
36 | lr_policy: multistep # learning rate scheduler
37 | gamma: 0.1 # how much to decay learning rate
38 | use_mixed: false
39 | lambda_ins: 1 # weight of image instance segmentation loss
40 | lambda_cate: 1 # weight of image category classification loss
41 |
42 |
43 | dataroot: ./datasets/monuseg
44 | stainnorm: normed
--------------------------------------------------------------------------------
/configs/cpm17_notype_large.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: hrnet64
16 | pretrain: true
17 | frozen_stages: 1
18 | norm_eval: true
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 7
22 | kernel_size: 1
23 | num_classes: 2
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 8 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 |
42 | dataroot: ./datasets/cpm17
43 | stainnorm: normed
44 | image_norm_mean:
45 | - 0.7949688
46 | - 0.55828372
47 | - 0.70746591
48 | image_norm_std:
49 | - 0.1847547
50 | - 0.22577311
51 | - 0.17084152
--------------------------------------------------------------------------------
/configs/consep_notype_large.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: hrnet64
16 | pretrain: true
17 | frozen_stages: 1
18 | norm_eval: true
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 7
22 | kernel_size: 1
23 | num_classes: 2
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 8 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 |
42 | dataroot: ./datasets/CoNSeP
43 | stainnorm: null
44 | image_norm_mean:
45 | - 0.82899987
46 | - 0.70380247
47 | - 0.84849265
48 | image_norm_std:
49 | - 0.16051538
50 | - 0.19952672
51 | - 0.12816858
52 |
--------------------------------------------------------------------------------
/configs/consep_notype_middle.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: hrnet32
16 | pretrain: true
17 | frozen_stages: 1
18 | norm_eval: true
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 4
22 | kernel_size: 1
23 | num_classes: 2
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 8 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 |
42 | dataroot: ./datasets/CoNSeP
43 | stainnorm: null
44 | image_norm_mean:
45 | - 0.82899987
46 | - 0.70380247
47 | - 0.84849265
48 | image_norm_std:
49 | - 0.16051538
50 | - 0.19952672
51 | - 0.12816858
52 |
--------------------------------------------------------------------------------
/configs/consep_notype_tiny.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: hrnet18
16 | pretrain: true
17 | frozen_stages: 1
18 | norm_eval: true
19 | seg_feat_channels: 128
20 | ins_out_channels: 128
21 | stacked_convs: 4
22 | kernel_size: 1
23 | num_classes: 2
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 16 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 |
42 | dataroot: ./datasets/CoNSeP
43 | stainnorm: null
44 | image_norm_mean:
45 | - 0.82899987
46 | - 0.70380247
47 | - 0.84849265
48 | image_norm_std:
49 | - 0.16051538
50 | - 0.19952672
51 | - 0.12816858
52 |
--------------------------------------------------------------------------------
/configs/consep_type_large.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: hrnet64
16 | pretrain: true
17 | frozen_stages: 1
18 | norm_eval: true
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 7
22 | kernel_size: 1
23 | num_classes: 5
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 8 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 |
42 | dataroot: ./datasets/CoNSeP
43 | stainnorm: null
44 | image_norm_mean:
45 | - 0.82899987
46 | - 0.70380247
47 | - 0.84849265
48 | image_norm_std:
49 | - 0.16051538
50 | - 0.19952672
51 | - 0.12816858
52 |
--------------------------------------------------------------------------------
/configs/consep_type_middle.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: hrnet32
16 | pretrain: true
17 | frozen_stages: 1
18 | norm_eval: true
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 4
22 | kernel_size: 1
23 | num_classes: 5
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 8 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 |
42 | dataroot: ./datasets/CoNSeP
43 | stainnorm: null
44 | image_norm_mean:
45 | - 0.82899987
46 | - 0.70380247
47 | - 0.84849265
48 | image_norm_std:
49 | - 0.16051538
50 | - 0.19952672
51 | - 0.12816858
52 |
--------------------------------------------------------------------------------
/configs/consep_type_tiny.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: hrnet18
16 | pretrain: true
17 | frozen_stages: 1
18 | norm_eval: true
19 | seg_feat_channels: 128
20 | ins_out_channels: 128
21 | stacked_convs: 4
22 | kernel_size: 1
23 | num_classes: 5
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 16 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 |
42 | dataroot: ./datasets/CoNSeP
43 | stainnorm: null
44 | image_norm_mean:
45 | - 0.82899987
46 | - 0.70380247
47 | - 0.84849265
48 | image_norm_std:
49 | - 0.16051538
50 | - 0.19952672
51 | - 0.12816858
52 |
--------------------------------------------------------------------------------
/configs/kumar_swin.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: swin
16 | pretrain: true
17 | frozen_stages: -1
18 | norm_eval: false
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 7
22 | kernel_size: 1
23 | num_classes: 2
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 8 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 | dataroot: ./datasets/kumar
42 |
43 | stainnorm: normed
44 |
45 | image_norm_mean:
46 | - 0.71174131
47 | - 0.5287984
48 | - 0.63705888
49 | image_norm_std:
50 | - 0.17307392
51 | - 0.19106038
52 | - 0.14219015
53 |
54 |
--------------------------------------------------------------------------------
/configs/kumar_resnet.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: resnet101
16 | pretrain: true
17 | frozen_stages: 1
18 | norm_eval: true
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 7
22 | kernel_size: 1
23 | num_classes: 2
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 8 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 | dataroot: ./datasets/kumar
42 |
43 | stainnorm: normed
44 |
45 | image_norm_mean:
46 | - 0.71174131
47 | - 0.5287984
48 | - 0.63705888
49 | image_norm_std:
50 | - 0.17307392
51 | - 0.19106038
52 | - 0.14219015
53 |
54 |
--------------------------------------------------------------------------------
/configs/kumar_notype_large.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: hrnet64
16 | pretrain: true
17 | frozen_stages: 1
18 | norm_eval: true
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 7
22 | kernel_size: 1
23 | num_classes: 2
24 | output_stride: 4
25 |
26 | train:
27 | max_epoch: 100 # maximum number of training iterations
28 | batch_size: 8 # batch size
29 | num_workers: 4 #
30 | optim: adamw
31 | lr: 0.0001 # initial learning rate 5e-4
32 | weight_decay: 0.0001 # weight decay 1e-4
33 | beta1: 0.9 # Adam parameter
34 | beta2: 0.999 # Adam parameter
35 | lr_policy: multistep # learning rate scheduler
36 | gamma: 0.1 # how much to decay learning rate
37 | use_mixed: false
38 | lambda_ins: 1 # weight of image instance segmentation loss
39 | lambda_cate: 1 # weight of image category classification loss
40 |
41 | dataroot: ./datasets/kumar
42 |
43 | stainnorm: normed
44 |
45 | image_norm_mean:
46 | - 0.71174131
47 | - 0.5287984
48 | - 0.63705888
49 | image_norm_std:
50 | - 0.17307392
51 | - 0.19106038
52 | - 0.14219015
53 |
54 |
--------------------------------------------------------------------------------
/configs/kumar_resnext101dcn.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_epoch: 5 # How often do you want to save output images during training
6 | image_display_epoch: 5 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_epoch: 5 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | mask_rescoring: false
12 |
13 | # model options
14 | model:
15 | backbone: resnext101dcn
16 | pretrain: true
17 | frozen_stages: 1
18 | norm_eval: true
19 | seg_feat_channels: 256
20 | ins_out_channels: 256
21 | stacked_convs: 7
22 | kernel_size: 1
23 | num_classes: 2
24 | output_stride: 4
25 |
26 |
27 | train:
28 | max_epoch: 100 # maximum number of training iterations
29 | batch_size: 8 # batch size
30 | num_workers: 4 #
31 | optim: adamw
32 | lr: 0.0001 # initial learning rate 5e-4
33 | weight_decay: 0.0001 # weight decay 1e-4
34 | beta1: 0.9 # Adam parameter
35 | beta2: 0.999 # Adam parameter
36 | lr_policy: multistep # learning rate scheduler
37 | gamma: 0.1 # how much to decay learning rate
38 | use_mixed: false
39 | lambda_ins: 1 # weight of image instance segmentation loss
40 | lambda_cate: 1 # weight of image category classification loss
41 |
42 | dataroot: ./datasets/kumar
43 |
44 | stainnorm: normed
45 |
46 | image_norm_mean:
47 | - 0.71174131
48 | - 0.5287984
49 | - 0.63705888
50 | image_norm_std:
51 | - 0.17307392
52 | - 0.19106038
53 | - 0.14219015
54 |
55 |
--------------------------------------------------------------------------------
/losses/bce_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 |
7 |
8 |
9 | class BCEWithLogitsLossWithOHEM(nn.Module):
10 |
11 | def __init__(self, ohem_ratio=1.0, pos_weight=None, eps=1e-7):
12 | super(BCEWithLogitsLossWithOHEM, self).__init__()
13 | self.criterion = nn.BCEWithLogitsLoss(reduction='none',
14 | pos_weight=pos_weight)
15 | self.ohem_ratio = ohem_ratio
16 | self.eps = eps
17 |
18 | def forward(self, pred, target):
19 | loss = self.criterion(pred, target)
20 | mask = _ohem_mask(loss, self.ohem_ratio)
21 | loss = loss * mask
22 | return loss.sum() / (mask.sum() + self.eps)
23 |
24 | def set_ohem_ratio(self, ohem_ratio):
25 | self.ohem_ratio = ohem_ratio
26 |
27 | def _ohem_mask(loss, ohem_ratio):
28 | with torch.no_grad():
29 | values, _ = torch.topk(loss.reshape(-1),
30 | int(loss.nelement() * ohem_ratio))
31 | mask = loss >= values[-1]
32 | return mask.float()
33 | class CrossEntropyLossWithOHEM(nn.Module):
34 |
35 | def __init__(self, ohem_ratio=1.0, weight=None, ignore_index=-100,
36 | eps=1e-7):
37 | super(CrossEntropyLossWithOHEM, self).__init__()
38 | self.criterion = nn.CrossEntropyLoss(weight=weight,
39 | ignore_index=ignore_index,
40 | reduction='none')
41 | self.ohem_ratio = ohem_ratio
42 | self.eps = eps
43 |
44 | def forward(self, pred, target):
45 | loss = self.criterion(pred, target)
46 | mask = _ohem_mask(loss, self.ohem_ratio)
47 | loss = loss * mask
48 | return loss.sum() / (mask.sum() + self.eps)
49 |
50 | def set_ohem_ratio(self, ohem_ratio):
51 | self.ohem_ratio = ohem_ratio
52 |
--------------------------------------------------------------------------------
/models/necks/fpn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | class FPN(nn.Module):
5 | def __init__(self,channels=[256,512,1024,2048]):
6 | super(FPN, self).__init__()
7 | self.in_planes = 64
8 | # Top layer
9 | self.toplayer = nn.Conv2d(channels[3], 256, kernel_size=1, stride=1, padding=0) # Reduce channels
10 |
11 | # Smooth layers
12 | self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
13 | self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
14 | self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
15 |
16 | # Lateral layers
17 | self.latlayer1 = nn.Conv2d(channels[2], 256, kernel_size=1, stride=1, padding=0)
18 | self.latlayer2 = nn.Conv2d(channels[1], 256, kernel_size=1, stride=1, padding=0)
19 | self.latlayer3 = nn.Conv2d(channels[0], 256, kernel_size=1, stride=1, padding=0)
20 |
21 | def _upsample_add(self, x, y):
22 | '''Upsample and add two feature maps.
23 | Args:
24 | x: (Variable) top feature map to be upsampled.
25 | y: (Variable) lateral feature map.
26 | Returns:
27 | (Variable) added feature map.
28 | Note in PyTorch, when input size is odd, the upsampled feature map
29 | with `F.upsample(..., scale_factor=2, mode='nearest')`
30 | maybe not equal to the lateral feature map size.
31 | e.g.
32 | original input size: [N,_,15,15] ->
33 | conv2d feature map size: [N,_,8,8] ->
34 | upsampled feature map size: [N,_,16,16]
35 | So we choose bilinear upsample which supports arbitrary output sizes.
36 | '''
37 | _, _, H, W = y.size()
38 | return F.interpolate(x, size=(H, W), mode='bilinear',align_corners=True) + y
39 |
40 | def forward(self, c2, c3, c4, c5):
41 | # Top-down
42 | p5 = self.toplayer(c5)
43 | p4 = self._upsample_add(p5, self.latlayer1(c4))
44 | p3 = self._upsample_add(p5, self.latlayer2(c3))
45 | p2 = self._upsample_add(p4, self.latlayer3(c2))
46 | # Smooth
47 | p4 = self.smooth1(p4)
48 | p3 = self.smooth2(p3)
49 | p2 = self.smooth3(p2)
50 | return p2, p3, p4, p5
--------------------------------------------------------------------------------
/models/model_store.py:
--------------------------------------------------------------------------------
1 | """Model store which provides pretrained models."""
2 | from __future__ import print_function
3 |
4 | import os
5 | import zipfile
6 |
7 | from .download import download, check_sha1
8 |
9 | __all__ = ['get_model_file', 'get_resnet_file']
10 |
11 | _model_sha1 = {name: checksum for checksum, name in [
12 | ('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'),
13 | ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'),
14 | ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'),
15 | ]}
16 |
17 | encoding_repo_url = 'https://hangzh.s3.amazonaws.com/'
18 | _url_format = '{repo_url}encoding/models/{file_name}.zip'
19 |
20 |
21 | def short_hash(name):
22 | if name not in _model_sha1:
23 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
24 | return _model_sha1[name][:8]
25 |
26 |
27 | def get_resnet_file(name, root='~/.torch/models'):
28 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name))
29 | root = os.path.expanduser(root)
30 |
31 | file_path = os.path.join(root, file_name + '.pth')
32 | sha1_hash = _model_sha1[name]
33 | if os.path.exists(file_path):
34 | if check_sha1(file_path, sha1_hash):
35 | return file_path
36 | else:
37 | print('Mismatch in the content of model file {} detected.' +
38 | ' Downloading again.'.format(file_path))
39 | else:
40 | print('Model file {} is not found. Downloading.'.format(file_path))
41 |
42 | if not os.path.exists(root):
43 | os.makedirs(root)
44 |
45 | zip_file_path = os.path.join(root, file_name + '.zip')
46 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url)
47 | if repo_url[-1] != '/':
48 | repo_url = repo_url + '/'
49 | download(_url_format.format(repo_url=repo_url, file_name=file_name),
50 | path=zip_file_path,
51 | overwrite=True)
52 | with zipfile.ZipFile(zip_file_path) as zf:
53 | zf.extractall(root)
54 | os.remove(zip_file_path)
55 |
56 | if check_sha1(file_path, sha1_hash):
57 | return file_path
58 | else:
59 | raise ValueError('Downloaded file has different hash. Please try again.')
60 |
61 |
62 | def get_model_file(name, root='~/.torch/models'):
63 | root = os.path.expanduser(root)
64 | file_path = os.path.join(root, name + '.pth')
65 | if os.path.exists(file_path):
66 | return file_path
67 | else:
68 | raise ValueError('Model file is not found. Downloading or trainning.')
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os,shutil
2 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
3 |
4 | from utils.dataloader import NucleiDataset
5 | from trainer import Trainer
6 | from torch.utils.data import DataLoader
7 | import sys
8 | import torch.nn as nn
9 | from utils import prepare_sub_folder,get_config,collate_func
10 | import torch
11 | import numpy as np
12 | import math
13 | import argparse
14 | import random
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--config', type=str, default='configs/kumar_notype_large.yaml')
17 | parser.add_argument('--name', type=str, default='tmp')
18 | parser.add_argument('--output_dir', type=str, default='outputs')
19 | parser.add_argument('--seed', type=int, default=10)
20 | opts = parser.parse_args()
21 |
22 | def check_manual_seed(seed):
23 | """ If manual seed is not specified, choose a
24 | random one and communicate it to the user.
25 | Args:
26 | seed: seed to check
27 | """
28 | seed = seed or random.randint(1, 10000)
29 | random.seed(seed)
30 | np.random.seed(seed)
31 | torch.manual_seed(seed)
32 | torch.cuda.manual_seed(seed)
33 | # ia.random.seed(seed)
34 |
35 | print("Using manual seed: {seed}".format(seed=seed))
36 | return
37 |
38 | if __name__ == '__main__':
39 | config=get_config(opts.config)
40 | train_dataset=NucleiDataset(config,opts.seed,is_train=True)
41 | check_manual_seed(opts.seed)
42 | train_loader=DataLoader(dataset=train_dataset, batch_size=config['train']['batch_size'], shuffle=True, drop_last=True, num_workers=config['train']['num_workers'],collate_fn=collate_func,pin_memory=True)
43 |
44 | output_directory = os.path.join(opts.output_dir, opts.name)
45 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
46 | shutil.copy(opts.config,os.path.join(output_directory,'config.yaml'))
47 |
48 | trainer = Trainer(config)
49 | trainer.cuda()
50 |
51 | iteration=0
52 | iter_per_epoch=len(train_loader)
53 |
54 | for epoch in range(config['train']['max_epoch']):
55 | for train_data in train_loader:
56 | for k in train_data.keys():
57 | if not isinstance(train_data[k], list):
58 | train_data[k] = train_data[k].cuda().detach()
59 | else:
60 | train_data[k] = [s.cuda().detach() if s is not None else s for s in train_data[k]]
61 |
62 | ins_loss, cate_loss, maskiou_loss = trainer.seg_updata_FMIX(train_data)
63 |
64 | sys.stdout.write(
65 | f'\r epoch:{epoch} step:{iteration}/{iter_per_epoch} ins_loss: {ins_loss} cate_loss: {cate_loss} maskiou_loss: {maskiou_loss}')
66 | iteration += 1
67 | if (epoch + 1) % 20 == 0:
68 | trainer.save(checkpoint_directory, epoch)
69 | trainer.scheduler.step()
70 |
71 | trainer.save(checkpoint_directory, epoch)
72 |
73 |
--------------------------------------------------------------------------------
/train_pannuke.py:
--------------------------------------------------------------------------------
1 | from utils.dataloader import NucleiDataset,PannukeDataset
2 | from trainer import Trainer
3 | from torch.utils.data import DataLoader
4 | import sys
5 | from utils import prepare_sub_folder,get_config,collate_func
6 | import torch
7 | import numpy as np
8 | import os,shutil
9 | import argparse
10 | import random
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--config', type=str, default='configs/pannuke.yaml')
14 | parser.add_argument('--name', type=str, default='pannuke_experiment')
15 | parser.add_argument('--train_fold', type=int, default=2)
16 | parser.add_argument('--val_fold', type=int, default=1)
17 | parser.add_argument('--test_fold', type=int, default=3)
18 | parser.add_argument('--output_dir', type=str, default='outputs')
19 | parser.add_argument('--seed', type=int, default=10)
20 | opts = parser.parse_args()
21 |
22 | def check_manual_seed(seed):
23 | """ If manual seed is not specified, choose a
24 | random one and communicate it to the user.
25 | Args:
26 | seed: seed to check
27 | """
28 | seed = seed or random.randint(1, 10000)
29 | random.seed(seed)
30 | np.random.seed(seed)
31 | torch.manual_seed(seed)
32 | torch.cuda.manual_seed(seed)
33 | # ia.random.seed(seed)
34 |
35 | print("Using manual seed: {seed}".format(seed=seed))
36 | return
37 |
38 | if __name__ == '__main__':
39 | config=get_config(opts.config)
40 | check_manual_seed(opts.seed)
41 | train_dataset=PannukeDataset(data_root=config['dataroot'], seed=opts.seed, is_train=True, fold=opts.train_fold,output_stride=config['model']['output_stride'])
42 | train_loader=DataLoader(dataset=train_dataset, batch_size=config['train']['batch_size'], shuffle=True, drop_last=True, num_workers=config['train']['num_workers'],persistent_workers=True,collate_fn=collate_func,pin_memory=True)
43 |
44 | output_directory = os.path.join(opts.output_dir, opts.name, 'train_{}_to_test_{}'.format( opts.train_fold,opts.test_fold))
45 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
46 | shutil.copy(opts.config,os.path.join(output_directory,'config.yaml'))
47 |
48 | trainer = Trainer(config)
49 | trainer.cuda()
50 |
51 | iteration=0
52 | iter_per_epoch=len(train_loader)
53 | for epoch in range(config['train']['max_epoch']):
54 | for train_data in train_loader:
55 | for k in train_data.keys():
56 | if not isinstance(train_data[k],list):
57 | train_data[k]=train_data[k].cuda().detach()
58 | else:
59 | train_data[k] = [s.cuda().detach() if s is not None else s for s in train_data[k]]
60 | ins_loss, cate_loss,maskiou_loss=trainer.seg_updata(train_data)
61 | sys.stdout.write(f'\r epoch:{epoch} step:{iteration}/{iter_per_epoch} ins_loss: {ins_loss} cate_loss: {cate_loss} maskiou_loss: {maskiou_loss}')
62 | iteration+=1
63 | trainer.scheduler.step()
64 |
65 | if (epoch+1)%50==0:
66 | trainer.save(checkpoint_directory, epoch)
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/infer_pannuke.py:
--------------------------------------------------------------------------------
1 | from trainer import Trainer
2 | try:
3 | from itertools import izip as zip
4 | except ImportError: # will be 3.x series
5 | pass
6 | import scipy.io as scio
7 | from utils.dataloader import NucleiDataset,PannukeDataset
8 | import torch
9 | import os
10 | from torch.utils.data import DataLoader
11 | import argparse
12 | import numpy as np
13 | from utils import get_config,collate_func
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--output_dir', type=str, default='outputs')
16 | parser.add_argument('--name', type=str, default='pannuke227')
17 | parser.add_argument('--train_fold', type=int, default=2)
18 | parser.add_argument('--test_fold', type=int, default=3)
19 | parser.add_argument('--epoch',type=int,default=100)
20 | opts = parser.parse_args()
21 |
22 | def stack_prediction(seg_masks,cate_labels):
23 | out_seg=np.zeros((256,256,6))
24 | idx_num=1
25 | for mask,label in zip(seg_masks,cate_labels):
26 | assert label!=5
27 | out_seg[:,:,label]=np.maximum(out_seg[:,:,label],mask*idx_num)
28 | idx_num+=1
29 | out_seg[:,:,5]=np.sum(out_seg[:,:,:5],axis=-1)==0
30 | return out_seg
31 |
32 | if __name__ == '__main__':
33 | opts.config=os.path.join(opts.output_dir,'{}'.format(opts.name),'train_{}_to_test_{}/config.yaml'.format( opts.train_fold,opts.test_fold))
34 | config=get_config(opts.config)
35 |
36 | #train_dataset=NucleiDataset(data_root=config['dataroot'],is_train=True,stain_norm=stain_norm_type)
37 | test_dataset=PannukeDataset(data_root=config['dataroot'], is_train=False, fold=opts.test_fold,output_stride=config['model']['output_stride'])
38 | test_loader=DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=0,collate_fn=collate_func,pin_memory=True)
39 | config['model']['kernel_size']=1
40 | config['train']['use_mixed']=False
41 | trainer = Trainer(config)
42 | trainer.cuda()
43 |
44 | state_path = os.path.join(opts.output_dir,opts.name,'train_{}_to_test_{}'.format( opts.train_fold,opts.test_fold),'checkpoints/model_{}.pt'.format('%04d' % (opts.epoch)))
45 | #state_path = os.path.join(opts.output_dir,opts.name,'train_{}_to_test_{}'.format( opts.train_fold,opts.test_fold),'checkpoints/model_ema.pt')
46 | state_dict = torch.load(state_path)
47 |
48 | trainer.model.load_state_dict(state_dict['seg'])
49 | predictions=[]
50 | for test_data in test_loader:
51 | for k in test_data.keys():
52 | if not isinstance(test_data[k], list):
53 | test_data[k] = test_data[k].cuda().detach()
54 | else:
55 | test_data[k] = [s.cuda().detach() if s is not None else s for s in test_data[k]]
56 | with torch.no_grad():
57 | img=test_data['image']
58 | output = trainer.prediction(img, score_thr=0.4, update_thr=0.2)
59 | if output is not None:
60 | seg_masks, cate_labels, cate_scores = output
61 | seg_masks = seg_masks.cpu().numpy()
62 | cate_labels = cate_labels.cpu().numpy()
63 | cate_scores = cate_scores.cpu().numpy()
64 | predictions.append(stack_prediction(seg_masks, cate_labels))
65 | else:
66 | predictions.append(np.zeros((256,256,6)))
67 |
68 | predictions=np.stack(predictions,0).astype(np.int16)
69 | save_fp= os.path.join(opts.output_dir,opts.name,'train_{}_to_test_{}/masks.npy'.format( opts.train_fold,opts.test_fold))
70 |
71 | np.save(save_fp,predictions)
72 |
73 |
74 |
75 |
76 |
--------------------------------------------------------------------------------
/models/download.py:
--------------------------------------------------------------------------------
1 | """Download files with progress bar."""
2 | import os
3 | import hashlib
4 | import requests
5 | from tqdm import tqdm
6 |
7 | def check_sha1(filename, sha1_hash):
8 | """Check whether the sha1 hash of the file content matches the expected hash.
9 | Parameters
10 | ----------
11 | filename : str
12 | Path to the file.
13 | sha1_hash : str
14 | Expected sha1 hash in hexadecimal digits.
15 | Returns
16 | -------
17 | bool
18 | Whether the file content matches the expected hash.
19 | """
20 | sha1 = hashlib.sha1()
21 | with open(filename, 'rb') as f:
22 | while True:
23 | data = f.read(1048576)
24 | if not data:
25 | break
26 | sha1.update(data)
27 |
28 | sha1_file = sha1.hexdigest()
29 | l = min(len(sha1_file), len(sha1_hash))
30 | return sha1.hexdigest()[0:l] == sha1_hash[0:l]
31 |
32 | def download(url, path=None, overwrite=False, sha1_hash=None):
33 | """Download an given URL
34 | Parameters
35 | ----------
36 | url : str
37 | URL to download
38 | path : str, optional
39 | Destination path to store downloaded file. By default stores to the
40 | current directory with same name as in url.
41 | overwrite : bool, optional
42 | Whether to overwrite destination file if already exists.
43 | sha1_hash : str, optional
44 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
45 | but doesn't match.
46 | Returns
47 | -------
48 | str
49 | The file path of the downloaded file.
50 | """
51 | if path is None:
52 | fname = url.split('/')[-1]
53 | else:
54 | path = os.path.expanduser(path)
55 | if os.path.isdir(path):
56 | fname = os.path.join(path, url.split('/')[-1])
57 | else:
58 | fname = path
59 |
60 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
61 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
62 | if not os.path.exists(dirname):
63 | os.makedirs(dirname)
64 |
65 | print('Downloading %s from %s...'%(fname, url))
66 | r = requests.get(url, stream=True)
67 | if r.status_code != 200:
68 | raise RuntimeError("Failed downloading url %s"%url)
69 | total_length = r.headers.get('content-length')
70 | with open(fname, 'wb') as f:
71 | if total_length is None: # no content length header
72 | for chunk in r.iter_content(chunk_size=1024):
73 | if chunk: # filter out keep-alive new chunks
74 | f.write(chunk)
75 | else:
76 | total_length = int(total_length)
77 | for chunk in tqdm(r.iter_content(chunk_size=1024),
78 | total=int(total_length / 1024. + 0.5),
79 | unit='KB', unit_scale=False, dynamic_ncols=True):
80 | f.write(chunk)
81 |
82 | if sha1_hash and not check_sha1(fname, sha1_hash):
83 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \
84 | 'The repo may be outdated or download may be incomplete. ' \
85 | 'If the "repo_url" is overridden, consider switching to ' \
86 | 'the default repo.'.format(fname))
87 |
88 | return fname
89 |
--------------------------------------------------------------------------------
/models/base_ops/DCNv2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchvision.ops import DeformConv2d, deform_conv2d
4 | import math
5 | class DeformConv(DeformConv2d):
6 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,dilation=1,groups=1, bias=None):
7 | super(DeformConv,self).__init__(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size, stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)
8 | channels_ = groups * 3 * self.kernel_size[0] * self.kernel_size[1]
9 | self.conv_offset_mask = nn.Conv2d(self.in_channels,
10 | channels_,
11 | kernel_size=self.kernel_size,
12 | stride=self.stride,
13 | padding=self.padding,
14 | bias=True)
15 | self.init_offset()
16 |
17 | def init_offset(self):
18 | self.conv_offset_mask.weight.data.zero_()
19 | self.conv_offset_mask.bias.data.zero_()
20 |
21 | def forward(self, input):
22 | out = self.conv_offset_mask(input)
23 | o1, o2, mask = torch.chunk(out, 3, dim=1)
24 | offset = torch.cat((o1, o2), dim=1)
25 | mask = torch.sigmoid(mask)
26 | return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
27 | padding=self.padding, dilation=self.dilation, mask=mask)
28 |
29 | class DeformConv123(DeformConv2d):
30 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=None):
31 | super(DeformConv, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
32 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
33 | channels_ = self.groups * 2 * self.kernel_size[0] * self.kernel_size[1]
34 | self.conv_offset = nn.Conv2d(self.in_channels,
35 | channels_,
36 | kernel_size=self.groups * 2 * self.kernel_size[0] * self.kernel_size[1],
37 | stride=self.stride,
38 | padding=self.padding,
39 | bias=True)
40 | nn.init.constant_(self.conv_offset.weight, 0)
41 | self.conv_offset.register_backward_hook(self._set_lr)
42 | self.conv_mask = nn.Conv2d(self.in_channels,
43 | self.groups * 1 * self.kernel_size[0] * self.kernel_size[1],
44 | kernel_size=self.kernel_size,
45 | stride=self.stride,
46 | padding=self.padding,
47 | bias=True)
48 | nn.init.constant_(self.conv_mask.weight, 0)
49 | self.conv_mask.register_backward_hook(self._set_lr)
50 |
51 | @staticmethod
52 | def _set_lr(module, grad_input, grad_output):
53 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
54 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
55 |
56 | def init_offset(self):
57 | n = self.in_channels
58 | for k in self.kernel_size:
59 | n *= k
60 | stdv = 1. / math.sqrt(n)
61 | self.conv_offset.weight.data.uniform_(-stdv, stdv)
62 | self.conv_offset.bias.data.zero_()
63 |
64 | def forward(self, input):
65 | offset = self.conv_offset(input)
66 | mask = torch.sigmoid(self.conv_mask(input))
67 | return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
68 | padding=self.padding, dilation=self.dilation, mask=mask)
--------------------------------------------------------------------------------
/utils/imsave.py:
--------------------------------------------------------------------------------
1 | import torchvision.utils as vutils
2 | import torch
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import random
6 | import time
7 | from PIL import Image
8 | import colorsys
9 | import cv2
10 | def write_2images(image_outputs, display_image_num, image_directory, postfix):
11 | __write_images(image_outputs[:], display_image_num, '%s/gen_%s.png' % (image_directory, postfix))
12 |
13 | def __write_images(image_outputs, display_image_num, file_name):
14 | image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels
15 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0)
16 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True)
17 | vutils.save_image(image_grid, file_name, nrow=1)
18 |
19 | def collate_func(batch_dic):
20 | output={}
21 | #for k in ['labels','ins_masks']:
22 | for k in ['cate_labels', 'ins_labels','ins_ind_labels']:
23 | output[k]=[dic[k] if dic[k] is not None else None for dic in batch_dic]
24 | output['image']=torch.stack([dic['image'] for dic in batch_dic])
25 | return output
26 |
27 | def convert_labels(arr):
28 | w,h=arr.shape
29 | output=np.zeros([w,h,3],dtype=np.uint8)
30 | for i in np.unique(arr):
31 | if i ==0:continue
32 | output[arr==i]=[random.randint(64,255),random.randint(64,255),random.randint(64,255)]
33 | return output
34 |
35 | def _imageshow(image,pred,gt,unpaired_pred_ind,unpaired_gt_ind,title=None,cmap='tab20b'):
36 | plt.rcParams['figure.figsize'] = (8.0, 6.0) # 设置figure_size尺寸
37 | plt.rcParams['image.interpolation'] = 'nearest' # 设置 interpolation style
38 | plt.rcParams['savefig.dpi'] = 300
39 | plt.rcParams['figure.dpi'] = 300
40 |
41 | plt.subplot(231)
42 | plt.imshow(image)
43 | if title is not None:
44 | plt.title(title)
45 | plt.axis('off')
46 | plt.subplot(232)
47 | plt.imshow(convert_labels(pred))
48 | plt.title(f'pred_num {len(np.unique(pred))}')
49 | plt.axis('off')
50 | plt.subplot(233)
51 | plt.imshow(convert_labels(gt))
52 | plt.title(f'gt_num {len(np.unique(gt))}')
53 | plt.axis('off')
54 | plt.subplot(224)
55 | plt.axis('off')
56 | unmatched_pred=np.zeros_like(pred)
57 | for i in np.unique(unpaired_pred_ind):
58 | if i ==0:
59 | continue
60 | unmatched_pred[pred==i]=i
61 | plt.subplot(235)
62 | plt.imshow(convert_labels(unmatched_pred))
63 | plt.title(f'unmatched_pred {len(np.unique(unpaired_pred_ind))}')
64 | plt.axis('off')
65 | unmatched_gt=np.zeros_like(gt)
66 | for i in np.unique(unpaired_gt_ind):
67 | if i ==0:
68 | continue
69 | unmatched_gt[gt==i]=i
70 | plt.subplot(236)
71 | plt.imshow(convert_labels(unmatched_gt))
72 | plt.title(f'unmatched_gt {len(np.unique(unpaired_gt_ind))}')
73 | plt.axis('off')
74 | plt.subplots_adjust(top=0.95, bottom=0.05, right=0.99, left=0.01, hspace=0.1, wspace=0.1)
75 | plt.tight_layout()
76 | #plt.savefig('D:/{}.png'.format(time.time()))
77 | plt.show()
78 |
79 |
80 | def random_colors(N, bright=True):
81 | """Generate random colors.
82 |
83 | To get visually distinct colors, generate them in HSV space then
84 | convert to RGB.
85 | """
86 | brightness = 1.0 if bright else 0.7
87 | hsv = [(i / N, 1, brightness) for i in range(N)]
88 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
89 | random.shuffle(colors)
90 | return colors
91 |
92 |
93 | def _imagesave(image, mask, label,save_path):
94 | colors={
95 | 1: [255, 0, 0],
96 | 2:[0, 255, 0],
97 | 3: [0, 0, 255],
98 | 4: [255, 255, 0],
99 | 5 : [255, 165, 0]
100 | }
101 |
102 | inst_rng_colors = random_colors(len(np.unique(mask)))
103 | inst_rng_colors = np.clip(inst_rng_colors,0,255) * 255
104 | inst_rng_colors = inst_rng_colors.astype(np.uint8).tolist()
105 |
106 | image=image.copy()
107 | if image.shape[2]==4:
108 | image=image[:,:,0:3].copy()
109 |
110 |
111 | for ui in np.unique(mask):
112 | if ui == 0: continue
113 | binary = ((mask == ui) * 255).astype(np.uint8)
114 | contours, hierarchy = cv2.findContours(binary.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
115 | if label is not None:
116 | image = cv2.drawContours(image, contours, -1, tuple(colors[label[ui-1]]), 2)
117 | else:
118 | image = cv2.drawContours(image, contours, -1, tuple(inst_rng_colors[ui-1]), 2)
119 |
120 | Image.fromarray(image).save(save_path)
121 |
--------------------------------------------------------------------------------
/losses/focal_loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | def binary_focal_loss(x, y, alpha=0.25, gamma=2., reduction='none'):
7 | pt = x.detach() * (y.detach() * 2 - 1)
8 | w = (1 - pt).pow(gamma)
9 | w[y == 0] *= (1 - alpha)
10 | w[y > 0] *= alpha
11 | # a = torch.where(y < 0, alpha, (1 - alpha))
12 | loss = F.binary_cross_entropy(x, y, w, reduction=reduction)
13 | return loss
14 |
15 | class BinaryFocalLoss(nn.Module):
16 | def __init__(self, alpha=0.25, gamma=2):
17 | super(BinaryFocalLoss, self).__init__()
18 | self.alpha = alpha
19 | self.gamma = gamma
20 | self.smooth = 1e-6 # set '1e-4' when train with FP16
21 |
22 | def forward(self, output, target):
23 | prob = torch.sigmoid(output)
24 | prob = torch.clamp(prob, self.smooth, 1.0 - self.smooth)
25 | loss=-target*(1-self.alpha)*((1-prob)**self.gamma)*torch.log(prob)-(1-target)*self.alpha*(prob**self.gamma)*torch.log(1-prob)
26 | return loss
27 |
28 | class BCELoss(nn.Module):
29 | def __init__(self):
30 | super(BCELoss, self).__init__()
31 | self.smooth = 1e-6 # set '1e-4' when train with FP16
32 |
33 | def forward(self, output, target):
34 | prob = torch.sigmoid(output)
35 | prob = torch.clamp(prob, self.smooth, 1.0 - self.smooth)
36 | loss=-target*torch.log(prob)-(1-target)*torch.log(1-prob)
37 | return loss
38 |
39 | class FocalLoss_Ori(nn.Module):
40 | """
41 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
42 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
43 | Focal_Loss= -1*alpha*(1-pt)*log(pt)
44 | :param num_class:
45 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
46 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
47 | focus on hard misclassified example
48 | :param smooth: (float,double) smooth value when cross entropy
49 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
50 | """
51 |
52 | def __init__(self, num_class, alpha=[0.25, 0.75], gamma=2, balance_index=-1, size_average=True):
53 | super(FocalLoss_Ori, self).__init__()
54 | self.num_class = num_class
55 | self.alpha = alpha
56 | self.gamma = gamma
57 | self.size_average = size_average
58 | self.eps = 1e-6
59 |
60 | if isinstance(self.alpha, (list, tuple)):
61 | assert len(self.alpha) == self.num_class
62 | self.alpha = torch.Tensor(list(self.alpha))
63 | elif isinstance(self.alpha, (float, int)):
64 | assert 0 < self.alpha < 1.0, 'alpha should be in `(0,1)`)'
65 | assert balance_index > -1
66 | alpha = torch.ones((self.num_class))
67 | alpha *= 1 - self.alpha
68 | alpha[balance_index] = self.alpha
69 | self.alpha = alpha
70 | elif isinstance(self.alpha, torch.Tensor):
71 | self.alpha = self.alpha
72 | else:
73 | raise TypeError('Not support alpha type, expect `int|float|list|tuple|torch.Tensor`')
74 |
75 | def forward(self, logit, target):
76 |
77 | if logit.dim() > 2:
78 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
79 | logit = logit.view(logit.size(0), logit.size(1), -1)
80 | logit = logit.transpose(1, 2).contiguous() # [N,C,d1*d2..] -> [N,d1*d2..,C]
81 | logit = logit.view(-1, logit.size(-1)) # [N,d1*d2..,C]-> [N*d1*d2..,C]
82 | target = target.view(-1, 1) # [N,d1,d2,...]->[N*d1*d2*...,1]
83 |
84 | # -----------legacy way------------
85 | # idx = target.cpu().long()
86 | # one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
87 | # one_hot_key = one_hot_key.scatter_(1, idx, 1)
88 | # if one_hot_key.device != logit.device:
89 | # one_hot_key = one_hot_key.to(logit.device)
90 | # pt = (one_hot_key * logit).sum(1) + epsilon
91 |
92 | # ----------memory saving way--------
93 | pt = logit.gather(1, target).view(-1) + self.eps # avoid apply
94 | logpt = pt.log()
95 |
96 | if self.alpha.device != logpt.device:
97 | alpha = self.alpha.to(logpt.device)
98 | alpha_class = alpha.gather(0, target.view(-1))
99 | logpt = alpha_class * logpt
100 | loss = -1 * torch.pow(torch.sub(1.0, pt), self.gamma) * logpt
101 |
102 | if self.size_average:
103 | loss = loss.mean()
104 | else:
105 | loss = loss.sum()
106 | return loss
107 |
108 |
--------------------------------------------------------------------------------
/utils/matrix_nms.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None):
5 | """Matrix NMS for multi-class masks.
6 | Args:
7 | seg_masks (Tensor): shape (n, h, w)
8 | cate_labels (Tensor): shape (n), mask labels in descending order
9 | cate_scores (Tensor): shape (n), mask scores in descending order
10 | kernel (str): 'linear' or 'gauss'
11 | sigma (float): std in gaussian method
12 | sum_masks (Tensor): The sum of seg_masks
13 | Returns:
14 | Tensor: cate_scores_update, tensors of shape (n)
15 | """
16 | n_samples = len(cate_labels)
17 | if n_samples == 0:
18 | return []
19 | if sum_masks is None:
20 | sum_masks = seg_masks.sum((1, 2)).float()
21 | seg_masks = seg_masks.reshape(n_samples, -1).float()
22 | # inter.
23 | inter_matrix = torch.mm(seg_masks, seg_masks.transpose(1, 0))
24 | # union.
25 | sum_masks_x = sum_masks.expand(n_samples, n_samples)
26 | # iou.
27 | iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1)
28 | # label_specific matrix.
29 | cate_labels_x = cate_labels.expand(n_samples, n_samples)
30 | label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1)
31 |
32 | # IoU compensation
33 | compensate_iou, _ = (iou_matrix * label_matrix).max(0)
34 | compensate_iou = compensate_iou.expand(n_samples, n_samples).transpose(1, 0)
35 |
36 | # IoU decay
37 | decay_iou = iou_matrix * label_matrix
38 |
39 | # matrix nms
40 | if kernel == 'gaussian':
41 | decay_matrix = torch.exp(-1 * sigma * (decay_iou ** 2))
42 | compensate_matrix = torch.exp(-1 * sigma * (compensate_iou ** 2))
43 | decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
44 | elif kernel == 'linear':
45 | decay_matrix = (1-decay_iou)/(1-compensate_iou)
46 | decay_coefficient, _ = decay_matrix.min(0)
47 | else:
48 | raise NotImplementedError
49 |
50 | # update the score.
51 | cate_scores_update = cate_scores * decay_coefficient
52 | return cate_scores_update
53 |
54 |
55 | def multiclass_nms(multi_bboxes,
56 | multi_scores,
57 | score_thr,
58 | nms_cfg,
59 | max_num=-1,
60 | score_factors=None):
61 | """NMS for multi-class bboxes.
62 | Args:
63 | multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
64 | multi_scores (Tensor): shape (n, #class), where the 0th column
65 | contains scores of the background class, but this will be ignored.
66 | score_thr (float): bbox threshold, bboxes with scores lower than it
67 | will not be considered.
68 | nms_thr (float): NMS IoU threshold
69 | max_num (int): if there are more than max_num bboxes after NMS,
70 | only top max_num will be kept.
71 | score_factors (Tensor): The factors multiplied to scores before
72 | applying NMS
73 | Returns:
74 | tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels
75 | are 0-based.
76 | """
77 | num_classes = multi_scores.shape[1]
78 | bboxes, labels = [], []
79 | nms_cfg_ = nms_cfg.copy()
80 | nms_type = nms_cfg_.pop('type', 'nms')
81 | nms_op = getattr(nms_wrapper, nms_type)
82 | for i in range(1, num_classes):
83 | cls_inds = multi_scores[:, i] > score_thr
84 | if not cls_inds.any():
85 | continue
86 | # get bboxes and scores of this class
87 | if multi_bboxes.shape[1] == 4:
88 | _bboxes = multi_bboxes[cls_inds, :]
89 | else:
90 | _bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4]
91 | _scores = multi_scores[cls_inds, i]
92 | if score_factors is not None:
93 | _scores *= score_factors[cls_inds]
94 | cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1)
95 | cls_dets, _ = nms_op(cls_dets, **nms_cfg_)
96 | cls_labels = multi_bboxes.new_full((cls_dets.shape[0], ),
97 | i - 1,
98 | dtype=torch.long)
99 | bboxes.append(cls_dets)
100 | labels.append(cls_labels)
101 | if bboxes:
102 | bboxes = torch.cat(bboxes)
103 | labels = torch.cat(labels)
104 | if bboxes.shape[0] > max_num:
105 | _, inds = bboxes[:, -1].sort(descending=True)
106 | inds = inds[:max_num]
107 | bboxes = bboxes[inds]
108 | labels = labels[inds]
109 | else:
110 | bboxes = multi_bboxes.new_zeros((0, 5))
111 | labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
112 |
113 | return bboxes, labels
--------------------------------------------------------------------------------
/models/segbase.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from .backbone.resnetv1b import resnet50_v1s,resnet101_v1s
3 | from .backbone.res2netv1b import res2net50_v1b
4 | from .backbone.resnext import resnext50_32x4d,resnext101_32x8d
5 | from .backbone.seg_hrnet import hrnet_w18_v2,hrnet_w32,hrnet_w44,hrnet_w48,hrnet_w64
6 | from .backbone.resnextdcn import resnext101_32x8d_dcn
7 | from .backbone.swim_transformer import swim_large
8 | from torch.nn.modules.batchnorm import _BatchNorm
9 | __all__ = ['SegBaseModel']
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | class SegBaseModel(nn.Module):
14 | r"""Base Model for Semantic Segmentation
15 | Parameters
16 | ----------
17 | backbone : string
18 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
19 | 'resnet101' or 'resnet152').
20 | resnest
21 | resnext
22 | res2net
23 | DLA
24 | """
25 | def __init__(self, backbone='enc', pretrained_base=False, frozen_stages=-1,norm_eval=False, **kwargs):
26 | super(SegBaseModel, self).__init__()
27 | self.norm_eval=norm_eval
28 | self.frozen_stages=frozen_stages
29 | if backbone == 'resnext101dcn':
30 | self.pretrained = resnext101_32x8d_dcn(pretrained=pretrained_base, dilated=False, **kwargs)
31 | elif backbone == 'resnet50':
32 | self.pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=False, **kwargs)
33 | elif backbone == 'hrnet18':
34 | self.pretrained = hrnet_w18_v2(pretrained=pretrained_base, dilated=False, **kwargs)
35 | elif backbone =='hrnet32':
36 | self.pretrained=hrnet_w32(pretrained=pretrained_base, **kwargs)
37 | elif backbone =='hrnet64':
38 | self.pretrained=hrnet_w64(pretrained=pretrained_base, **kwargs)
39 | elif backbone == 'resnet101':
40 | self.pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=False, **kwargs)
41 | elif backbone == 'resnext101':
42 | self.pretrained = resnext101_32x8d(pretrained=pretrained_base, dilated=False, **kwargs)
43 | elif backbone == 'res2net50':
44 | self.pretrained = res2net50_v1b(pretrained=pretrained_base, **kwargs)
45 | elif backbone == 'resnext50':
46 | self.pretrained = resnext50_32x4d(pretrained=pretrained_base, dilated=False, **kwargs)
47 | elif backbone == 'swin':
48 | self.pretrained = swim_large(pretrained=False)
49 | else:
50 | raise RuntimeError('unknown backbone: {}'.format(backbone))
51 | self._train()
52 |
53 |
54 | def set_requires_grad(self, nets, requires_grad=False):
55 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
56 | Parameters:
57 | nets (network list) -- a list of networks
58 | requires_grad (bool) -- whether the networks require gradients or not
59 | """
60 | if not isinstance(nets, list):
61 | nets = [nets]
62 | for net in nets:
63 | if net is not None:
64 | for param in net.parameters():
65 | param.requires_grad = requires_grad
66 |
67 | def freeze(self):
68 | self.set_requires_grad(self.pretrained,False)
69 | self.set_requires_grad([self.pretrained.conv1,self.pretrained.bn1],True)
70 |
71 | def unfreeze(self):
72 | self.set_requires_grad([self.pretrained],True)
73 |
74 | def _train(self, mode=True):
75 | super(SegBaseModel, self).train(mode)
76 | self._freeze_stages()
77 | if mode and self.norm_eval:
78 | print('Freeze backbone BN using running mean and std')
79 | for m in self.modules():
80 | # trick: eval have effect on BatchNorm only
81 | if isinstance(m, _BatchNorm):
82 | m.eval()
83 |
84 | def _freeze_stages(self):
85 | if self.frozen_stages >= 0:
86 | self.pretrained.bn1.eval()
87 | for m in [self.pretrained.conv1, self.pretrained.bn1]:
88 | for param in m.parameters():
89 | param.requires_grad = False
90 | if hasattr(self.pretrained, 'conv2'):
91 | self.pretrained.bn2.eval()
92 | for m in [self.pretrained.conv2, self.pretrained.bn2]:
93 | for param in m.parameters():
94 | param.requires_grad = False
95 |
96 | print(f'Freezing backbone stage {self.frozen_stages}')
97 | for i in range(1, self.frozen_stages + 1):
98 | m = getattr(self.pretrained, 'layer{}'.format(i))
99 | m.eval()
100 | for param in m.parameters():
101 | param.requires_grad = False
102 |
103 | # trick: only train conv1 since not use ImageNet mean and std image norm
104 | if hasattr(self.pretrained,'conv1'):
105 | print('active train conv1 and bn1')
106 | self.set_requires_grad([self.pretrained.conv1,self.pretrained.bn1],True)
107 |
108 | def base_forward(self, x):
109 | """forwarding pre-trained network"""
110 | x = self.pretrained.conv1(x)
111 | x = self.pretrained.bn1(x)
112 | x = self.pretrained.relu(x)
113 | c1 = self.pretrained.maxpool(x)
114 | c2 = self.pretrained.layer1(c1)
115 | c3 = self.pretrained.layer2(c2)
116 | c4 = self.pretrained.layer3(c3)
117 | c5 = self.pretrained.layer4(c4)
118 | return c1, c2, c3, c4, c5
119 |
120 |
121 |
122 |
123 |
--------------------------------------------------------------------------------
/losses/generalise_focal_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | def reduce_loss(loss, reduction):
5 | """Reduce loss as specified.
6 | Args:
7 | loss (Tensor): Elementwise loss tensor.
8 | reduction (str): Options are "none", "mean" and "sum".
9 | Return:
10 | Tensor: Reduced loss tensor.
11 | """
12 | reduction_enum = F._Reduction.get_enum(reduction)
13 | # none: 0, elementwise_mean:1, sum: 2
14 | if reduction_enum == 0:
15 | return loss
16 | elif reduction_enum == 1:
17 | return loss.mean()
18 | elif reduction_enum == 2:
19 | return loss.sum()
20 |
21 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
22 | """Apply element-wise weight and reduce loss.
23 | Args:
24 | loss (Tensor): Element-wise loss.
25 | weight (Tensor): Element-wise weights.
26 | reduction (str): Same as built-in losses of PyTorch.
27 | avg_factor (float): Avarage factor when computing the mean of losses.
28 | Returns:
29 | Tensor: Processed loss values.
30 | """
31 | # if weight is specified, apply element-wise weight
32 | if weight is not None:
33 | loss = loss * weight
34 |
35 | # if avg_factor is not specified, just reduce the loss
36 | if avg_factor is None:
37 | loss = reduce_loss(loss, reduction)
38 | else:
39 | # if reduction is mean, then average the loss by avg_factor
40 | if reduction == 'mean':
41 | loss = loss.sum() / avg_factor
42 | # if reduction is 'none', then do nothing, otherwise raise an error
43 | elif reduction != 'none':
44 | raise ValueError('avg_factor can not be used with reduction="sum"')
45 | return loss
46 |
47 |
48 | def quality_focal_loss(
49 | pred, # (n, 80)
50 | label, # (n) 0, 1-80: 0 is neg, 1-80 is positive
51 | score, # (n) reg target 0-1, only positive is good
52 | weight=None,
53 | beta=2.0,
54 | reduction='mean',
55 | avg_factor=None):
56 | # all goes to 0
57 | pred_sigmoid = pred.sigmoid()
58 | pt = pred_sigmoid
59 | zerolabel = pt.new_zeros(pred.shape)
60 | loss = F.binary_cross_entropy_with_logits(
61 | pred, zerolabel, reduction='none') * pt.pow(beta)
62 |
63 | label = label - 1
64 | pos = (label >= 0).nonzero().squeeze(1)
65 | a = pos
66 | b = label[pos].long()
67 |
68 | # positive goes to bbox quality
69 | pt = score[a] - pred_sigmoid[a, b]
70 | loss[a, b] = F.binary_cross_entropy_with_logits(
71 | pred[a, b], score[a], reduction='none') * pt.pow(beta)
72 |
73 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
74 | return loss
75 |
76 |
77 | def distribution_focal_loss(
78 | pred,
79 | label,
80 | weight=None,
81 | reduction='mean',
82 | avg_factor=None):
83 | disl = label.long()
84 | disr = disl + 1
85 |
86 | wl = disr.float() - label
87 | wr = label - disl.float()
88 |
89 | loss = F.cross_entropy(pred, disl, reduction='none') * wl \
90 | + F.cross_entropy(pred, disr, reduction='none') * wr
91 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
92 | return loss
93 |
94 | class QualityFocalLoss(nn.Module):
95 | def __init__(self,
96 | use_sigmoid=True,
97 | beta=2.0,
98 | reduction='mean',
99 | loss_weight=1.0):
100 | super(QualityFocalLoss, self).__init__()
101 | assert use_sigmoid is True, 'Only sigmoid in QFL supported now.'
102 | self.use_sigmoid = use_sigmoid
103 | self.beta = beta
104 | self.reduction = reduction
105 | self.loss_weight = loss_weight
106 |
107 | def forward(self,
108 | pred,
109 | target,
110 | score,
111 | weight=None,
112 | avg_factor=None,
113 | reduction_override=None):
114 | assert reduction_override in (None, 'none', 'mean', 'sum')
115 | reduction = (
116 | reduction_override if reduction_override else self.reduction)
117 | if self.use_sigmoid:
118 | loss_cls = self.loss_weight * quality_focal_loss(
119 | pred,
120 | target,
121 | score,
122 | weight,
123 | beta=self.beta,
124 | reduction=reduction,
125 | avg_factor=avg_factor)
126 | else:
127 | raise NotImplementedError
128 | return loss_cls
129 |
130 | class DistributionFocalLoss(nn.Module):
131 |
132 | def __init__(self,
133 | reduction='mean',
134 | loss_weight=1.0):
135 | super(DistributionFocalLoss, self).__init__()
136 | self.reduction = reduction
137 | self.loss_weight = loss_weight
138 |
139 | def forward(self,
140 | pred,
141 | target,
142 | weight=None,
143 | avg_factor=None,
144 | reduction_override=None):
145 | assert reduction_override in (None, 'none', 'mean', 'sum')
146 | reduction = (
147 | reduction_override if reduction_override else self.reduction)
148 | loss_cls = self.loss_weight * distribution_focal_loss(
149 | pred,
150 | target,
151 | weight,
152 | reduction=reduction,
153 | avg_factor=avg_factor)
154 | return loss_cls
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PointNu-Net
2 |
3 | ## PointNu-Net: Keypoint-assisted Convolutional Neural Network for Simultaneous Multi-tissue Histology Nuclei Segmentation and Classification. [Ppaer](https://ieeexplore.ieee.org/abstract/document/10148651) [ArXiv](https://arxiv.org/pdf/2111.01557.pdf)
4 | Kai Yao, Kaizhu Huang, Jie Sun, Amir Hussain, Curran Jude \
5 | Both University of Liverpool and Xi'an Jiaotong-liverpool University
6 |
7 | **Abstract**
8 |
9 | Automatic nuclei segmentation and classification play a vital role in digital pathology. However, previous works are mostly built on data with limited diversity and small sizes, making the results questionable or misleading in actual downstream tasks. In this paper, we aim to build a reliable and robust method capable of dealing with data from the ‘the clinical wild’. Specifically, we study and design a new method to simultaneously detect, segment, and classify nuclei from Haematoxylin and Eosin (H\&E) stained histopathology data, and evaluate our approach using the recent largest dataset: PanNuke. We address the detection and classification of each nuclei as a novel semantic keypoint estimation problem to determine the center point of each nuclei. Next, the corresponding class-agnostic masks for nuclei center points are obtained using dynamic instance segmentation. Meanwhile, we proposed a novel Joint Pyramid Fusion Module (JPFM) to model the cross-scale dependencies, thus enhancing the local feature for better nuclei detection and classification. By decoupling two simultaneous challenging tasks and taking advantage of JPFM, our method can benefit from class-aware detection and class-agnostic segmentation, thus leading to a significant performance boost. We demonstrate the superior performance of our proposed approach for nuclei segmentation and classification across 19 different tissue types, delivering new benchmark results.
10 |
11 | ## News:
12 | \[2023/12/25\] We add additional experimental results on CPM17 (PQ 0.706, AJI 0.712). \
13 | \[2023/5/15\] Our paper have been accepted by TETCI. \
14 | \[2023/5/1\] We release the training and inference code, and the training instruction.
15 |
16 |
17 | ## 1. Installation
18 |
19 | Clone this repo.
20 | ```bash
21 | git clone https://github.com/Kaiseem/PointNu-Net.git
22 | cd PointNu-Net/
23 | ```
24 |
25 | This code requires PyTorch 1.10+ and python 3+. Please install dependencies by
26 | ```bash
27 | pip install -r requirements.txt
28 | ```
29 |
30 |
31 | ## 2. Data preparation
32 |
33 | For small dataset Kumar and CoNSeP, we conduct datasets preparation following [Hover-Net](https://github.com/vqdang/hover_net).
34 |
35 | We provide the [processed Kumar and CoNSeP datasets](https://drive.google.com/file/d/1_eI_ii6xcNe_77NWx7Qo8_KndK5UwPBO/view?usp=sharing). Also, we provide the [processed CPM17 dataset](https://drive.google.com/file/d/1igsZ2oBUmylPSzhsXn_3TCOWZNe_Zfd8/view?usp=sharing).
36 |
37 | The [PanNuKe](https://arxiv.org/pdf/2003.10778v7.pdf) datasets can be found [here](https://warwick.ac.uk/fac/sci/dcs/research/tia/data/pannuke)
38 |
39 | Download and unzip all the files where the folder structure should look this:
40 |
41 | ```none
42 | PointNu-Net
43 | ├── ...
44 | ├── datasets
45 | │ ├── kumar
46 | │ │ ├── train
47 | │ │ ├── test
48 | │ ├── CoNSeP
49 | │ │ ├── train
50 | │ │ ├── test
51 | │ ├── PanNuKe
52 | │ │ ├── images
53 | │ │ │ ├── fold1
54 | │ │ │ │ ├── images.npy
55 | │ │ │ │ ├── types.npy
56 | │ │ │ ├── fold2
57 | │ │ │ │ ├── images.npy
58 | │ │ │ │ ├── types.npy
59 | │ │ │ ├── fold3
60 | │ │ │ │ ├── images.npy
61 | │ │ │ │ ├── types.npy
62 | │ │ ├── masks
63 | │ │ │ ├── fold1
64 | │ │ │ │ ├── masks.npy
65 | │ │ │ ├── fold2
66 | │ │ │ │ ├── masks.npy
67 | │ │ │ ├── fold3
68 | │ │ │ │ ├── masks.npy
69 | ├── ...
70 | ```
71 |
72 | ## 3. Training and Inference
73 | To reproduce the performance, you need one 3090 GPU at least.
74 |
75 | Download ImageNet Pretrain Checkpoints from official [HRNet](https://github.com/HRNet/HRNet-Image-Classification).
76 |
77 |
78 |
79 | 1) Kumar Dataset
80 |
81 |
82 | run the command to train the model
83 | ```bash
84 | python train.py --name=kumar_exp --seed=888 --config=configs/kumar_notype_large.yaml
85 | ```
86 |
87 | run the command to inference
88 | ```bash
89 | python inference.py --name=kumar_exp
90 | ```
91 |
92 |
93 |
94 |
95 | 2) CoNSeP Dataset
96 |
97 |
98 | run the command to train the model
99 | ```bash
100 | python train.py --name=consep_exp --seed=888 --config=configs/consep_type_large.yaml
101 | ```
102 |
103 | run the command to inference
104 | ```bash
105 | python inference.py --name=consep_exp
106 | ```
107 |
108 |
109 |
110 |
111 |
112 | 2) PanNuKe Dataset
113 |
114 |
115 | run the command to train the model
116 | ```bash
117 | python train_pannuke.py --name=pannuke_exp --seed=888 --train_fold={} --val_fold={} --test_fold={}
118 | ```
119 | [train_fold, val_fold, test_fold] should be selected from {[1, 2, 3], [2, 1, 3], [3, 2, 1]}
120 |
121 | run the command to inference the model
122 | ```bash
123 | python infer_pannuke.py --name=pannuke_exp --train_fold={} --test_fold={}
124 | ```
125 |
126 | run the command to evaluate the performance
127 | ```bash
128 | python eval_pannuke.py --name=pannuke_exp --train_fold={} --test_fold={}
129 | ```
130 |
131 |
132 |
133 |
134 | ## Citation
135 | If our work or code helps you, please consider to cite our paper. Thank you!
136 |
137 | ```
138 | @article{yao2023pointnu,
139 | title={PointNu-Net: Keypoint-Assisted Convolutional Neural Network for Simultaneous Multi-Tissue Histology Nuclei Segmentation and Classification},
140 | author={Yao, Kai and Huang, Kaizhu and Sun, Jie and Hussain, Amir},
141 | journal={IEEE Transactions on Emerging Topics in Computational Intelligence},
142 | year={2023},
143 | }
144 | ```
145 |
--------------------------------------------------------------------------------
/utils/imop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.ndimage.interpolation import shift
3 | import cv2
4 | from collections import Counter
5 | import math
6 | from scipy import ndimage
7 | def get_ins_info(seg_mask,method='bbox'):
8 | methods=['bbox','circle','area']
9 | assert method in methods, f'instance segmentation information should in {methods}'
10 | if method=='circle':
11 | contours, hierachy = cv2.findContours((seg_mask * 255).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
12 | (center_w, center_h), EC_radius = cv2.minEnclosingCircle(contours[0])
13 | return center_w,center_h,EC_radius*2,EC_radius*2
14 | elif method=='bbox':
15 | bbox_x, bbox_y, bbox_w, bbox_h = cv2.boundingRect(np.array(seg_mask).astype(np.uint8))
16 | center_w = bbox_x + bbox_w / 2
17 | center_h = bbox_y + bbox_h / 2
18 | return center_w, center_h, bbox_w, bbox_h
19 | elif method=='area':
20 | center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
21 | equal_diameter=(np.sum(seg_mask)/3.1415)**0.5*2
22 | return center_w,center_h,equal_diameter,equal_diameter
23 | else:
24 | raise NotImplementedError
25 |
26 | def gaussian_radius(det_size, min_overlap=0.7):
27 |
28 | #https://github.com/princeton-vl/CornerNet/blob/master/sample/utils.py
29 | height, width = det_size
30 |
31 | a1 = 1
32 | b1 = (height + width)
33 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
34 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
35 | r1 = (b1 - sq1) / 2
36 |
37 | a2 = 4
38 | b2 = 2 * (height + width)
39 | c2 = (1 - min_overlap) * width * height
40 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
41 | r2 = (b2 - sq2) / 2
42 |
43 | a3 = 4 * min_overlap
44 | b3 = -2 * min_overlap * (height + width)
45 | c3 = (min_overlap - 1) * width * height
46 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
47 | r3 = (b3 + sq3) / 2
48 |
49 | return min(r1, r2, r3)
50 |
51 | def gaussian_radius_new(det_size, min_overlap=0.7):
52 | #https://github.com/princeton-vl/CornerNet/blob/master/sample/utils.py
53 | height, width = det_size
54 |
55 | a1 = 1
56 | b1 = (height + width)
57 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
58 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
59 | r1 = (b1 - sq1) / (2 * a1)
60 |
61 | a2 = 4
62 | b2 = 2 * (height + width)
63 | c2 = (1 - min_overlap) * width * height
64 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
65 | r2 = (b2 - sq2) / (2 * a2)
66 |
67 | a3 = 4 * min_overlap
68 | b3 = -2 * min_overlap * (height + width)
69 | c3 = (min_overlap - 1) * width * height
70 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
71 | r3 = (b3 + sq3) / (2 * a3)
72 |
73 | return min(r1, r2, r3)
74 |
75 | def gaussian2D(shape, sigma=1.):
76 | m, n = [(ss - 1.) / 2. for ss in shape]
77 | y, x = np.ogrid[-m:m+1,-n:n+1]
78 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
79 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
80 | return h
81 |
82 | def draw_gaussian(heatmap, center, radius):
83 | diameter = float(2 * (int(radius)+1) + 1)
84 | gaussian = gaussian2D((diameter, diameter), sigma = radius/3)#gaussian_2d_kernel(int(diameter),radius/3)#
85 | coord_w, coord_h = center
86 | height, width = heatmap.shape
87 | temp=np.zeros((height,width), dtype=np.float)
88 | temp = insert_image(temp, gaussian, coord_h, coord_w)
89 | np.maximum(heatmap, temp, out=heatmap)
90 | return temp
91 |
92 | def gaussian_2d_kernel(kernel_size=3, sigma=0):
93 | kernel = np.zeros([kernel_size, kernel_size])
94 | center = kernel_size // 2
95 | if sigma == 0:
96 | sigma = ((kernel_size - 1) * 0.5 - 1) * 0.3 + 0.8
97 | s = 2 * (sigma ** 2)
98 |
99 | for i in range(0, kernel_size):
100 | for j in range(0, kernel_size):
101 | x = i - center
102 | y = j - center
103 | kernel[i, j] = np.exp(-(x ** 2 + y ** 2) / s)
104 | return kernel
105 |
106 | def insert_image(img, kernel,h,w):
107 | ks=kernel.shape[0]
108 | if ks !=0:
109 | half_ks=ks//2
110 | img=np.pad(img,((half_ks,half_ks),(half_ks,half_ks)))
111 | img[h:h+ks,w:w+ks]=kernel
112 | return img[half_ks:-half_ks,half_ks:-half_ks]
113 | else:
114 | img[h:h+1,w:w+1]=kernel
115 | return img
116 |
117 | def label_relaxation(msk, border_window=3):
118 | msk=msk.astype(np.float)
119 | border=border_window//2
120 | output=np.zeros_like(msk)
121 | for i in range(-border,border+1,1):
122 | for j in range(-border, border + 1, 1):
123 | output+=shift(msk,shift=[i,j],cval=0)
124 | output/=border_window**2
125 | return output
126 |
127 | def ensemble_img(ins_list,cate_list,score_list):
128 | N,w,h=ins_list.shape
129 | output_img=np.zeros((w,h),dtype=np.int16)
130 | #print(ins_list.shape,cate_list.shape,score_list.shape)
131 | for i in range(N):
132 | ins_num = i+1
133 | ins_=ins_list[i]
134 | score_=score_list[i]
135 | if np.sum(np.logical_and(output_img>0,ins_))==0:
136 | output_img=np.where(output_img>0,output_img,ins_*ins_num)
137 | else:
138 | compared_num,_ = Counter((output_img*ins_).flatten()).most_common(2)[1]
139 | assert compared_num>0
140 | #print( Counter((output_img*ins_).flatten()).most_common(2))
141 | compared_num=int(compared_num)
142 | compared_score=score_list[compared_num-1]
143 | if np.sum(np.logical_and(output_img==compared_num,ins_))/np.sum(np.logical_or(output_img==compared_num,ins_))>0.5:
144 | if compared_score>score_:
145 | pass
146 | else:
147 | output_img[output_img==compared_num]=0
148 | output_img=np.where(output_img>0,output_img,ins_*ins_num)
149 | else:
150 | output_img = np.where(output_img > 0, output_img, ins_ * ins_num)
151 | return output_img
--------------------------------------------------------------------------------
/models/backbone/resnext.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 | __all__ = ['ResNext', 'resnext50_32x4d', 'resnext101_32x8d']
5 |
6 | model_urls = {
7 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
8 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
9 | }
10 |
11 |
12 | class Bottleneck(nn.Module):
13 | expansion = 4
14 |
15 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
16 | base_width=64, dilation=1, norm_layer=None, **kwargs):
17 | super(Bottleneck, self).__init__()
18 | width = int(planes * (base_width / 64.)) * groups
19 |
20 | self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False)
21 | self.bn1 = norm_layer(width)
22 | self.conv2 = nn.Conv2d(width, width, 3, stride, dilation, dilation, groups, bias=False)
23 | self.bn2 = norm_layer(width)
24 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False)
25 | self.bn3 = norm_layer(planes * self.expansion)
26 | self.relu = nn.ReLU(True)
27 | self.downsample = downsample
28 | self.stride = stride
29 |
30 | def forward(self, x):
31 | identity = x
32 |
33 | out = self.conv1(x)
34 | out = self.bn1(out)
35 | out = self.relu(out)
36 |
37 | out = self.conv2(out)
38 | out = self.bn2(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv3(out)
42 | out = self.bn3(out)
43 |
44 | if self.downsample is not None:
45 | identity = self.downsample(x)
46 |
47 | out += identity
48 | out = self.relu(out)
49 |
50 | return out
51 |
52 |
53 | class ResNext(nn.Module):
54 |
55 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1,
56 | width_per_group=64, dilated=False, norm_layer=nn.BatchNorm2d,deep_stem=False, **kwargs):
57 | super(ResNext, self).__init__()
58 | self.inplanes = 64
59 | self.groups = groups
60 | self.base_width = width_per_group
61 | if deep_stem:
62 | self.conv1 = nn.Sequential(
63 | nn.Conv2d(3, 64, 3, 2, 1, bias=False),
64 | norm_layer(64),
65 | nn.ReLU(True),
66 | nn.Conv2d(64, 64, 3, 1, 1, bias=False),
67 | norm_layer(64),
68 | nn.ReLU(True),
69 | nn.Conv2d(64, 128, 3, 1, 1, bias=False)
70 | )
71 | else:
72 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
73 | self.bn1 = norm_layer(self.inplanes)
74 | self.relu = nn.ReLU(True)
75 | self.maxpool = nn.MaxPool2d(3, 2, 1)
76 |
77 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
78 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
79 | if dilated:
80 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer)
81 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer)
82 | else:
83 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
84 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
85 |
86 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
87 | self.fc = nn.Linear(512 * block.expansion, num_classes)
88 |
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
92 | elif isinstance(m, nn.BatchNorm2d):
93 | nn.init.constant_(m.weight, 1)
94 | nn.init.constant_(m.bias, 0)
95 |
96 | if zero_init_residual:
97 | for m in self.modules():
98 | if isinstance(m, Bottleneck):
99 | nn.init.constant_(m.bn3.weight, 0)
100 |
101 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):
102 | downsample = None
103 | if stride != 1 or self.inplanes != planes * block.expansion:
104 | downsample = nn.Sequential(
105 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False),
106 | norm_layer(planes * block.expansion)
107 | )
108 |
109 | layers = list()
110 | if dilation in (1, 2):
111 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
112 | self.base_width, norm_layer=norm_layer))
113 | elif dilation == 4:
114 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
115 | self.base_width, dilation=2, norm_layer=norm_layer))
116 | else:
117 | raise RuntimeError("=> unknown dilation size: {}".format(dilation))
118 | self.inplanes = planes * block.expansion
119 | for _ in range(1, blocks):
120 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width,
121 | dilation=dilation, norm_layer=norm_layer))
122 |
123 | return nn.Sequential(*layers)
124 |
125 | def forward(self, x):
126 | x = self.conv1(x)
127 | x = self.bn1(x)
128 | x = self.relu(x)
129 | x = self.maxpool(x)
130 |
131 | x = self.layer1(x)
132 | x = self.layer2(x)
133 | x = self.layer3(x)
134 | x = self.layer4(x)
135 |
136 | x = self.avgpool(x)
137 | x = x.view(x.size(0), -1)
138 | x = self.fc(x)
139 |
140 | return x
141 |
142 |
143 | def resnext50_32x4d(pretrained=False, **kwargs):
144 | kwargs['groups'] = 32
145 | kwargs['width_per_group'] = 4
146 | model = ResNext(Bottleneck, [3, 4, 6, 3], **kwargs)
147 | if pretrained:
148 | state_dict = model_zoo.load_url(model_urls['resnext50_32x4d'])
149 | model.load_state_dict(state_dict)
150 | return model
151 |
152 |
153 | def resnext101_32x8d(pretrained=False, **kwargs):
154 | kwargs['groups'] = 32
155 | kwargs['width_per_group'] = 8
156 | model = ResNext(Bottleneck, [3, 4, 23, 3], **kwargs)
157 | if pretrained:
158 | state_dict = model_zoo.load_url(model_urls['resnext101_32x8d'])
159 | model.load_state_dict(state_dict)
160 | return model
161 |
162 |
163 | if __name__ == '__main__':
164 | model = resnext101_32x8d()
--------------------------------------------------------------------------------
/models/backbone/resnextdcn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 | from ..base_ops.DCNv2 import DeformConv
4 | __all__ = ['ResNext', 'resnext101_32x8d_dcn']
5 |
6 | model_urls = {
7 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
8 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
9 | }
10 |
11 | class Bottleneck(nn.Module):
12 | expansion = 4
13 |
14 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
15 | base_width=64, dilation=1, norm_layer=None, use_dcn=False, **kwargs):
16 | super(Bottleneck, self).__init__()
17 | width = int(planes * (base_width / 64.)) * groups
18 |
19 | self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False)
20 | self.bn1 = norm_layer(width)
21 | if use_dcn:
22 | self.conv2 = DeformConv(width, width, kernel_size=3, stride=stride,
23 | padding=dilation,dilation=dilation, groups=groups)
24 |
25 | # self.conv2.conv_offset.weight.data.zero_()
26 | # self.conv2.conv_offset.bias.data.zero_()
27 | else:
28 | self.conv2 = nn.Conv2d(width, width, 3, stride, dilation, dilation, groups, bias=False)
29 | self.bn2 = norm_layer(width)
30 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False)
31 | self.bn3 = norm_layer(planes * self.expansion)
32 | self.relu = nn.ReLU(True)
33 | self.downsample = downsample
34 | self.stride = stride
35 |
36 | def forward(self, x):
37 | identity = x
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv3(out)
48 | out = self.bn3(out)
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu(out)
55 |
56 | return out
57 |
58 |
59 | class ResNext(nn.Module):
60 |
61 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1,
62 | width_per_group=64, norm_layer=nn.BatchNorm2d,deep_stem=False, **kwargs):
63 | super(ResNext, self).__init__()
64 |
65 | self.inplanes = 64
66 | self.groups = groups
67 | self.base_width = width_per_group
68 | if deep_stem:
69 | self.conv1 = nn.Sequential(
70 | nn.Conv2d(3, 64, 3, 2, 1, bias=False),
71 | norm_layer(64),
72 | nn.ReLU(True),
73 | nn.Conv2d(64, 64, 3, 1, 1, bias=False),
74 | norm_layer(64),
75 | nn.ReLU(True),
76 | nn.Conv2d(64, 128, 3, 1, 1, bias=False)
77 | )
78 | else:
79 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
80 | self.bn1 = norm_layer(self.inplanes)
81 | self.relu = nn.ReLU(True)
82 | self.maxpool = nn.MaxPool2d(3, 2, 1)
83 |
84 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer,use_dcn=False)
85 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer,use_dcn=True)
86 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer,use_dcn=True)
87 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer,use_dcn=True)
88 |
89 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
90 | self.fc = nn.Linear(512 * block.expansion, num_classes)
91 |
92 | for m in self.modules():
93 | if isinstance(m, nn.Conv2d):
94 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
95 | elif isinstance(m, nn.BatchNorm2d):
96 | nn.init.constant_(m.weight, 1)
97 | nn.init.constant_(m.bias, 0)
98 |
99 | if zero_init_residual:
100 | for m in self.modules():
101 | if isinstance(m, Bottleneck):
102 | nn.init.constant_(m.bn3.weight, 0)
103 |
104 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d,use_dcn=False):
105 | downsample = None
106 | if stride != 1 or self.inplanes != planes * block.expansion:
107 | downsample = nn.Sequential(
108 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False),
109 | norm_layer(planes * block.expansion)
110 | )
111 |
112 | layers = list()
113 | if dilation in (1, 2):
114 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
115 | self.base_width, norm_layer=norm_layer,use_dcn=use_dcn))
116 | elif dilation == 4:
117 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
118 | self.base_width, dilation=2, norm_layer=norm_layer,use_dcn=use_dcn))
119 | else:
120 | raise RuntimeError("=> unknown dilation size: {}".format(dilation))
121 |
122 | self.inplanes = planes * block.expansion
123 | for _ in range(1, blocks):
124 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width,
125 | dilation=1, norm_layer=norm_layer,use_dcn=use_dcn))
126 |
127 | return nn.Sequential(*layers)
128 |
129 | def forward(self, x):
130 | x = self.conv1(x)
131 | x = self.bn1(x)
132 | x = self.relu(x)
133 | x = self.maxpool(x)
134 |
135 | x = self.layer1(x)
136 | x = self.layer2(x)
137 | x = self.layer3(x)
138 | x = self.layer4(x)
139 |
140 | x = self.avgpool(x)
141 | x = x.view(x.size(0), -1)
142 | x = self.fc(x)
143 |
144 | return x
145 |
146 | def resnext101_32x8d_dcn(pretrained=False, **kwargs):
147 | kwargs['groups'] = 32
148 | kwargs['width_per_group'] = 8
149 | model = ResNext(Bottleneck, [3, 4, 23, 3], **kwargs)
150 | if pretrained:
151 | old_dict = model_zoo.load_url(model_urls['resnext101_32x8d'])
152 |
153 | model_dict = model.state_dict()
154 | for i,(k,v) in enumerate(old_dict.items()):
155 | if k not in model_dict:
156 | print(k)
157 | pass
158 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
159 | model_dict.update(old_dict)
160 | model.load_state_dict(model_dict)
161 |
162 | return model
163 |
164 |
165 | if __name__ == '__main__':
166 | model = resnext101_32x8d_dcn()
--------------------------------------------------------------------------------
/models/backbone/res2next.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn import init
6 | import torch
7 | import torch.utils.model_zoo as model_zoo
8 |
9 | __all__ = ['res2next50']
10 | model_urls = {
11 | 'res2next50': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2next50_4s-6ef7e7bf.pth',
12 | }
13 |
14 | class Bottle2neckX(nn.Module):
15 | expansion = 4
16 |
17 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None, scale = 4, stype='normal'):
18 | """ Constructor
19 | Args:
20 | inplanes: input channel dimensionality
21 | planes: output channel dimensionality
22 | baseWidth: base width.
23 | cardinality: num of convolution groups.
24 | stride: conv stride. Replaces pooling layer.
25 | scale: number of scale.
26 | type: 'normal': normal set. 'stage': frist blokc of a new stage.
27 | """
28 | super(Bottle2neckX, self).__init__()
29 |
30 | D = int(math.floor(planes * (baseWidth/64.0)))
31 | C = cardinality
32 |
33 | self.conv1 = nn.Conv2d(inplanes, D*C*scale, kernel_size=1, stride=1, padding=0, bias=False)
34 | self.bn1 = nn.BatchNorm2d(D*C*scale)
35 |
36 | if scale == 1:
37 | self.nums = 1
38 | else:
39 | self.nums = scale -1
40 | if stype == 'stage':
41 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1)
42 | convs = []
43 | bns = []
44 | for i in range(self.nums):
45 | convs.append(nn.Conv2d(D*C, D*C, kernel_size=3, stride = stride, padding=1, groups=C, bias=False))
46 | bns.append(nn.BatchNorm2d(D*C))
47 | self.convs = nn.ModuleList(convs)
48 | self.bns = nn.ModuleList(bns)
49 |
50 | self.conv3 = nn.Conv2d(D*C*scale, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
51 | self.bn3 = nn.BatchNorm2d(planes * 4)
52 | self.relu = nn.ReLU(inplace=True)
53 |
54 | self.downsample = downsample
55 | self.width = D*C
56 | self.stype = stype
57 | self.scale = scale
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | spx = torch.split(out, self.width, 1)
67 | for i in range(self.nums):
68 | if i==0 or self.stype=='stage':
69 | sp = spx[i]
70 | else:
71 | sp = sp + spx[i]
72 | sp = self.convs[i](sp)
73 | sp = self.relu(self.bns[i](sp))
74 | if i==0:
75 | out = sp
76 | else:
77 | out = torch.cat((out, sp), 1)
78 | if self.scale != 1 and self.stype=='normal':
79 | out = torch.cat((out, spx[self.nums]),1)
80 | elif self.scale != 1 and self.stype=='stage':
81 | out = torch.cat((out, self.pool(spx[self.nums])),1)
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 |
95 | class Res2NeXt(nn.Module):
96 | def __init__(self, block, baseWidth, cardinality, layers, num_classes, scale=4):
97 | """ Constructor
98 | Args:
99 | baseWidth: baseWidth for ResNeXt.
100 | cardinality: number of convolution groups.
101 | layers: config of layers, e.g., [3, 4, 6, 3]
102 | num_classes: number of classes
103 | scale: scale in res2net
104 | """
105 | super(Res2NeXt, self).__init__()
106 |
107 | self.cardinality = cardinality
108 | self.baseWidth = baseWidth
109 | self.num_classes = num_classes
110 | self.inplanes = 64
111 | self.output_size = 64
112 | self.scale = scale
113 |
114 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
115 | self.bn1 = nn.BatchNorm2d(64)
116 | self.relu = nn.ReLU(inplace=True)
117 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
118 | self.layer1 = self._make_layer(block, 64, layers[0])
119 | self.layer2 = self._make_layer(block, 128, layers[1], 2)
120 | self.layer3 = self._make_layer(block, 256, layers[2], 2)
121 | self.layer4 = self._make_layer(block, 512, layers[3], 2)
122 | self.avgpool = nn.AdaptiveAvgPool2d(1)
123 | self.fc = nn.Linear(512 * block.expansion, num_classes)
124 |
125 | for m in self.modules():
126 | if isinstance(m, nn.Conv2d):
127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
128 | m.weight.data.normal_(0, math.sqrt(2. / n))
129 | elif isinstance(m, nn.BatchNorm2d):
130 | m.weight.data.fill_(1)
131 | m.bias.data.zero_()
132 |
133 | def _make_layer(self, block, planes, blocks, stride=1):
134 | downsample = None
135 | if stride != 1 or self.inplanes != planes * block.expansion:
136 | downsample = nn.Sequential(
137 | nn.Conv2d(self.inplanes, planes * block.expansion,
138 | kernel_size=1, stride=stride, bias=False),
139 | nn.BatchNorm2d(planes * block.expansion),
140 | )
141 |
142 | layers = []
143 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample, scale=self.scale, stype='stage'))
144 | self.inplanes = planes * block.expansion
145 | for i in range(1, blocks):
146 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, scale=self.scale))
147 |
148 | return nn.Sequential(*layers)
149 |
150 | def forward(self, x):
151 | x = self.conv1(x)
152 | x = self.bn1(x)
153 | x = self.relu(x)
154 | x = self.maxpool1(x)
155 | x = self.layer1(x)
156 | x = self.layer2(x)
157 | x = self.layer3(x)
158 | x = self.layer4(x)
159 | x = self.avgpool(x)
160 | x = x.view(x.size(0), -1)
161 | x = self.fc(x)
162 |
163 | return x
164 | def res2next50(pretrained=False, **kwargs):
165 | """ Construct Res2NeXt-50.
166 | The default scale is 4.
167 | Args:
168 | pretrained (bool): If True, returns a model pre-trained on ImageNet
169 | """
170 | model = Res2NeXt(Bottle2neckX, layers = [3, 4, 6, 3], baseWidth = 4, cardinality=8, scale = 4, num_classes=1000)
171 | if pretrained:
172 | model.load_state_dict(model_zoo.load_url(model_urls['res2next50']))
173 | return model
174 |
175 | if __name__ == '__main__':
176 | images = torch.rand(1, 3, 224, 224).cuda(0)
177 | model = res2next50(pretrained=True)
178 | model = model.cuda(0)
179 | print(model(images).size())
180 |
--------------------------------------------------------------------------------
/utils/fmix.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import numpy as np
5 | from scipy.stats import beta
6 |
7 |
8 | def fftfreqnd(h, w=None, z=None):
9 | """ Get bin values for discrete fourier transform of size (h, w, z)
10 | :param h: Required, first dimension size
11 | :param w: Optional, second dimension size
12 | :param z: Optional, third dimension size
13 | """
14 | fz = fx = 0
15 | fy = np.fft.fftfreq(h)
16 |
17 | if w is not None:
18 | fy = np.expand_dims(fy, -1)
19 |
20 | if w % 2 == 1:
21 | fx = np.fft.fftfreq(w)[: w // 2 + 2]
22 | else:
23 | fx = np.fft.fftfreq(w)[: w // 2 + 1]
24 |
25 | if z is not None:
26 | fy = np.expand_dims(fy, -1)
27 | if z % 2 == 1:
28 | fz = np.fft.fftfreq(z)[:, None]
29 | else:
30 | fz = np.fft.fftfreq(z)[:, None]
31 |
32 | return np.sqrt(fx * fx + fy * fy + fz * fz)
33 |
34 |
35 | def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
36 | """ Samples a fourier image with given size and frequencies decayed by decay power
37 | :param freqs: Bin values for the discrete fourier transform
38 | :param decay_power: Decay power for frequency decay prop 1/f**d
39 | :param ch: Number of channels for the resulting mask
40 | :param h: Required, first dimension size
41 | :param w: Optional, second dimension size
42 | :param z: Optional, third dimension size
43 | """
44 | scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power)
45 |
46 | param_size = [ch] + list(freqs.shape) + [2]
47 | param = np.random.randn(*param_size)
48 |
49 | scale = np.expand_dims(scale, -1)[None, :]
50 |
51 | return scale * param
52 |
53 |
54 | def make_low_freq_image(decay, shape, ch=1):
55 | """ Sample a low frequency image from fourier space
56 | :param decay_power: Decay power for frequency decay prop 1/f**d
57 | :param shape: Shape of desired mask, list up to 3 dims
58 | :param ch: Number of channels for desired mask
59 | """
60 | freqs = fftfreqnd(*shape)
61 | spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1))
62 | spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
63 | mask = np.real(np.fft.irfftn(spectrum, shape))
64 |
65 | if len(shape) == 1:
66 | mask = mask[:1, :shape[0]]
67 | if len(shape) == 2:
68 | mask = mask[:1, :shape[0], :shape[1]]
69 | if len(shape) == 3:
70 | mask = mask[:1, :shape[0], :shape[1], :shape[2]]
71 |
72 | mask = mask
73 | mask = (mask - mask.min())
74 | mask = mask / mask.max()
75 | return mask
76 |
77 |
78 | def sample_lam(alpha, reformulate=False):
79 | """ Sample a lambda from symmetric beta distribution with given alpha
80 | :param alpha: Alpha value for beta distribution
81 | :param reformulate: If True, uses the reformulation of [1].
82 | """
83 | if reformulate:
84 | lam = beta.rvs(alpha+1, alpha)
85 | else:
86 | lam = beta.rvs(alpha, alpha)
87 |
88 | return lam
89 |
90 |
91 | def binarise_mask(mask, lam, in_shape, max_soft=0.0):
92 | """ Binarises a given low frequency image such that it has mean lambda.
93 | :param mask: Low frequency image, usually the result of `make_low_freq_image`
94 | :param lam: Mean value of final mask
95 | :param in_shape: Shape of inputs
96 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
97 | :return:
98 | """
99 | idx = mask.reshape(-1).argsort()[::-1]
100 | mask = mask.reshape(-1)
101 | num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size)
102 |
103 | eff_soft = max_soft
104 | if max_soft > lam or max_soft > (1-lam):
105 | eff_soft = min(lam, 1-lam)
106 |
107 | soft = int(mask.size * eff_soft)
108 | num_low = num - soft
109 | num_high = num + soft
110 |
111 | mask[idx[:num_high]] = 1
112 | mask[idx[num_low:]] = 0
113 | mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))
114 |
115 | mask = mask.reshape((1, *in_shape))
116 | return mask
117 |
118 |
119 | def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
120 | """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
121 | it based on this lambda
122 | :param alpha: Alpha value for beta distribution from which to sample mean of mask
123 | :param decay_power: Decay power for frequency decay prop 1/f**d
124 | :param shape: Shape of desired mask, list up to 3 dims
125 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
126 | :param reformulate: If True, uses the reformulation of [1].
127 | """
128 | if isinstance(shape, int):
129 | shape = (shape,)
130 |
131 | # Choose lambda
132 | lam = sample_lam(alpha, reformulate)
133 |
134 | # Make mask, get mean / std
135 | mask = make_low_freq_image(decay_power, shape)
136 | mask = binarise_mask(mask, lam, shape, max_soft)
137 |
138 | return lam, mask
139 |
140 |
141 | def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
142 | """
143 | :param x: Image batch on which to apply fmix of shape [b, c, shape*]
144 | :param alpha: Alpha value for beta distribution from which to sample mean of mask
145 | :param decay_power: Decay power for frequency decay prop 1/f**d
146 | :param shape: Shape of desired mask, list up to 3 dims
147 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
148 | :param reformulate: If True, uses the reformulation of [1].
149 | :return: mixed input, permutation indices, lambda value of mix,
150 | """
151 | lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
152 | index = np.random.permutation(x.shape[0])
153 |
154 | x1, x2 = x * mask, x[index] * (1-mask)
155 | return x1+x2, index, lam
156 |
157 |
158 | class FMixBase:
159 | r""" FMix augmentation
160 | Args:
161 | decay_power (float): Decay power for frequency decay prop 1/f**d
162 | alpha (float): Alpha value for beta distribution from which to sample mean of mask
163 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
164 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
165 | reformulate (bool): If True, uses the reformulation of [1].
166 | """
167 |
168 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
169 | super().__init__()
170 | self.decay_power = decay_power
171 | self.reformulate = reformulate
172 | self.size = size
173 | self.alpha = alpha
174 | self.max_soft = max_soft
175 | self.index = None
176 | self.lam = None
177 |
178 | def __call__(self, x):
179 | raise NotImplementedError
180 |
181 | def loss(self, *args, **kwargs):
182 | raise NotImplementedError
183 |
184 |
185 |
--------------------------------------------------------------------------------
/utils/_aug.py:
--------------------------------------------------------------------------------
1 | from imgaug import augmenters as iaa
2 | import cv2
3 | import numpy as np
4 | from scipy.ndimage import measurements
5 | """
6 | from Hover-Net augmentation
7 | https://github.com/vqdang/hover_net/blob/master/dataloader/augs.py
8 | """
9 |
10 | ####
11 | def fix_mirror_padding(ann):
12 | """Deal with duplicated instances due to mirroring in interpolation
13 | during shape augmentation (scale, rotation etc.).
14 |
15 | """
16 | current_max_id = np.amax(ann)
17 | inst_list = list(np.unique(ann))
18 | inst_list.remove(0) # 0 is background
19 | for inst_id in inst_list:
20 | inst_map = np.array(ann == inst_id, np.uint8)
21 | remapped_ids = measurements.label(inst_map)[0]
22 | remapped_ids[remapped_ids > 1] += current_max_id
23 | ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
24 | current_max_id = np.amax(ann)
25 | return ann
26 |
27 |
28 | ####
29 | def gaussian_blur(images, random_state, parents, hooks, max_ksize=3):
30 | """Apply Gaussian blur to input images."""
31 | img = images[0] # aleju input batch as default (always=1 in our case)
32 | ksize = random_state.randint(0, max_ksize, size=(2,))
33 | ksize = tuple((ksize * 2 + 1).tolist())
34 |
35 | ret = cv2.GaussianBlur(
36 | img, ksize, sigmaX=0, sigmaY=0, borderType=cv2.BORDER_REPLICATE
37 | )
38 | ret = np.reshape(ret, img.shape)
39 | ret = ret.astype(np.uint8)
40 | return [ret]
41 |
42 |
43 | ####
44 | def median_blur(images, random_state, parents, hooks, max_ksize=3):
45 | """Apply median blur to input images."""
46 | img = images[0] # aleju input batch as default (always=1 in our case)
47 | ksize = random_state.randint(0, max_ksize)
48 | ksize = ksize * 2 + 1
49 | ret = cv2.medianBlur(img, ksize)
50 | ret = ret.astype(np.uint8)
51 | return [ret]
52 |
53 |
54 | ####
55 | def add_to_hue(images, random_state, parents, hooks, range=(-8, 8)):
56 | """Perturbe the hue of input images."""
57 | img = images[0] # aleju input batch as default (always=1 in our case)
58 | hue = random_state.uniform(*range)
59 | hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
60 | if hsv.dtype.itemsize == 1:
61 | # OpenCV uses 0-179 for 8-bit images
62 | hsv[..., 0] = (hsv[..., 0] + hue) % 180
63 | else:
64 | # OpenCV uses 0-360 for floating point images
65 | hsv[..., 0] = (hsv[..., 0] + 2 * hue) % 360
66 | ret = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
67 | ret = ret.astype(np.uint8)
68 | return [ret]
69 |
70 |
71 | ####
72 | def add_to_saturation(images, random_state, parents, hooks, range=(-0.2, 0.2)):
73 | """Perturbe the saturation of input images."""
74 | img = images[0] # aleju input batch as default (always=1 in our case)
75 | value = 1 + random_state.uniform(*range)
76 | gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
77 | ret = img * value + (gray * (1 - value))[:, :, np.newaxis]
78 | ret = np.clip(ret, 0, 255)
79 | ret = ret.astype(np.uint8)
80 | return [ret]
81 |
82 |
83 | ####
84 | def add_to_contrast(images, random_state, parents, hooks, range=(0.75, 1.25)):
85 | """Perturbe the contrast of input images."""
86 | img = images[0] # aleju input batch as default (always=1 in our case)
87 | value = random_state.uniform(*range)
88 | mean = np.mean(img, axis=(0, 1), keepdims=True)
89 | ret = img * value + mean * (1 - value)
90 | ret = np.clip(img, 0, 255)
91 | ret = ret.astype(np.uint8)
92 | return [ret]
93 |
94 |
95 | ####
96 | def add_to_brightness(images, random_state, parents, hooks, range=(-26, 26)):
97 | """Perturbe the brightness of input images."""
98 | img = images[0] # aleju input batch as default (always=1 in our case)
99 | value = random_state.uniform(*range)
100 | ret = np.clip(img + value, 0, 255)
101 | ret = ret.astype(np.uint8)
102 | return [ret]
103 |
104 |
105 | def get_augmentation(mode, rng,input_shape=(256,256)):
106 | if mode == "train":
107 | print('Using train augmentation')
108 | shape_augs = [
109 | # * order = ``0`` -> ``cv2.INTER_NEAREST``
110 | # * order = ``1`` -> ``cv2.INTER_LINEAR``
111 | # * order = ``2`` -> ``cv2.INTER_CUBIC``
112 | # * order = ``3`` -> ``cv2.INTER_CUBIC``
113 | # * order = ``4`` -> ``cv2.INTER_CUBIC``
114 | # ! for pannuke v0, no rotation or translation, just flip to avoid mirror padding
115 | iaa.Affine(
116 | # scale images to 80-120% of their size, individually per axis
117 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
118 | # translate by -A to +A percent (per axis)
119 | translate_percent={"x": (-0.01, 0.01), "y": (-0.01, 0.01)},
120 | shear=(-5, 5), # shear by -5 to +5 degrees
121 | rotate=(-179, 179), # rotate by -179 to +179 degrees
122 | order=0, # default 0 use nearest neighbour
123 | backend="cv2", # opencv for fast processing
124 | seed=rng,
125 | ),
126 | # set position to 'center' for center crop
127 | # else 'uniform' for random crop
128 | iaa.CropToFixedSize(
129 | input_shape[0], input_shape[1], position="center"
130 | ),
131 | iaa.Fliplr(0.5, seed=rng),
132 | iaa.Flipud(0.5, seed=rng),
133 | ]
134 |
135 | input_augs = [
136 | iaa.OneOf(
137 | [
138 | iaa.Lambda(
139 | seed=rng,
140 | func_images= gaussian_blur,
141 | ),
142 | iaa.Lambda(
143 | seed=rng,
144 | func_images= median_blur,
145 | ),
146 | iaa.AdditiveGaussianNoise(
147 | loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5
148 | ),
149 | ]
150 | ),
151 | iaa.Sequential(
152 | [
153 | iaa.Lambda(
154 | seed=rng,
155 | func_images=add_to_hue,
156 | ),
157 | iaa.Lambda(
158 | seed=rng,
159 | func_images=add_to_saturation
160 | ,
161 | ),
162 | iaa.Lambda(
163 | seed=rng,
164 | func_images=add_to_brightness
165 | ,
166 | ),
167 | iaa.Lambda(
168 | seed=rng,
169 | func_images= add_to_contrast,
170 | ),
171 | ],
172 | random_order=True,
173 | ),
174 | ]
175 | elif mode == "test":
176 | print('Using test augmentation')
177 | shape_augs = [
178 | # set position to 'center' for center crop
179 | # else 'uniform' for random crop
180 | iaa.CropToFixedSize(
181 | input_shape[0], input_shape[1], position="center"
182 | )
183 | ]
184 | input_augs = []
185 |
186 | return iaa.Sequential(shape_augs), iaa.Sequential(input_augs)
--------------------------------------------------------------------------------
/models/PointNuNet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .segbase import SegBaseModel
7 | from .necks.jpfm import JPFM
8 | from .base_ops.CoordConv import CoordConv2d
9 | from .necks.fpn import FPN
10 |
11 | class PointNuNet(SegBaseModel):
12 | def __init__(self, nclass, backbone='resnet50', pretrained_base=True, frozen_stages=-1,norm_eval=False, seg_feat_channels=256, stacked_convs=7,ins_out_channels=256, kernel_size=1,output_stride=4, **kwargs):
13 | super(PointNuNet, self).__init__(backbone, pretrained_base=pretrained_base, frozen_stages=frozen_stages,norm_eval=norm_eval, **kwargs)
14 |
15 | if 'res' in backbone:
16 | self.fpn=FPN()
17 | self.forward=self.forward_res
18 | elif 'hrnet' in backbone:
19 | if '32' in backbone:
20 | c=480
21 | elif '64' in backbone:
22 | c=960
23 | elif '18' in backbone:
24 | c=270
25 | else:
26 | raise NotImplementedError
27 | self.jpfm_1=JPFM(in_channel=c)
28 | self.jpfm_2=JPFM(in_channel=c)
29 | self.jpfm_3=JPFM(in_channel=c)
30 | self.forward=self.forward_hrnet
31 | elif 'swin' in backbone:
32 | self.fpn=FPN(channels=[192,384,768,1536])
33 | self.forward=self.forward_swin
34 | else:
35 | raise NotImplementedError
36 | self.output_stride=output_stride
37 | self.heads=_PointNuNetHead(num_classes=nclass,
38 | in_channels=1024,
39 | seg_feat_channels=seg_feat_channels,
40 | stacked_convs=stacked_convs,
41 | ins_out_channels=ins_out_channels,
42 | kernel_size=kernel_size)
43 |
44 | def forward_swin(self, x):
45 | c2, c3, c4, c5 = self.pretrained(x)
46 | c2, c3, c4, c5=self.fpn(c2, c3, c4, c5)
47 | x0_h, x0_w = c2.size(2), c2.size(3)
48 | c3 = F.interpolate(c3, size=(x0_h, x0_w), mode='bilinear', align_corners=True)
49 | c4 = F.interpolate(c4, size=(x0_h, x0_w), mode='bilinear', align_corners=True)
50 | c5 = F.interpolate(c5, size=(x0_h, x0_w), mode='bilinear', align_corners=True)
51 | cat_x = torch.cat([c2,c3,c4,c5], 1)
52 |
53 | output=self.heads(cat_x, cat_x, cat_x)
54 | return output
55 |
56 | def forward_res(self, x):
57 | c1, c2, c3, c4, c5 = self.base_forward(x)
58 | c2, c3, c4, c5=self.fpn(c2, c3, c4, c5)
59 | x0_h, x0_w = c2.size(2), c2.size(3)
60 | c3 = F.interpolate(c3, size=(x0_h, x0_w), mode='bilinear', align_corners=True)
61 | c4 = F.interpolate(c4, size=(x0_h, x0_w), mode='bilinear', align_corners=True)
62 | c5 = F.interpolate(c5, size=(x0_h, x0_w), mode='bilinear', align_corners=True)
63 | cat_x = torch.cat([c2,c3,c4,c5], 1)
64 |
65 | output=self.heads(cat_x, cat_x, cat_x)
66 | return output
67 |
68 | def forward_hrnet(self, x):
69 | c2,c3,c4,c5 = self.pretrained(x)
70 | x0_h, x0_w = c2.size(2), c2.size(3)
71 | c3 = F.interpolate(c3, size=(x0_h, x0_w), mode='bilinear', align_corners=True)
72 | c4 = F.interpolate(c4, size=(x0_h, x0_w), mode='bilinear', align_corners=True)
73 | c5 = F.interpolate(c5, size=(x0_h, x0_w), mode='bilinear', align_corners=True)
74 | cat_x = torch.cat([c2,c3,c4,c5], 1)
75 |
76 | f1=self.jpfm_1(cat_x)
77 | f2=self.jpfm_2(cat_x)
78 | f3=self.jpfm_3(cat_x)
79 | if self.output_stride!=4:
80 | f2=F.interpolate(f2, size=(256//self.output_stride, 256//self.output_stride), mode='bilinear', align_corners=True)
81 | f3=F.interpolate(f3, size=(256//self.output_stride, 256//self.output_stride), mode='bilinear', align_corners=True)
82 | output=self.heads(f1,f2,f3)
83 |
84 | return output
85 |
86 | class _PointNuNetHead(nn.Module):
87 | def __init__(self,num_classes,
88 | in_channels=256*4,
89 | seg_feat_channels=256,
90 | stacked_convs=7,
91 | ins_out_channels=256,
92 | kernel_size=1
93 | ):
94 |
95 | super(_PointNuNetHead,self).__init__()
96 | self.num_classes = num_classes
97 | self.cate_out_channels = self.num_classes - 1
98 | self.in_channels = in_channels
99 | self.stacked_convs = stacked_convs
100 | self.seg_feat_channels = seg_feat_channels
101 | self.seg_out_channels = ins_out_channels
102 | self.ins_out_channels = ins_out_channels
103 | self.kernel_out_channels = (self.ins_out_channels * kernel_size * kernel_size)
104 |
105 | self._init_layers()
106 | self.init_weight()
107 |
108 | def _init_layers(self):
109 | self.mask_convs = nn.ModuleList()
110 | self.kernel_convs = nn.ModuleList()
111 | self.cate_convs = nn.ModuleList()
112 |
113 | for i in range(self.stacked_convs):
114 | chn = self.in_channels if i == 0 else self.seg_feat_channels
115 | conv = CoordConv2d if i ==0 else nn.Conv2d
116 | self.kernel_convs.append(nn.Sequential(
117 | conv(chn, self.seg_feat_channels, 3, 1, 1, bias=False),
118 | nn.BatchNorm2d(self.seg_feat_channels),
119 | nn.ReLU(True),
120 | ))
121 |
122 | chn = self.in_channels if i == 0 else self.seg_feat_channels
123 | self.cate_convs.append(nn.Sequential(
124 | nn.Conv2d(chn, self.seg_feat_channels, 3, 1, 1, bias=False),
125 | nn.BatchNorm2d(self.seg_feat_channels),
126 | nn.ReLU(True),
127 | ))
128 |
129 | self.head_kernel = nn.Conv2d(self.seg_feat_channels, self.kernel_out_channels, 1,padding=0)
130 | self.head_cate = nn.Conv2d(self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
131 |
132 | self.mask_convs.append(nn.Sequential(
133 | nn.Conv2d(self.in_channels, self.seg_feat_channels, 3, 1, 1, bias=False),
134 | nn.BatchNorm2d(self.seg_feat_channels),
135 | nn.ReLU(True),
136 | nn.Conv2d(self.seg_feat_channels, self.seg_feat_channels, 3, 1, 1, bias=False),
137 | nn.BatchNorm2d(self.seg_feat_channels),
138 | nn.ReLU(True),))
139 |
140 | self.mask_convs.append(nn.Sequential(
141 | nn.ConvTranspose2d(self.seg_feat_channels, self.seg_feat_channels, 4, 2, padding=1, output_padding=0,bias=False),
142 | nn.BatchNorm2d(self.seg_feat_channels),
143 | nn.ReLU(True),
144 | nn.Conv2d(self.seg_feat_channels, self.seg_feat_channels, 3, 1, 1, bias=False),
145 | nn.BatchNorm2d(self.seg_feat_channels),
146 | nn.ReLU(True),))
147 |
148 | self.mask_convs.append(nn.Sequential(
149 | nn.ConvTranspose2d(self.seg_feat_channels, self.seg_feat_channels, 4, 2, padding=1, output_padding=0,bias=False),
150 | nn.BatchNorm2d(self.seg_feat_channels),
151 | nn.ReLU(True),
152 | nn.Conv2d(self.seg_feat_channels, self.seg_feat_channels, 3, 1, 1, bias=False),
153 | nn.BatchNorm2d(self.seg_feat_channels),
154 | nn.ReLU(True)))
155 |
156 | self.head_mask = nn.Sequential(
157 | nn.Conv2d(self.seg_feat_channels, self.seg_out_channels, 1, padding=0, bias=False),
158 | nn.BatchNorm2d(self.seg_out_channels),
159 | nn.ReLU(True))
160 |
161 | def init_weight(self):
162 | prior_prob = 0.01
163 | bias_init = float(-math.log((1 - prior_prob) / prior_prob))
164 | torch.nn.init.normal_(self.head_cate.weight, std=0.01)
165 | torch.nn.init.constant_(self.head_cate.bias, bias_init)
166 |
167 | def forward(self, feats,f2,f3):
168 | # feature branch
169 | mask_feat=feats
170 | for i, mask_layer in enumerate(self.mask_convs):
171 | mask_feat = mask_layer(mask_feat)
172 | feature_pred = self.head_mask(mask_feat)
173 |
174 | # kernel branch
175 | kernel_feat=f2
176 | for i, kernel_layer in enumerate(self.kernel_convs):
177 | kernel_feat = kernel_layer(kernel_feat)
178 | kernel_pred = self.head_kernel(kernel_feat)
179 |
180 | # cate branch
181 | cate_feat=f3
182 | for i, cate_layer in enumerate(self.cate_convs):
183 | cate_feat = cate_layer(cate_feat)
184 | cate_pred = self.head_cate(cate_feat)
185 | return feature_pred, kernel_pred, cate_pred
186 |
187 |
188 |
--------------------------------------------------------------------------------
/models/backbone/res2netv1b.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import math
4 | import torch.utils.model_zoo as model_zoo
5 | import torch
6 | import torch.nn.functional as F
7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b']
8 |
9 |
10 | model_urls = {
11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth',
12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth',
13 | }
14 |
15 |
16 | class Bottle2neck(nn.Module):
17 | expansion = 4
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'):
20 | """ Constructor
21 | Args:
22 | inplanes: input channel dimensionality
23 | planes: output channel dimensionality
24 | stride: conv stride. Replaces pooling layer.
25 | downsample: None when stride = 1
26 | baseWidth: basic width of conv3x3
27 | scale: number of scale.
28 | type: 'normal': normal set. 'stage': first block of a new stage.
29 | """
30 | super(Bottle2neck, self).__init__()
31 |
32 | width = int(math.floor(planes * (baseWidth/64.0)))
33 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False)
34 | self.bn1 = nn.BatchNorm2d(width*scale)
35 |
36 | if scale == 1:
37 | self.nums = 1
38 | else:
39 | self.nums = scale -1
40 | if stype == 'stage':
41 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1)
42 | convs = []
43 | bns = []
44 | for i in range(self.nums):
45 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False))
46 | bns.append(nn.BatchNorm2d(width))
47 | self.convs = nn.ModuleList(convs)
48 | self.bns = nn.ModuleList(bns)
49 |
50 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
52 |
53 | self.relu = nn.ReLU(inplace=True)
54 | self.downsample = downsample
55 | self.stype = stype
56 | self.scale = scale
57 | self.width = width
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | spx = torch.split(out, self.width, 1)
67 | for i in range(self.nums):
68 | if i==0 or self.stype=='stage':
69 | sp = spx[i]
70 | else:
71 | sp = sp + spx[i]
72 | sp = self.convs[i](sp)
73 | sp = self.relu(self.bns[i](sp))
74 | if i==0:
75 | out = sp
76 | else:
77 | out = torch.cat((out, sp), 1)
78 | if self.scale != 1 and self.stype=='normal':
79 | out = torch.cat((out, spx[self.nums]),1)
80 | elif self.scale != 1 and self.stype=='stage':
81 | out = torch.cat((out, self.pool(spx[self.nums])),1)
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 | class Res2Net(nn.Module):
95 |
96 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000):
97 | self.inplanes = 64
98 | super(Res2Net, self).__init__()
99 | self.baseWidth = baseWidth
100 | self.scale = scale
101 | self.conv1 = nn.Sequential(
102 | nn.Conv2d(3, 32, 3, 2, 1, bias=False),
103 | nn.BatchNorm2d(32),
104 | nn.ReLU(inplace=True),
105 | nn.Conv2d(32, 32, 3, 1, 1, bias=False),
106 | nn.BatchNorm2d(32),
107 | nn.ReLU(inplace=True),
108 | nn.Conv2d(32, 64, 3, 1, 1, bias=False)
109 | )
110 | self.bn1 = nn.BatchNorm2d(64)
111 | self.relu = nn.ReLU()
112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
113 | self.layer1 = self._make_layer(block, 64, layers[0])
114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
117 | self.avgpool = nn.AdaptiveAvgPool2d(1)
118 | self.fc = nn.Linear(512 * block.expansion, num_classes)
119 |
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
123 | elif isinstance(m, nn.BatchNorm2d):
124 | nn.init.constant_(m.weight, 1)
125 | nn.init.constant_(m.bias, 0)
126 |
127 | def _make_layer(self, block, planes, blocks, stride=1):
128 | downsample = None
129 | if stride != 1 or self.inplanes != planes * block.expansion:
130 | downsample = nn.Sequential(
131 | nn.AvgPool2d(kernel_size=stride, stride=stride,
132 | ceil_mode=True, count_include_pad=False),
133 | nn.Conv2d(self.inplanes, planes * block.expansion,
134 | kernel_size=1, stride=1, bias=False),
135 | nn.BatchNorm2d(planes * block.expansion),
136 | )
137 |
138 | layers = []
139 | layers.append(block(self.inplanes, planes, stride, downsample=downsample,
140 | stype='stage', baseWidth = self.baseWidth, scale=self.scale))
141 | self.inplanes = planes * block.expansion
142 | for i in range(1, blocks):
143 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale))
144 |
145 | return nn.Sequential(*layers)
146 |
147 | def forward(self, x):
148 | x = self.conv1(x)
149 | x = self.bn1(x)
150 | x = self.relu(x)
151 | x = self.maxpool(x)
152 |
153 | x = self.layer1(x)
154 | x = self.layer2(x)
155 | x = self.layer3(x)
156 | x = self.layer4(x)
157 |
158 | x = self.avgpool(x)
159 | x = x.view(x.size(0), -1)
160 | x = self.fc(x)
161 |
162 | return x
163 |
164 |
165 | def res2net50_v1b(pretrained=False, **kwargs):
166 | """Constructs a Res2Net-50_v1b model.
167 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s.
168 | Args:
169 | pretrained (bool): If True, returns a model pre-trained on ImageNet
170 | """
171 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
172 | if pretrained:
173 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s']))
174 | return model
175 |
176 | def res2net101_v1b(pretrained=False, **kwargs):
177 | """Constructs a Res2Net-50_v1b_26w_4s model.
178 | Args:
179 | pretrained (bool): If True, returns a model pre-trained on ImageNet
180 | """
181 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
182 | if pretrained:
183 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
184 | return model
185 |
186 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs):
187 | """Constructs a Res2Net-50_v1b_26w_4s model.
188 | Args:
189 | pretrained (bool): If True, returns a model pre-trained on ImageNet
190 | """
191 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
192 | if pretrained:
193 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s']))
194 | return model
195 |
196 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs):
197 | """Constructs a Res2Net-50_v1b_26w_4s model.
198 | Args:
199 | pretrained (bool): If True, returns a model pre-trained on ImageNet
200 | """
201 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
202 | if pretrained:
203 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
204 | return model
205 |
206 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs):
207 | """Constructs a Res2Net-50_v1b_26w_4s model.
208 | Args:
209 | pretrained (bool): If True, returns a model pre-trained on ImageNet
210 | """
211 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth = 26, scale = 4, **kwargs)
212 | if pretrained:
213 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s']))
214 | return model
215 |
216 |
217 |
218 |
219 |
220 | if __name__ == '__main__':
221 | images = torch.rand(1, 3, 224, 224).cuda(0)
222 | model = res2net50_v1b_26w_4s(pretrained=True)
223 | model = model.cuda(0)
224 | print(model(images).size())
225 |
--------------------------------------------------------------------------------
/losses/lovasz_losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch
3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
4 | """
5 |
6 | from __future__ import print_function, division
7 |
8 | import torch
9 | from torch.autograd import Variable
10 | import torch.nn.functional as F
11 | import numpy as np
12 |
13 | try:
14 | from itertools import ifilterfalse
15 | except ImportError: # py3k
16 | from itertools import filterfalse as ifilterfalse
17 |
18 |
19 | def lovasz_grad(gt_sorted):
20 | """
21 | Computes gradient of the Lovasz extension w.r.t sorted errors
22 | See Alg. 1 in paper
23 | """
24 | p = len(gt_sorted)
25 | gts = gt_sorted.sum()
26 | intersection = gts - gt_sorted.float().cumsum(0)
27 | union = gts + (1 - gt_sorted).float().cumsum(0)
28 | jaccard = 1. - intersection / union
29 | if p > 1: # cover 1-pixel case
30 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
31 | return jaccard
32 |
33 |
34 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
35 | """
36 | IoU for foreground class
37 | binary: 1 foreground, 0 background
38 | """
39 | if not per_image:
40 | preds, labels = (preds,), (labels,)
41 | ious = []
42 | for pred, label in zip(preds, labels):
43 | intersection = ((label == 1) & (pred == 1)).sum()
44 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
45 | if not union:
46 | iou = EMPTY
47 | else:
48 | iou = float(intersection) / float(union)
49 | ious.append(iou)
50 | iou = mean(ious) # mean accross images if per_image
51 | return 100 * iou
52 |
53 |
54 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
55 | """
56 | Array of IoU for each (non ignored) class
57 | """
58 | if not per_image:
59 | preds, labels = (preds,), (labels,)
60 | ious = []
61 | for pred, label in zip(preds, labels):
62 | iou = []
63 | for i in range(C):
64 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
65 | intersection = ((label == i) & (pred == i)).sum()
66 | union = ((label == i) | ((pred == i) & (label != ignore))).sum()
67 | if not union:
68 | iou.append(EMPTY)
69 | else:
70 | iou.append(float(intersection) / float(union))
71 | ious.append(iou)
72 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
73 | return 100 * np.array(ious)
74 |
75 |
76 | # --------------------------- BINARY LOSSES ---------------------------
77 |
78 |
79 | def lovasz_hinge(logits, labels, per_image=True, ignore=None):
80 | """
81 | Binary Lovasz hinge loss
82 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
83 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
84 | per_image: compute the loss per image instead of per batch
85 | ignore: void class id
86 | """
87 | if per_image:
88 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
89 | for log, lab in zip(logits, labels))
90 | else:
91 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
92 | return loss
93 |
94 |
95 | def lovasz_hinge_flat(logits, labels):
96 | """
97 | Binary Lovasz hinge loss
98 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
99 | labels: [P] Tensor, binary ground truth labels (0 or 1)
100 | ignore: label to ignore
101 | """
102 | if len(labels) == 0:
103 | # only void pixels, the gradients should be 0
104 | return logits.sum() * 0.
105 | signs = 2. * labels.float() - 1.
106 | errors = (1. - logits * Variable(signs))
107 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
108 | perm = perm.data
109 | gt_sorted = labels[perm]
110 | grad = lovasz_grad(gt_sorted)
111 | loss = torch.dot(F.relu(errors_sorted), Variable(grad))
112 | return loss
113 |
114 |
115 | def flatten_binary_scores(scores, labels, ignore=None):
116 | """
117 | Flattens predictions in the batch (binary case)
118 | Remove labels equal to 'ignore'
119 | """
120 | scores = scores.view(-1)
121 | labels = labels.view(-1)
122 | if ignore is None:
123 | return scores, labels
124 | valid = (labels != ignore)
125 | vscores = scores[valid]
126 | vlabels = labels[valid]
127 | return vscores, vlabels
128 |
129 |
130 | class StableBCELoss(torch.nn.modules.Module):
131 | def __init__(self):
132 | super(StableBCELoss, self).__init__()
133 |
134 | def forward(self, input, target):
135 | neg_abs = - input.abs()
136 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
137 | return loss.mean()
138 |
139 |
140 | def binary_xloss(logits, labels, ignore=None):
141 | """
142 | Binary Cross entropy loss
143 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
144 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
145 | ignore: void class id
146 | """
147 | logits, labels = flatten_binary_scores(logits, labels, ignore)
148 | loss = StableBCELoss()(logits, Variable(labels.float()))
149 | return loss
150 |
151 |
152 | # --------------------------- MULTICLASS LOSSES ---------------------------
153 |
154 |
155 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
156 | """
157 | Multi-class Lovasz-Softmax loss
158 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
159 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
160 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
161 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
162 | per_image: compute the loss per image instead of per batch
163 | ignore: void class labels
164 | """
165 | if per_image:
166 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
167 | for prob, lab in zip(probas, labels))
168 | else:
169 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
170 | return loss
171 |
172 |
173 | def lovasz_softmax_flat(probas, labels, classes='present'):
174 | """
175 | Multi-class Lovasz-Softmax loss
176 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
177 | labels: [P] Tensor, ground truth labels (between 0 and C - 1)
178 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
179 | """
180 | if probas.numel() == 0:
181 | # only void pixels, the gradients should be 0
182 | return probas * 0.
183 | C = probas.size(1)
184 | losses = []
185 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
186 | for c in class_to_sum:
187 | fg = (labels == c).float() # foreground for class c
188 | if (classes is 'present' and fg.sum() == 0):
189 | continue
190 | if C == 1:
191 | if len(classes) > 1:
192 | raise ValueError('Sigmoid output possible only with 1 class')
193 | class_pred = probas[:, 0]
194 | else:
195 | class_pred = probas[:, c]
196 | errors = (Variable(fg) - class_pred).abs()
197 | errors_sorted, perm = torch.sort(errors, 0, descending=True)
198 | perm = perm.data
199 | fg_sorted = fg[perm]
200 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
201 | return mean(losses)
202 |
203 |
204 | def flatten_probas(probas, labels, ignore=None):
205 | """
206 | Flattens predictions in the batch
207 | """
208 | if probas.dim() == 3:
209 | # assumes output of a sigmoid layer
210 | B, H, W = probas.size()
211 | probas = probas.view(B, 1, H, W)
212 | B, C, H, W = probas.size()
213 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
214 | labels = labels.view(-1)
215 | if ignore is None:
216 | return probas, labels
217 | valid = (labels != ignore)
218 | vprobas = probas[valid.nonzero().squeeze()]
219 | vlabels = labels[valid]
220 | return vprobas, vlabels
221 |
222 |
223 | def xloss(logits, labels, ignore=None):
224 | """
225 | Cross entropy loss
226 | """
227 | return F.cross_entropy(logits, Variable(labels), ignore_index=255)
228 |
229 |
230 | # --------------------------- HELPER FUNCTIONS ---------------------------
231 | def isnan(x):
232 | return x != x
233 |
234 |
235 | def mean(l, ignore_nan=False, empty=0):
236 | """
237 | nanmean compatible with generators.
238 | """
239 | l = iter(l)
240 | if ignore_nan:
241 | l = ifilterfalse(isnan, l)
242 | try:
243 | n = 1
244 | acc = next(l)
245 | except StopIteration:
246 | if empty == 'raise':
247 | raise ValueError('Empty mean')
248 | return empty
249 | for n, v in enumerate(l, 2):
250 | acc += v
251 | if n == 1:
252 | return acc
253 | return acc / n
--------------------------------------------------------------------------------
/models/backbone/res2net.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import math
4 | import torch.utils.model_zoo as model_zoo
5 | import torch
6 | import torch.nn.functional as F
7 | __all__ = ['Res2Net', 'res2net50']
8 |
9 |
10 | model_urls = {
11 | 'res2net50_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_4s-06e79181.pth',
12 | 'res2net50_48w_2s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_48w_2s-afed724a.pth',
13 | 'res2net50_14w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_14w_8s-6527dddc.pth',
14 | 'res2net50_26w_6s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_6s-19041792.pth',
15 | 'res2net50_26w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_8s-2c7c9f12.pth',
16 | 'res2net101_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_26w_4s-02a759a1.pth',
17 | }
18 |
19 | class Bottle2neck(nn.Module):
20 | expansion = 4
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'):
23 | """ Constructor
24 | Args:
25 | inplanes: input channel dimensionality
26 | planes: output channel dimensionality
27 | stride: conv stride. Replaces pooling layer.
28 | downsample: None when stride = 1
29 | baseWidth: basic width of conv3x3
30 | scale: number of scale.
31 | type: 'normal': normal set. 'stage': first block of a new stage.
32 | """
33 | super(Bottle2neck, self).__init__()
34 |
35 | width = int(math.floor(planes * (baseWidth/64.0)))
36 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False)
37 | self.bn1 = nn.BatchNorm2d(width*scale)
38 |
39 | if scale == 1:
40 | self.nums = 1
41 | else:
42 | self.nums = scale -1
43 | if stype == 'stage':
44 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1)
45 | convs = []
46 | bns = []
47 | for i in range(self.nums):
48 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False))
49 | bns.append(nn.BatchNorm2d(width))
50 | self.convs = nn.ModuleList(convs)
51 | self.bns = nn.ModuleList(bns)
52 |
53 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False)
54 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
55 |
56 | self.relu = nn.ReLU(inplace=True)
57 | self.downsample = downsample
58 | self.stype = stype
59 | self.scale = scale
60 | self.width = width
61 |
62 | def forward(self, x):
63 | residual = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | spx = torch.split(out, self.width, 1)
70 | for i in range(self.nums):
71 | if i==0 or self.stype=='stage':
72 | sp = spx[i]
73 | else:
74 | sp = sp + spx[i]
75 | sp = self.convs[i](sp)
76 | sp = self.relu(self.bns[i](sp))
77 | if i==0:
78 | out = sp
79 | else:
80 | out = torch.cat((out, sp), 1)
81 | if self.scale != 1 and self.stype=='normal':
82 | out = torch.cat((out, spx[self.nums]),1)
83 | elif self.scale != 1 and self.stype=='stage':
84 | out = torch.cat((out, self.pool(spx[self.nums])),1)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 | class Res2Net(nn.Module):
98 |
99 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000):
100 | self.inplanes = 64
101 | super(Res2Net, self).__init__()
102 | self.baseWidth = baseWidth
103 | self.scale = scale
104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
105 | bias=False)
106 | self.bn1 = nn.BatchNorm2d(64)
107 | self.relu = nn.ReLU(inplace=True)
108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
109 | self.layer1 = self._make_layer(block, 64, layers[0])
110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
113 | self.avgpool = nn.AdaptiveAvgPool2d(1)
114 | self.fc = nn.Linear(512 * block.expansion, num_classes)
115 |
116 | for m in self.modules():
117 | if isinstance(m, nn.Conv2d):
118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
119 | elif isinstance(m, nn.BatchNorm2d):
120 | nn.init.constant_(m.weight, 1)
121 | nn.init.constant_(m.bias, 0)
122 |
123 | def _make_layer(self, block, planes, blocks, stride=1):
124 | downsample = None
125 | if stride != 1 or self.inplanes != planes * block.expansion:
126 | downsample = nn.Sequential(
127 | nn.Conv2d(self.inplanes, planes * block.expansion,
128 | kernel_size=1, stride=stride, bias=False),
129 | nn.BatchNorm2d(planes * block.expansion),
130 | )
131 |
132 | layers = []
133 | layers.append(block(self.inplanes, planes, stride, downsample=downsample,
134 | stype='stage', baseWidth = self.baseWidth, scale=self.scale))
135 | self.inplanes = planes * block.expansion
136 | for i in range(1, blocks):
137 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def forward(self, x):
142 | x = self.conv1(x)
143 | x = self.bn1(x)
144 | x = self.relu(x)
145 | x = self.maxpool(x)
146 |
147 | x = self.layer1(x)
148 | x = self.layer2(x)
149 | x = self.layer3(x)
150 | x = self.layer4(x)
151 |
152 | x = self.avgpool(x)
153 | x = x.view(x.size(0), -1)
154 | x = self.fc(x)
155 |
156 | return x
157 |
158 |
159 | def res2net50(pretrained=False, **kwargs):
160 | """Constructs a Res2Net-50 model.
161 | Res2Net-50 refers to the Res2Net-50_26w_4s.
162 | Args:
163 | pretrained (bool): If True, returns a model pre-trained on ImageNet
164 | """
165 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
166 | if pretrained:
167 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s']))
168 | return model
169 |
170 | def res2net50_26w_4s(pretrained=False, **kwargs):
171 | """Constructs a Res2Net-50_26w_4s model.
172 | Args:
173 | pretrained (bool): If True, returns a model pre-trained on ImageNet
174 | """
175 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
176 | if pretrained:
177 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s']))
178 | return model
179 |
180 | def res2net101_26w_4s(pretrained=False, **kwargs):
181 | """Constructs a Res2Net-50_26w_4s model.
182 | Args:
183 | pretrained (bool): If True, returns a model pre-trained on ImageNet
184 | """
185 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
186 | if pretrained:
187 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_26w_4s']))
188 | return model
189 |
190 | def res2net50_26w_6s(pretrained=False, **kwargs):
191 | """Constructs a Res2Net-50_26w_4s model.
192 | Args:
193 | pretrained (bool): If True, returns a model pre-trained on ImageNet
194 | """
195 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 6, **kwargs)
196 | if pretrained:
197 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_6s']))
198 | return model
199 |
200 | def res2net50_26w_8s(pretrained=False, **kwargs):
201 | """Constructs a Res2Net-50_26w_4s model.
202 | Args:
203 | pretrained (bool): If True, returns a model pre-trained on ImageNet
204 | """
205 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 8, **kwargs)
206 | if pretrained:
207 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_8s']))
208 | return model
209 |
210 | def res2net50_48w_2s(pretrained=False, **kwargs):
211 | """Constructs a Res2Net-50_48w_2s model.
212 | Args:
213 | pretrained (bool): If True, returns a model pre-trained on ImageNet
214 | """
215 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 48, scale = 2, **kwargs)
216 | if pretrained:
217 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_48w_2s']))
218 | return model
219 |
220 | def res2net50_14w_8s(pretrained=False, **kwargs):
221 | """Constructs a Res2Net-50_14w_8s model.
222 | Args:
223 | pretrained (bool): If True, returns a model pre-trained on ImageNet
224 | """
225 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 14, scale = 8, **kwargs)
226 | if pretrained:
227 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_14w_8s']))
228 | return model
229 |
230 |
231 |
232 | if __name__ == '__main__':
233 | images = torch.rand(1, 3, 224, 224).cuda(0)
234 | model = res2net101_26w_4s(pretrained=True)
235 | model = model.cuda(0)
236 | print(model(images).size())
237 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from trainer import Trainer
7 |
8 | try:
9 | from itertools import izip as zip
10 | except ImportError: # will be 3.x series
11 | pass
12 | import scipy.io as scio
13 |
14 | import torch
15 | import os
16 | import tifffile
17 | import numpy as np
18 | from PIL import Image
19 | from utils import get_config,_imageshow,_imagesave
20 | from skimage.util.shape import view_as_windows
21 | from utils.imop import get_ins_info
22 | from utils.metrics import get_fast_aji,get_fast_pq, get_dice_1,remap_label
23 | from torchvision import transforms
24 | import argparse
25 | import cv2
26 | from collections import Counter
27 | import time
28 | import scipy.io as scio
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--output_dir', type=str, default='outputs')
32 | parser.add_argument('--name', type=str, default='tmp')
33 | parser.add_argument('--epoch',type=int,default=100)
34 | parser.add_argument('--load_size',type=int,default=1024)
35 | parser.add_argument('--patch_size',type=int,default=256)
36 | parser.add_argument('--stride',type=int,default=128)
37 |
38 | opts = parser.parse_args()
39 |
40 | if __name__ == '__main__':
41 |
42 | opts.config=os.path.join(opts.output_dir,'{}/config.yaml'.format(opts.name))
43 |
44 | config=get_config(opts.config)
45 | trainer = Trainer(config)
46 | trainer.cuda()
47 |
48 | load_size = opts.load_size
49 | patch_size = opts.patch_size
50 | stride = opts.stride
51 |
52 | state_path = os.path.join(opts.output_dir,'{}/checkpoints/model_{}.pt'.format(opts.name, '%04d' % (opts.epoch)))
53 | state_dict = torch.load(state_path)
54 |
55 | trainer.model.load_state_dict(state_dict['seg'])
56 | trainer.model.eval()
57 | if not config['image_norm_mean']:
58 | _mean = (0.5, 0.5, 0.5)
59 | _std = (0.5, 0.5, 0.5)
60 | else:
61 | _mean = tuple(config['image_norm_mean'])
62 | _std = tuple(config['image_norm_std'])
63 | im_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(_mean, _std)])
64 |
65 | ajis = []
66 | dices = []
67 | pqs = []
68 | dqs = []
69 | sqs = []
70 | root = os.path.join(config['dataroot'], 'test')
71 | stain_norm_type=config['stainnorm']
72 |
73 | test_dir_fp=os.path.join(root, 'Images') if stain_norm_type is None else os.path.join(root, f'Images_{stain_norm_type}')
74 | print(test_dir_fp,stain_norm_type,_mean,_std)
75 | test_img_fp=os.listdir(test_dir_fp)
76 |
77 | test_meta = [os.path.splitext(p)[0] for p in test_img_fp]
78 | test_img_fp=[os.path.join(test_dir_fp,f) for f in test_img_fp]
79 | for i, (test_fp, test_file_name) in enumerate(zip(test_img_fp,test_meta)):
80 | if 'tif' in test_fp:
81 | with tifffile.TiffFile(test_fp) as f:
82 | test_img = f.asarray()
83 | else:
84 | test_img = np.array(Image.open(test_fp))
85 | original_img=test_img.copy()#tifffile.TiffFile(os.path.join(root,'Images',f'{test_file_name}.tif')).asarray()#
86 |
87 | test_gt_path = os.path.join(root, 'Labels', test_file_name + '.mat')
88 | test_GT = scio.loadmat(test_gt_path)['inst_map'].astype(np.int32)
89 |
90 | im_size = test_img.shape[0]
91 |
92 | assert patch_size % stride == 0 and load_size % patch_size == 0
93 | pad_size = (load_size - im_size + patch_size) // 2
94 |
95 | test_img = np.pad(test_img, ((pad_size, pad_size), (pad_size, pad_size), (0, 0)), "reflect")
96 | crop_test_imgs = view_as_windows(test_img, (patch_size, patch_size, 3), (stride, stride, 3))[:, :, 0]
97 |
98 |
99 | pred_crop_test_imgs = []
100 | ins_num = 1
101 |
102 | output_seg = np.zeros((load_size + patch_size, load_size + patch_size), dtype=np.int32)
103 | score_list={}
104 | for i in range(crop_test_imgs.shape[0]):
105 | for j in range(crop_test_imgs.shape[1]):
106 | crop_test_img = crop_test_imgs[i, j]
107 | crop_test_img = im_transform(crop_test_img).unsqueeze(0).cuda()
108 | with torch.no_grad():
109 | output = trainer.prediction_single(crop_test_img)
110 | #output = trainer.prediction_fast(crop_test_img)
111 | if output is not None:
112 | seg_masks, cate_labels, cate_scores = output
113 | else:
114 | continue
115 | seg_masks = seg_masks.cpu().numpy()
116 | cate_labels = cate_labels.cpu().numpy()
117 | cate_scores = cate_scores.cpu().numpy()
118 | for ins_id in range(seg_masks.shape[0]):
119 | seg_ = seg_masks[ins_id]
120 | label_ = cate_labels[ins_id]
121 | score_ = cate_scores[ins_id]
122 | center_w, center_h, width, height = get_ins_info(seg_, method='bbox')
123 |
124 | center_h = np.ceil(center_h)
125 | center_w = np.ceil(center_w)
126 | offset_h = i * stride
127 | offset_w = j * stride
128 | if center_h >= patch_size // 2 - stride // 2 and center_h <= patch_size // 2 + stride // 2 and center_w >= patch_size // 2 - stride // 2 and center_w <= patch_size // 2 + stride // 2:
129 | focus_area = output_seg[offset_h:offset_h + patch_size, offset_w:offset_w + patch_size].copy()
130 | if np.sum(np.logical_and(focus_area > 0, seg_)) == 0:
131 | output_seg[offset_h:offset_h + patch_size, offset_w:offset_w + patch_size] = np.where(
132 | focus_area > 0, focus_area, seg_ * (ins_num))
133 | score_list[ins_num]=score_
134 | ins_num += 1
135 | else:
136 | compared_num, _ = Counter((focus_area * seg_).flatten()).most_common(2)[1]
137 | assert compared_num > 0
138 | compared_num = int(compared_num)
139 | compared_score = score_list[compared_num]
140 | if np.sum(np.logical_and(focus_area == compared_num, seg_)) / np.sum(
141 | np.logical_or(focus_area == compared_num, seg_)) > 0.7:#IoU>0.1判断重叠
142 | if compared_score > score_:
143 | pass
144 | else:
145 | focus_area[focus_area==compared_num]=0
146 | output_seg[offset_h:offset_h + patch_size, offset_w:offset_w + patch_size]=focus_area
147 | output_seg[offset_h:offset_h + patch_size,
148 | offset_w:offset_w + patch_size] = np.where(
149 | focus_area > 0, focus_area, seg_ * (ins_num))
150 | score_list[ins_num] = score_
151 | ins_num += 1
152 | else:
153 | output_seg[offset_h:offset_h + patch_size, offset_w:offset_w + patch_size] = np.where(
154 | focus_area > 0, focus_area, seg_ * (ins_num))
155 | score_list[ins_num] = score_
156 | ins_num += 1
157 |
158 | output_seg=output_seg[pad_size:-pad_size, pad_size:-pad_size]
159 | for ui in np.unique(output_seg):
160 | if ui ==0:continue
161 | if np.sum(output_seg==ui)<16:
162 | output_seg[output_seg==ui]=0
163 | test_GT = remap_label(test_GT)
164 |
165 | output_seg = remap_label(output_seg)
166 | aji = get_fast_aji(test_GT.copy(), output_seg.copy())
167 | dice = get_dice_1(test_GT.copy(), output_seg.copy())
168 |
169 | [dq, sq, pq], [paired_true, paired_pred, unpaired_true, unpaired_pred] = get_fast_pq(test_GT.copy(), output_seg.copy())
170 | print(f'dice {round(float(dice), 3)} AJI {round(float(aji), 3)} dq {round(float(dq), 3)} sq {round(float(sq), 3)} pq {round(float(pq), 3)} gt up {len(unpaired_true)} pred up {len(unpaired_pred)}')
171 | title=f'DICE:{round(float(dice), 3)}, AJI:{round(float(aji), 3)},\n DQ:{round(float(dq), 3)}, SQ:{round(float(sq), 3)}, PQ:{round(float(pq), 3)}'
172 | #_imageshow(test_img[pad_size:-pad_size, pad_size:-pad_size],output_seg,test_GT,unpaired_pred,unpaired_true,title=title)
173 | #_imagesave(original_img,output_seg,None,f'consep_pred/{test_file_name}.png')
174 | #_imagesave(original_img, test_GT, None, f'consep_gt/{test_file_name}.png')
175 |
176 | #results={'pred':output_seg, 'gt':test_GT,'paired_true':paired_true,'paired_pred':paired_pred,'unpaired_true':unpaired_true,'unpaired_pred':unpaired_pred}
177 |
178 | #scio.savemat(f'consep_results/{test_file_name}.mat',results)
179 | #scio.savemat(f'consep_pred/{test_file_name}.mat',new_pred)
180 |
181 |
182 | dices.append(dice)
183 | ajis.append(aji)
184 | dqs.append(dq)
185 | sqs.append(sq)
186 | pqs.append(pq)
187 |
188 |
189 |
190 | print(f'dice {round(float(np.mean(dices)),3)} AJI {round(float(np.mean(ajis)),3)} dq{round(float(np.mean(dqs)),3)} sq{round(float(np.mean(sqs)),3)} pq{round(float(np.mean(pqs)),3)}')
191 | print(f'{round(float(np.mean(dices)),3)}\t{round(float(np.mean(ajis)),3)}\t{round(float(np.mean(dqs)),3)}\t{round(float(np.mean(sqs)),3)}\t{round(float(np.mean(pqs)),3)}')
192 |
--------------------------------------------------------------------------------
/models/backbone/resnetv1b.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 | from ..base_ops.DCNv2 import DeformConv
5 |
6 | __all__ = ['ResNetV1b', 'resnet18_v1b', 'resnet34_v1b', 'resnet50_v1b',
7 | 'resnet101_v1b', 'resnet152_v1b', 'resnet152_v1s', 'resnet101_v1s', 'resnet50_v1s']
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | }
16 |
17 |
18 | class BasicBlockV1b(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
22 | previous_dilation=1, norm_layer=nn.BatchNorm2d):
23 | super(BasicBlockV1b, self).__init__()
24 | self.conv1 = nn.Conv2d(inplanes, planes, 3, stride,
25 | dilation, dilation, bias=False)
26 | self.bn1 = norm_layer(planes)
27 | self.relu = nn.ReLU(True)
28 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, previous_dilation,
29 | dilation=previous_dilation, bias=False)
30 | self.bn2 = norm_layer(planes)
31 | self.downsample = downsample
32 | self.stride = stride
33 |
34 | def forward(self, x):
35 | identity = x
36 |
37 | out = self.conv1(x)
38 | out = self.bn1(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv2(out)
42 | out = self.bn2(out)
43 |
44 | if self.downsample is not None:
45 | identity = self.downsample(x)
46 |
47 | out += identity
48 | out = self.relu(out)
49 |
50 | return out
51 |
52 |
53 | class BottleneckV1b(nn.Module):
54 | expansion = 4
55 |
56 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
57 | previous_dilation=1, norm_layer=nn.BatchNorm2d, dcn=False):
58 | super(BottleneckV1b, self).__init__()
59 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
60 | self.bn1 = norm_layer(planes)
61 | self.dcn=dcn
62 | if dcn:
63 | self.conv2 = DeformConv(planes, planes, 3, stride, dilation, dilation, bias=False)
64 | else:
65 | self.conv2 = nn.Conv2d(planes, planes, 3, stride, dilation, dilation, bias=False)
66 | self.bn2 = norm_layer(planes)
67 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
68 | self.bn3 = norm_layer(planes * self.expansion)
69 | self.relu = nn.ReLU(True)
70 | self.downsample = downsample
71 | self.stride = stride
72 |
73 |
74 | def forward(self, x):
75 | identity = x
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv2(out)
82 | out = self.bn2(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv3(out)
86 | out = self.bn3(out)
87 |
88 | if self.downsample is not None:
89 | identity = self.downsample(x)
90 |
91 | out += identity
92 | out = self.relu(out)
93 |
94 | return out
95 |
96 |
97 | class ResNetV1b(nn.Module):
98 |
99 | def __init__(self, block, layers, num_classes=1000, dilated=True, deep_stem=False,
100 | zero_init_residual=False, norm_layer=nn.BatchNorm2d):
101 | self.inplanes = 128 if deep_stem else 64
102 | super(ResNetV1b, self).__init__()
103 | if deep_stem:
104 | self.conv1 = nn.Sequential(
105 | nn.Conv2d(3, 64, 3, 2, 1, bias=False),
106 | norm_layer(64),
107 | nn.ReLU(True),
108 | nn.Conv2d(64, 64, 3, 1, 1, bias=False),
109 | norm_layer(64),
110 | nn.ReLU(True),
111 | nn.Conv2d(64, 128, 3, 1, 1, bias=False)
112 | )
113 | else:
114 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
115 | use_dcn=[False,False,False,False]
116 | #use_dcn = [False, False, False, False]
117 | self.bn1 = norm_layer(self.inplanes)
118 | self.relu = nn.ReLU(True)
119 | self.maxpool = nn.MaxPool2d(3, 2, 1)
120 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer,dcn=use_dcn[0])
121 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer,dcn=use_dcn[1])
122 | if dilated:
123 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer,dcn=use_dcn[2])
124 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer,dcn=use_dcn[3])
125 | else:
126 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer,dcn=use_dcn[2])
127 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer,dcn=use_dcn[3])
128 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
129 | self.fc = nn.Linear(512 * block.expansion, num_classes)
130 |
131 | for m in self.modules():
132 | if isinstance(m, nn.Conv2d):
133 | pass
134 |
135 | elif isinstance(m, nn.BatchNorm2d):
136 | nn.init.constant_(m.weight, 1)
137 | nn.init.constant_(m.bias, 0)
138 |
139 | if zero_init_residual:
140 | for m in self.modules():
141 | if isinstance(m, BottleneckV1b):
142 | nn.init.kaiming_normal_(m.conv1.weight, mode='fan_out', nonlinearity='relu')
143 | nn.init.kaiming_normal_(m.conv2.weight, mode='fan_out', nonlinearity='relu')
144 | nn.init.kaiming_normal_(m.conv3.weight, mode='fan_out', nonlinearity='relu')
145 | nn.init.constant_(m.bn3.weight, 0)
146 | elif isinstance(m, BasicBlockV1b):
147 | nn.init.constant_(m.bn2.weight, 0)
148 |
149 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d,dcn=False):
150 | downsample = None
151 | if stride != 1 or self.inplanes != planes * block.expansion:
152 | downsample = nn.Sequential(
153 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False),
154 | norm_layer(planes * block.expansion),
155 | )
156 |
157 | layers = []
158 | if dilation in (1, 2):
159 | layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
160 | previous_dilation=dilation, norm_layer=norm_layer,dcn=dcn))
161 | elif dilation == 4:
162 | layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
163 | previous_dilation=dilation, norm_layer=norm_layer,dcn=dcn))
164 | else:
165 | raise RuntimeError("=> unknown dilation size: {}".format(dilation))
166 | self.inplanes = planes * block.expansion
167 | for _ in range(1, blocks):
168 | layers.append(block(self.inplanes, planes, dilation=dilation,
169 | previous_dilation=dilation, norm_layer=norm_layer,dcn=dcn))
170 |
171 | return nn.Sequential(*layers)
172 |
173 | def forward(self, x):
174 | x = self.conv1(x)
175 | x = self.bn1(x)
176 | x = self.relu(x)
177 | x = self.maxpool(x)
178 |
179 | x = self.layer1(x)
180 | x = self.layer2(x)
181 | x = self.layer3(x)
182 | x = self.layer4(x)
183 |
184 | x = self.avgpool(x)
185 | x = x.view(x.size(0), -1)
186 | x = self.fc(x)
187 |
188 | return x
189 |
190 |
191 | def resnet18_v1b(pretrained=False, **kwargs):
192 | model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], **kwargs)
193 | if pretrained:
194 | old_dict = model_zoo.load_url(model_urls['resnet18'])
195 | model_dict = model.state_dict()
196 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
197 | model_dict.update(old_dict)
198 | model.load_state_dict(model_dict)
199 | return model
200 |
201 |
202 | def resnet34_v1b(pretrained=False, **kwargs):
203 | model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
204 | if pretrained:
205 | old_dict = model_zoo.load_url(model_urls['resnet34'])
206 | model_dict = model.state_dict()
207 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
208 | model_dict.update(old_dict)
209 | model.load_state_dict(model_dict)
210 | return model
211 |
212 |
213 | def resnet50_v1b(pretrained=False, **kwargs):
214 | model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], **kwargs)
215 | if pretrained:
216 | old_dict = model_zoo.load_url(model_urls['resnet50'])
217 | model_dict = model.state_dict()
218 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
219 | model_dict.update(old_dict)
220 | model.load_state_dict(model_dict)
221 | return model
222 |
223 |
224 | def resnet101_v1b(pretrained=False, **kwargs):
225 | model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], **kwargs)
226 | if pretrained:
227 | old_dict = model_zoo.load_url(model_urls['resnet101'])
228 | model_dict = model.state_dict()
229 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
230 | model_dict.update(old_dict)
231 | model.load_state_dict(model_dict)
232 | return model
233 |
234 |
235 | def resnet152_v1b(pretrained=False, **kwargs):
236 | model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], **kwargs)
237 | if pretrained:
238 | old_dict = model_zoo.load_url(model_urls['resnet152'])
239 | model_dict = model.state_dict()
240 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
241 | model_dict.update(old_dict)
242 | model.load_state_dict(model_dict)
243 | return model
244 |
245 |
246 | def resnet50_v1s(pretrained=False, root='~/.torch/models', **kwargs):
247 | model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, **kwargs)
248 | if pretrained:
249 | from ..model_store import get_resnet_file
250 | model.load_state_dict(torch.load(get_resnet_file('resnet50', root=root)), strict=False)
251 | return model
252 |
253 | def resnet101_v1s(pretrained=False, root='~/.torch/models', **kwargs):
254 | model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, **kwargs)
255 | if pretrained:
256 | from ..model_store import get_resnet_file
257 | model.load_state_dict(torch.load(get_resnet_file('resnet101', root=root)), strict=False)
258 | return model
259 |
260 |
261 | def resnet152_v1s(pretrained=False, root='~/.torch/models', **kwargs):
262 | model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, **kwargs)
263 | if pretrained:
264 | from ..model_store import get_resnet_file
265 | model.load_state_dict(torch.load(get_resnet_file('resnet152', root=root)), strict=False)
266 | return model
267 |
268 |
269 | if __name__ == '__main__':
270 | import torch
271 | img = torch.randn(4, 3, 224, 224)
272 | model = resnet50_v1s(True)
273 |
274 |
--------------------------------------------------------------------------------
/losses/dice_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from scipy.ndimage import distance_transform_edt
7 |
8 | def make_one_hot(input, num_classes=None):
9 | """Convert class index tensor to one hot encoding tensor.
10 | Args:
11 | input: A tensor of shape [N, 1, *]
12 | num_classes: An int of number of class
13 | Shapes:
14 | predict: A tensor of shape [N, *] without sigmoid activation function applied
15 | target: A tensor of shape same with predict
16 | Returns:
17 | A tensor of shape [N, num_classes, *]
18 | """
19 | if num_classes is None:
20 | num_classes = input.max() + 1
21 | shape = np.array(input.shape)
22 | shape[1] = num_classes
23 | shape = tuple(shape)
24 | result = torch.zeros(shape)
25 | result = result.scatter_(1, input.cpu().long(), 1)
26 | return result
27 | class logcosh_DICE(nn.Module):
28 | def __init__(self, ignore_index=None, reduction='mean',**kwargs):
29 | super(logcosh_DICE, self).__init__()
30 |
31 | class BinaryDiceLoss(nn.Module):
32 | """Dice loss of binary class
33 | Args:
34 | ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
35 | reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
36 | Shapes:
37 | output: A tensor of shape [N, *] without sigmoid activation function applied
38 | target: A tensor of shape same with output
39 | Returns:
40 | Loss tensor according to arg reduction
41 | Raise:
42 | Exception if unexpected reduction
43 | """
44 |
45 | def __init__(self, ignore_index=None, reduction='mean',**kwargs):
46 | super(BinaryDiceLoss, self).__init__()
47 | self.smooth = 1 # suggest set a large number when target area is large,like '10|100'
48 | self.ignore_index = ignore_index
49 | self.reduction = reduction
50 | self.batch_dice = False # treat a large map when True
51 | if 'batch_loss' in kwargs.keys():
52 | self.batch_dice = kwargs['batch_loss']
53 |
54 | def forward(self, output, target, use_sigmoid=True):
55 | assert output.shape[0] == target.shape[0], f"output & target batch size don't match {output.shape[0]} {target.shape[0]} "
56 | if use_sigmoid:
57 | output = torch.sigmoid(output)
58 |
59 | if self.ignore_index is not None:
60 | validmask = (target != self.ignore_index).float()
61 | output = output.mul(validmask) # can not use inplace for bp
62 | target = target.float().mul(validmask)
63 |
64 | dim0= output.shape[0]
65 | if self.batch_dice:
66 | dim0 = 1
67 |
68 | output = output.contiguous().view(dim0, -1)
69 | target = target.contiguous().view(dim0, -1).float()
70 |
71 | num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth
72 | den = torch.sum(output.abs() + target.abs(), dim=1) + self.smooth
73 |
74 | loss = 1 - (num / den)
75 |
76 | if self.reduction == 'mean':
77 | return loss.mean()
78 | elif self.reduction == 'sum':
79 | return loss.sum()
80 | elif self.reduction == 'none':
81 | return loss
82 | else:
83 | raise Exception('Unexpected reduction {}'.format(self.reduction))
84 |
85 |
86 | class DiceLoss(nn.Module):
87 | """Dice loss, need one hot encode input
88 | Args:
89 | weight: An array of shape [num_classes,]
90 | ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
91 | output: A tensor of shape [N, C, *]
92 | target: A tensor of same shape with output
93 | other args pass to BinaryDiceLoss
94 | Return:
95 | same as BinaryDiceLoss
96 | """
97 |
98 | def __init__(self, weight=None, ignore_index=None, **kwargs):
99 | super(DiceLoss, self).__init__()
100 | self.kwargs = kwargs
101 | self.weight = weight
102 | if isinstance(ignore_index, (int, float)):
103 | self.ignore_index = [int(ignore_index)]
104 | elif ignore_index is None:
105 | self.ignore_index = []
106 | elif isinstance(ignore_index, (list, tuple)):
107 | self.ignore_index = ignore_index
108 | else:
109 | raise TypeError("Expect 'int|float|list|tuple', while get '{}'".format(type(ignore_index)))
110 |
111 | def forward(self, output, target):
112 | assert output.shape == target.shape, 'output & target shape do not match'
113 | dice = BinaryDiceLoss(**self.kwargs)
114 | total_loss = 0
115 | output = F.softmax(output, dim=1)
116 | for i in range(target.shape[1]):
117 | if i not in self.ignore_index:
118 | dice_loss = dice(output[:, i], target[:, i], use_sigmoid=False)
119 | if self.weight is not None:
120 | assert self.weight.shape[0] == target.shape[1], \
121 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
122 | dice_loss *= self.weights[i]
123 | total_loss += (dice_loss)
124 | loss = total_loss / (target.size(1) - len(self.ignore_index))
125 | return loss
126 |
127 |
128 | class WBCEWithLogitLoss(nn.Module):
129 | """
130 | Weighted Binary Cross Entropy.
131 | `WBCE(p,t)=-β*t*log(p)-(1-t)*log(1-p)`
132 | To decrease the number of false negatives, set β>1.
133 | To decrease the number of false positives, set β<1.
134 | Args:
135 | @param weight: positive sample weight
136 | Shapes:
137 | output: A tensor of shape [N, 1,(d,), h, w] without sigmoid activation function applied
138 | target: A tensor of shape same with output
139 | """
140 |
141 | def __init__(self, weight=1.0, ignore_index=None, reduction='mean'):
142 | super(WBCEWithLogitLoss, self).__init__()
143 | assert reduction in ['none', 'mean', 'sum']
144 | self.ignore_index = ignore_index
145 | weight = float(weight)
146 | self.weight = weight
147 | self.reduction = reduction
148 | self.smooth = 0.01
149 |
150 | def forward(self, output, target):
151 | assert output.shape[0] == target.shape[0], "output & target batch size don't match"
152 |
153 | if self.ignore_index is not None:
154 | valid_mask = (target != self.ignore_index).float()
155 | output = output.mul(valid_mask) # can not use inplace for bp
156 | target = target.float().mul(valid_mask)
157 |
158 | batch_size = output.size(0)
159 | output = output.view(batch_size, -1)
160 | target = target.view(batch_size, -1)
161 |
162 | output = torch.sigmoid(output)
163 | # avoid `nan` loss
164 | eps = 1e-6
165 | output = torch.clamp(output, min=eps, max=1.0 - eps)
166 | # soft label
167 | target = torch.clamp(target, min=self.smooth, max=1.0 - self.smooth)
168 |
169 | # loss = self.bce(output, target)
170 | loss = -self.weight * target.mul(torch.log(output)) - ((1.0 - target).mul(torch.log(1.0 - output)))
171 | if self.reduction == 'mean':
172 | loss = torch.mean(loss)
173 | elif self.reduction == 'sum':
174 | loss = torch.sum(loss)
175 | elif self.reduction == 'none':
176 | loss = loss
177 | else:
178 | raise NotImplementedError
179 | return loss
180 |
181 |
182 | class WBCE_DiceLoss(nn.Module):
183 | def __init__(self, alpha=1.0, weight=1.0, ignore_index=None, reduction='mean'):
184 | """
185 | combination of Weight Binary Cross Entropy and Binary Dice Loss
186 | Args:
187 | @param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
188 | @param reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
189 | @param alpha: weight between WBCE('Weight Binary Cross Entropy') and binary dice, apply on WBCE
190 | Shapes:
191 | output: A tensor of shape [N, *] without sigmoid activation function applied
192 | target: A tensor of shape same with output
193 | """
194 | super(WBCE_DiceLoss, self).__init__()
195 | assert reduction in ['none', 'mean', 'sum']
196 | assert 0 <= alpha <= 1, '`alpha` should in [0,1]'
197 | self.alpha = alpha
198 | self.ignore_index = ignore_index
199 | self.reduction = reduction
200 | self.dice = BinaryDiceLoss(ignore_index=ignore_index, reduction=reduction, general=True)
201 | self.wbce = WBCEWithLogitLoss(weight=weight, ignore_index=ignore_index, reduction=reduction)
202 | self.dice_loss = None
203 | self.wbce_loss = None
204 |
205 | def forward(self, output, target):
206 | self.dice_loss = self.dice(output, target)
207 | self.wbce_loss = self.wbce(output, target)
208 | loss = self.alpha * self.wbce_loss + self.dice_loss
209 | return loss
210 |
211 |
212 |
213 | def compute_edts_forPenalizedLoss(GT):
214 | """
215 | GT.shape = (batch_size, x,y,z)
216 | only for binary segmentation
217 | """
218 | res = np.zeros(GT.shape,dtype=np.float32)
219 | for i in range(GT.shape[0]):
220 |
221 | posmask = GT[i]
222 | negmask = ~posmask
223 | pos_edt = distance_transform_edt(posmask)
224 | pos_edt = (np.max(pos_edt) - pos_edt) * posmask
225 | neg_edt = distance_transform_edt(negmask)
226 | neg_edt = (np.max(neg_edt) - neg_edt) * negmask
227 | res[i] = pos_edt / np.max(pos_edt) + neg_edt / np.max(neg_edt)
228 |
229 | return res
230 |
231 |
232 |
233 | class DistBinaryDiceLoss(nn.Module):
234 | """
235 | Distance map penalized Dice loss
236 | Motivated by: https://openreview.net/forum?id=B1eIcvS45V
237 | Distance Map Loss Penalty Term for Semantic Segmentation
238 | """
239 |
240 | def __init__(self, smooth=1e-5):
241 | super(DistBinaryDiceLoss, self).__init__()
242 | self.smooth = smooth
243 |
244 | def forward(self, net_output, gt):
245 | """
246 | net_output: (batch_size, 2, x,y,z)
247 | target: ground truth, shape: (batch_size, 1, x,y,z)
248 | """
249 | net_output = torch.sigmoid(net_output)
250 | # one hot code for gt
251 | with torch.no_grad():
252 | if len(net_output.shape) != len(gt.shape):
253 | gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))
254 |
255 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
256 | # if this is the case then gt is probably already a one hot encoding
257 | y_onehot = gt
258 | else:
259 | gt = gt.long()
260 | y_onehot = torch.zeros(net_output.shape)
261 | if net_output.device.type == "cuda":
262 | y_onehot = y_onehot.cuda(net_output.device.index)
263 | y_onehot.scatter_(1, gt, 1)
264 |
265 | gt_temp = gt[:, 0, ...].type(torch.float32)
266 | with torch.no_grad():
267 | dist = compute_edts_forPenalizedLoss(gt_temp.cpu().numpy() > 0.5) + 1.0
268 | # print('dist.shape: ', dist.shape)
269 | dist = torch.from_numpy(dist)
270 |
271 | if dist.device != net_output.device:
272 | dist = dist.to(net_output.device).type(torch.float32)
273 |
274 | tp = net_output * y_onehot
275 | tp = torch.sum(tp[:, 1, ...] * dist, (1, 2, 3))
276 |
277 | dc = (2 * tp + self.smooth) / (torch.sum(net_output[:, 1, ...], (1, 2, 3)) + torch.sum(y_onehot[:, 1, ...], (1, 2, 3)) + self.smooth)
278 |
279 | dc = dc.mean()
280 |
281 | return -dc
282 |
--------------------------------------------------------------------------------
/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import os
3 | from torchvision import transforms
4 | from tifffile import TiffFile
5 | from PIL import Image
6 | import albumentations as A
7 | import torch
8 | import numpy as np
9 |
10 | from skimage.util.shape import view_as_windows
11 | import scipy.io as scio
12 | from .io import make_dataset
13 | from ._aug import get_augmentation
14 | from .imop import get_ins_info,gaussian_radius,draw_gaussian
15 |
16 | import matplotlib.pyplot as plt
17 |
18 | class NucleiDataset(data.Dataset):
19 | def __init__(self,config,seed, is_train,output_stride=4):
20 | stain_norm=config['stainnorm']
21 | data_root=config['dataroot']
22 | img_size=256
23 | self.grid_size=img_size//output_stride
24 |
25 | if not config['image_norm_mean']:
26 | _mean=(0.5,0.5,0.5)
27 | _std=(0.5,0.5,0.5)
28 | else:
29 | _mean = tuple(config['image_norm_mean'])
30 | _std = tuple(config['image_norm_std'])
31 | self.transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(_mean, _std)])
32 | print(f'Stain Norm Type: {stain_norm}, Trans mean std: {_mean}, {_std}.')
33 |
34 | self.phase='train' if is_train else 'test'
35 |
36 | self.img_dir_path = os.path.join(data_root,self.phase, 'Images') if stain_norm is None else os.path.join(data_root, self.phase, f'Images_{stain_norm}')
37 | self.img_paths = sorted(make_dataset(self.img_dir_path))
38 | self.gt_dir_path = os.path.join(data_root,self.phase, 'Labels')
39 | self.gt_paths = sorted(make_dataset(self.gt_dir_path))
40 | self.num_class=config['model']['num_classes']
41 | self.use_class= config['model']['num_classes']>2 and 'CoNSeP' in config['dataroot']
42 | print(f'use classes {self.use_class}')
43 |
44 | if 'cpm' in config['dataroot'].lower():
45 | self.images,self.masks,self.labels=self.load_cpm17(self.img_paths,self.gt_paths)
46 | else:
47 | self.images,self.masks,self.labels=self.load_crop_data(self.img_paths,self.gt_paths)
48 |
49 | self._size =self.images.shape[0]
50 | self.setup_augmentor(seed)
51 |
52 | def load_cpm17(self, img_paths, gt_paths, patch_size=384, stride=128):
53 | out_imgs = []
54 | out_masks = []
55 | out_labels = []
56 | for ip, mp in zip(img_paths, gt_paths):
57 | assert os.path.basename(ip)[:-4] == os.path.basename(mp)[:-4]
58 | if '.tif' in ip:
59 | with TiffFile(ip) as t:
60 | im=t.asarray()
61 | else:
62 | im = np.array(Image.open(ip).convert('RGB'))
63 | matfile=scio.loadmat(mp)
64 | mk=matfile['inst_map'].astype(np.int16)
65 |
66 | if self.use_class:
67 | lbl=matfile['inst_type'][:,0].astype(np.uint8)
68 | lbl[lbl == 4] = 3
69 | lbl[lbl == 5] = 4
70 | lbl[lbl == 6] = 4
71 | lbl[lbl == 7] = 4
72 | else:
73 | lbl=[1]*(np.max(mk))
74 | hs=list(range(0,im.shape[0]-patch_size,patch_size))+[im.shape[0]-patch_size]
75 | ws=list(range(0,im.shape[1]-patch_size,patch_size))+[im.shape[1]-patch_size]
76 | ims=[]
77 | mks=[]
78 | for h in hs:
79 | for w in ws:
80 | ims.append(im[h:h+patch_size,w:w+patch_size])
81 | mks.append(mk[h:h + patch_size, w:w + patch_size])
82 | ims=np.stack(ims,0)
83 | mks=np.stack(mks,0)
84 | out_imgs.append(ims)
85 | out_masks.append(mks)
86 | for idx in range(mks.shape[0]):
87 | tmk=mks[idx]
88 | olabel = {}
89 | for ui in np.unique(tmk):
90 | if ui==0:
91 | continue
92 | olabel[ui]=lbl[ui-1]
93 | out_labels.append(olabel)
94 | out_imgs = np.concatenate(out_imgs)
95 | out_masks = np.concatenate(out_masks)
96 | assert len(out_imgs.shape) == 4 and out_imgs.dtype == np.uint8
97 |
98 | print(f'processed data with size {len(out_imgs)} & {len(out_labels)}')
99 | return out_imgs, out_masks, out_labels
100 |
101 | def load_crop_data(self, img_paths, gt_paths, patch_size=384, stride=128):
102 | out_imgs = []
103 | out_masks = []
104 | out_labels = []
105 | resize = A.Resize(p=1, height=1024, width=1024)
106 | for ip, mp in zip(img_paths, gt_paths):
107 | assert os.path.basename(ip)[:-4] == os.path.basename(mp)[:-4]
108 | if '.tif' in ip:
109 | with TiffFile(ip) as t:
110 | im=t.asarray()
111 | else:
112 | im = np.array(Image.open(ip).convert('RGB'))
113 | matfile=scio.loadmat(mp)
114 | mk=matfile['inst_map'].astype(np.int16)
115 | if self.use_class:
116 | lbl=matfile['inst_type'][:,0].astype(np.uint8)
117 | lbl[lbl == 4] = 3
118 | lbl[lbl == 5] = 4
119 | lbl[lbl == 6] = 4
120 | lbl[lbl == 7] = 4
121 | else:
122 | lbl=[1]*(np.max(mk))
123 | augmented = resize(image=im, mask=mk)
124 | im = augmented['image']
125 | mk = augmented['mask']
126 | ims = view_as_windows(im, (patch_size, patch_size, 3), (stride, stride, 3)).reshape((-1, patch_size, patch_size, 3))
127 | mks = view_as_windows(mk, (patch_size, patch_size), (stride, stride)).reshape((-1, patch_size, patch_size))
128 | out_imgs.append(ims)
129 | out_masks.append(mks)
130 | for idx in range(mks.shape[0]):
131 | tmk=mks[idx]
132 | olabel = {}
133 | for ui in np.unique(tmk):
134 | if ui==0:
135 | continue
136 | olabel[ui]=lbl[ui-1]
137 | out_labels.append(olabel)
138 |
139 | out_imgs = np.concatenate(out_imgs)
140 | out_masks = np.concatenate(out_masks)
141 | assert len(out_imgs.shape) == 4 and out_imgs.dtype == np.uint8
142 |
143 | print(f'processed data with size {len(out_imgs)} & {len(out_labels)}')
144 | return out_imgs, out_masks, out_labels
145 |
146 | def setup_augmentor(self, seed):
147 | self.shape_augs, self.input_augs = get_augmentation(self.phase, seed)
148 |
149 | def __getitem__(self, index):
150 | img = self.images[index]
151 | mask = self.masks[index]
152 | label_dic = self.labels[index]
153 |
154 | shape_augs = self.shape_augs.to_deterministic()
155 | img = shape_augs.augment_image(img)
156 | masks = shape_augs.augment_image(mask)
157 |
158 | input_augs = self.input_augs.to_deterministic()
159 | img = input_augs.augment_image(img)
160 |
161 | cate_labels=[]
162 | ins_labels=[]
163 |
164 | for i, ui in enumerate(np.unique(masks)):
165 | if ui ==0:
166 | assert i==ui
167 | continue
168 | tmp_mask=masks==ui
169 | label=label_dic[ui]
170 | ins_labels.append(((tmp_mask)*1).astype(np.int32))
171 | cate_labels.append(label)
172 |
173 | if len(cate_labels)>0:
174 | cate_labels, ins_labels, ins_ind_labels= self.process_label(np.array(cate_labels), np.array(ins_labels))
175 | cate_labels=torch.from_numpy(np.array(cate_labels)).float()
176 | ins_labels=torch.from_numpy(ins_labels)
177 | ins_ind_labels=torch.from_numpy(ins_ind_labels).bool()
178 | else:
179 | cate_labels=torch.from_numpy(np.zeros([self.num_class - 1,64,64])).float()
180 | ins_labels=None
181 | ins_ind_labels=None
182 | image=self.transform(img)
183 | output={'image': image, 'cate_labels':cate_labels, 'ins_labels':ins_labels,'ins_ind_labels':ins_ind_labels}
184 | return output
185 |
186 | def process_label(self, gt_labels_raw, gt_masks_raw, iou_threshold=0.3, tau=0.5):
187 | w,h=256,256
188 |
189 | cate_label = np.zeros([self.num_class - 1, self.grid_size, self.grid_size], dtype=np.float)
190 | ins_label = np.zeros([self.grid_size ** 2, w, h], dtype=np.int16)
191 | ins_ind_label = np.zeros([self.grid_size ** 2], dtype=np.bool)
192 | if gt_masks_raw is not None:
193 | gt_labels = gt_labels_raw
194 | gt_masks = gt_masks_raw
195 | for seg_mask, gt_label in zip(gt_masks, gt_labels):
196 | center_w, center_h, width, height = get_ins_info(seg_mask, method='bbox')
197 | radius = max(gaussian_radius((width, height), iou_threshold), 0)
198 | coord_h = int((center_h / h) / (1. / self.grid_size))
199 | coord_w = int((center_w / w) / (1. / self.grid_size))
200 | temp = draw_gaussian(cate_label[gt_label - 1], (coord_w, coord_h), (radius / 4))
201 | non_zeros = (temp > tau).nonzero()
202 | label = non_zeros[0] * self.grid_size + non_zeros[1] # label = int(coord_h * grid_size + coord_w)#
203 | ins_label[label, :, :] = seg_mask
204 | ins_ind_label[label] = True
205 | ins_label=np.stack(ins_label[ins_ind_label],0)
206 | return cate_label,ins_label,ins_ind_label
207 |
208 | def __len__(self):
209 | return self._size
210 |
211 | class PannukeDataset(data.Dataset):
212 | """
213 | img_path: original image
214 | masks: one-hot masks
215 | GT: tiff mask, one channel denote one instance
216 | """
217 | def __init__(self, data_root, is_train,seed=888,fold=1,output_stride=4):
218 | self.grid_size=256//output_stride
219 |
220 | self.images,self.labels,self.masks= self.load_pannuke(data_root,fold)
221 | self.transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])
222 | self.num_class=6
223 | self.A_size =self.images.shape[0]
224 |
225 | self.mode='train' if is_train else "test"
226 | self.setup_augmentor(seed)
227 |
228 | def setup_augmentor(self, seed):
229 | self.shape_augs, self.input_augs = get_augmentation(self.mode, seed)
230 |
231 | def load_pannuke(self, data_root,fold=1):
232 | out_labels = []
233 | out_masks = []
234 | out_imgs=np.load(os.path.join(data_root,f'images/fold{fold}/images.npy')).astype(np.uint8)#(2523, 256, 256, 3)
235 | masks=np.load(os.path.join(data_root,f'masks/fold{fold}/masks.npy')).astype(np.int16)#(2523, 256, 256, 6)
236 | for i in range(masks.shape[0]):
237 | tmask=masks[i]
238 | olabel={}
239 | omask=np.zeros((256,256),dtype=np.int16)
240 | for j in range(5):
241 | ids=np.unique(tmask[:,:,j])
242 | if len(ids) ==1:
243 | continue
244 | else:
245 | for id in ids:
246 | if id==0:continue
247 | omask[tmask[:,:,j]==id]=id
248 | olabel[id]=j+1
249 | out_masks.append(omask)
250 | out_labels.append(olabel)
251 | out_masks = np.stack(out_masks,0)
252 | assert len(out_imgs.shape) == 4 and out_imgs.dtype == np.uint8
253 | assert len(out_masks.shape) == 3 and out_masks.dtype == np.int16, f'{out_masks.shape}, {out_masks.dtype}'
254 | assert out_masks.shape[0]==out_imgs.shape[0] and out_imgs.shape[0]==len(out_labels)
255 | print(f'processed data with size {len(out_imgs)}')
256 | return out_imgs, out_labels, out_masks
257 |
258 | def __getitem__(self, index):
259 | img = self.images[index]
260 | mask = self.masks[index]
261 | label_dic = self.labels[index]
262 |
263 | shape_augs = self.shape_augs.to_deterministic()
264 | img = shape_augs.augment_image(img)
265 | masks = shape_augs.augment_image(mask)
266 |
267 | input_augs = self.input_augs.to_deterministic()
268 | img = input_augs.augment_image(img)
269 |
270 | cate_labels=[]
271 | ins_labels=[]
272 |
273 | for i, ui in enumerate(np.unique(masks)):
274 | if ui ==0:
275 | assert i==ui
276 | continue
277 | tmp_mask=masks==ui
278 | label=label_dic[ui]
279 | ins_labels.append(((tmp_mask)*1).astype(np.int32))
280 | cate_labels.append(label)
281 |
282 | if len(cate_labels)>0:
283 | cate_labels, ins_labels, ins_ind_labels= self.process_label(np.array(cate_labels), np.array(ins_labels))
284 | cate_labels=torch.from_numpy(np.array(cate_labels)).float()
285 | ins_labels=torch.from_numpy(ins_labels)
286 | ins_ind_labels=torch.from_numpy(ins_ind_labels).bool()
287 | else:
288 | cate_labels=torch.from_numpy(np.zeros([self.num_class - 1,self.grid_size,self.grid_size])).float()
289 | ins_labels=None
290 | ins_ind_labels=None
291 | image=self.transform(img)
292 | output={'image': image, 'cate_labels':cate_labels, 'ins_labels':ins_labels,'ins_ind_labels':ins_ind_labels}
293 | return output
294 |
295 | def process_label(self, gt_labels_raw, gt_masks_raw, iou_threshold=0.3, tau=0.5):
296 | w,h=256,256
297 | cate_label = np.zeros([self.num_class - 1, self.grid_size, self.grid_size], dtype=np.float)
298 | ins_label = np.zeros([self.grid_size ** 2, w, h], dtype=np.int16)
299 | ins_ind_label = np.zeros([self.grid_size ** 2], dtype=np.bool)
300 | if gt_masks_raw is not None:
301 | gt_labels = gt_labels_raw
302 | gt_masks = gt_masks_raw
303 | for seg_mask, gt_label in zip(gt_masks, gt_labels):
304 | center_w, center_h, width, height = get_ins_info(seg_mask, method='bbox')
305 | radius = max(gaussian_radius((width, height), iou_threshold), 0)
306 | coord_h = int((center_h / h) / (1. / self.grid_size))
307 | coord_w = int((center_w / w) / (1. / self.grid_size))
308 | temp = draw_gaussian(cate_label[gt_label - 1], (coord_w, coord_h), (radius / 4))
309 | non_zeros = (temp > tau).nonzero()
310 | label = non_zeros[0] * self.grid_size + non_zeros[1] # label = int(coord_h * grid_size + coord_w)#
311 | cate_label[gt_label - 1, coord_h, coord_w] = 1
312 | label = int(coord_h * self.grid_size + coord_w) #
313 | ins_label[label, :, :] = seg_mask
314 | ins_ind_label[label] = True
315 | ins_label=np.stack(ins_label[ins_ind_label],0)
316 | return cate_label, ins_label, ins_ind_label
317 |
318 | def __len__(self):
319 | return self.A_size
320 |
--------------------------------------------------------------------------------