├── __init__.py ├── models ├── __init__.py ├── backbone │ ├── __init__.py │ ├── resnext.py │ ├── resnextdcn.py │ ├── res2next.py │ ├── res2netv1b.py │ ├── res2net.py │ └── resnetv1b.py ├── base_ops │ ├── __init__.py │ ├── CoordConv.py │ └── DCNv2.py ├── necks │ ├── __init__.py │ ├── jpfm.py │ └── fpn.py ├── model_store.py ├── download.py ├── segbase.py └── PointNuNet.py ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── vcs.xml ├── modules.xml └── PointNu-Net.iml ├── .gitmodules ├── utils ├── __init__.py ├── io.py ├── imsave.py ├── matrix_nms.py ├── imop.py ├── fmix.py ├── _aug.py └── dataloader.py ├── losses ├── __init__.py ├── bce_loss.py ├── focal_loss.py ├── generalise_focal_loss.py ├── lovasz_losses.py └── dice_loss.py ├── requirements.txt ├── eval_pannuke.py ├── LICENSE ├── configs ├── pannuke.yaml ├── pannuke_resnext101dcn.yaml ├── monuseg.yaml ├── cpm17_notype_large.yaml ├── consep_notype_large.yaml ├── consep_notype_middle.yaml ├── consep_notype_tiny.yaml ├── consep_type_large.yaml ├── consep_type_middle.yaml ├── consep_type_tiny.yaml ├── kumar_swin.yaml ├── kumar_resnet.yaml ├── kumar_notype_large.yaml └── kumar_resnext101dcn.yaml ├── train.py ├── train_pannuke.py ├── infer_pannuke.py ├── README.md └── inference.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/base_ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "PanNuKe-metrics"] 2 | path = PanNuKe-metrics 3 | url = https://github.com/TIA-Lab/PanNuke-metrics 4 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .io import prepare_sub_folder,get_config 2 | from .imsave import collate_func,_imageshow,_imagesave 3 | from .imop import gaussian_2d_kernel,insert_image,label_relaxation 4 | 5 | 6 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .focal_loss import BinaryFocalLoss,FocalLoss_Ori,BCELoss 2 | from .dice_loss import BinaryDiceLoss 3 | from .bce_loss import BCEWithLogitsLossWithOHEM 4 | from .lovasz_losses import lovasz_hinge -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.1.0 2 | imgaug==0.4.0 3 | matplotlib==3.4.3 4 | mmcv==2.0.0 5 | mmcv_full==1.4.8 6 | numpy==1.22.4+mkl 7 | opencv_python_headless==4.5.4.60 8 | Pillow==9.5.0 9 | PyYAML==6.0 10 | requests==2.21.0 11 | scikit_image==0.18.3 12 | scipy==1.7.1 13 | skimage==0.0 14 | tifffile==2021.7.2 15 | timm==0.4.12 16 | torch==1.10.0+cu113 17 | torchsummaryX==1.3.0 18 | torchvision==0.11.0+cu113 19 | tqdm==4.62.3 20 | -------------------------------------------------------------------------------- /.idea/PointNu-Net.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /eval_pannuke.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--name', type=str, default='pannuke_exp') 6 | parser.add_argument('--save_path', type=str, default='pannuke_outputs') 7 | parser.add_argument('--train_fold', type=int, default=1) 8 | parser.add_argument('--test_fold', type=int, default=3) 9 | opts = parser.parse_args() 10 | 11 | pred_path = rf'outputs/{opts.name}/train_{opts.train_fold}_to_test_{opts.test_fold}' 12 | true_path = rf'datasets/PanNuKe/masks/fold{opts.test_fold}' 13 | save_path = opts.save_path 14 | cmd=rf"python PanNuke-metrics/run.py --true_path={true_path} --pred_path={pred_path} --save_path={save_path}" 15 | 16 | os.system(cmd) 17 | -------------------------------------------------------------------------------- /models/base_ops/CoordConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CoordConv2d(nn.Conv2d): 7 | 8 | def __init__(self, in_chan, out_chan, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): 9 | super(CoordConv2d, self).__init__( 10 | in_chan + 2, out_chan, kernel_size, stride=stride, 11 | padding=padding, dilation=dilation, groups=groups, bias=bias) 12 | 13 | def forward(self, x): 14 | batchsize, H, W = x.size(0), x.size(2), x.size(3) 15 | h_range = torch.linspace(-1, 1, H, device=x.device, dtype=x.dtype) 16 | w_range = torch.linspace(-1, 1, W, device=x.device, dtype=x.dtype) 17 | h_chan, w_chan = torch.meshgrid(h_range, w_range) 18 | h_chan = h_chan.expand([batchsize, 1, -1, -1]) 19 | w_chan = w_chan.expand([batchsize, 1, -1, -1]) 20 | 21 | feat = torch.cat([h_chan, w_chan, x], dim=1) 22 | 23 | return F.conv2d(feat, self.weight, self.bias, 24 | self.stride, self.padding, self.dilation, self.groups) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kaiseem 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/necks/jpfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class JPFM(nn.Module): 5 | def __init__(self,in_channel,width=256): 6 | super(JPFM, self).__init__() 7 | 8 | self.out_channel = width*4 9 | self.dilation1 = nn.Sequential( 10 | nn.Conv2d(in_channel, width, 3, padding=1, dilation=1, bias=False), 11 | nn.BatchNorm2d(width), 12 | nn.ReLU(True)) 13 | self.dilation2 = nn.Sequential( 14 | nn.Conv2d(in_channel, width, 3, padding=2, dilation=2, bias=False), 15 | nn.BatchNorm2d(width), 16 | nn.ReLU(True)) 17 | self.dilation3 = nn.Sequential( 18 | nn.Conv2d(in_channel, width, 3, padding=4, dilation=4, bias=False), 19 | nn.BatchNorm2d(width), 20 | nn.ReLU(True)) 21 | self.dilation4 = nn.Sequential( 22 | nn.Conv2d(in_channel, width, 3, padding=8, dilation=8, bias=False), 23 | nn.BatchNorm2d(width), 24 | nn.ReLU(True)) 25 | 26 | def forward(self,feat): 27 | feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)], dim=1) 28 | return feat 29 | 30 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | IMG_EXTENSIONS = [ 4 | '.jpg', '.JPG', '.jpeg', '.JPEG', 5 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 6 | '.tif', '.TIF', '.tiff', '.TIFF','npy','mat' 7 | ] 8 | 9 | def prepare_sub_folder(output_directory): 10 | image_directory = os.path.join(output_directory, 'images') 11 | if not os.path.exists(image_directory): 12 | print("Creating directory: {}".format(image_directory)) 13 | os.makedirs(image_directory) 14 | checkpoint_directory = os.path.join(output_directory, 'checkpoints') 15 | if not os.path.exists(checkpoint_directory): 16 | print("Creating directory: {}".format(checkpoint_directory)) 17 | os.makedirs(checkpoint_directory) 18 | return checkpoint_directory, image_directory 19 | 20 | def get_config(config): 21 | with open(config,'r') as stream: 22 | return yaml.load(stream, Loader=yaml.FullLoader) 23 | 24 | def is_image_file(filename): 25 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 26 | 27 | def make_dataset(dir, max_dataset_size=float("inf")): 28 | images = [] 29 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 30 | for root, _, fnames in sorted(os.walk(dir)): 31 | for fname in fnames: 32 | if is_image_file(fname): 33 | path = os.path.join(root, fname) 34 | images.append(path) 35 | return images[:min(max_dataset_size, len(images))] 36 | -------------------------------------------------------------------------------- /configs/pannuke.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: hrnet64 16 | pretrain: true 17 | frozen_stages: -1 18 | norm_eval: false 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 7 22 | num_classes: 6 23 | kernel_size: 1 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 8 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | dataroot: ./datasets/PanNuKe 42 | -------------------------------------------------------------------------------- /configs/pannuke_resnext101dcn.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: resnext101dcn 16 | pretrain: true 17 | frozen_stages: -1 18 | norm_eval: false 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 7 22 | num_classes: 6 23 | kernel_size: 1 24 | output_stride: 4 25 | 26 | 27 | train: 28 | max_epoch: 100 # maximum number of training iterations 29 | batch_size: 8 # batch size 30 | num_workers: 4 # 31 | optim: adamw 32 | lr: 0.0001 # initial learning rate 5e-4 33 | weight_decay: 0.0001 # weight decay 1e-4 34 | beta1: 0.9 # Adam parameter 35 | beta2: 0.999 # Adam parameter 36 | lr_policy: multistep # learning rate scheduler 37 | gamma: 0.1 # how much to decay learning rate 38 | use_mixed: false 39 | lambda_ins: 1 # weight of image instance segmentation loss 40 | lambda_cate: 1 # weight of image category classification loss 41 | 42 | dataroot: ./datasets/PanNuKe -------------------------------------------------------------------------------- /configs/monuseg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: hrnet64 16 | pretrain: true 17 | frozen_stages: -1 18 | norm_eval: false 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 7 22 | num_classes: 2 23 | kernel_size: 1 24 | output_stride: 4 25 | 26 | 27 | train: 28 | max_epoch: 100 # maximum number of training iterations 29 | batch_size: 8 # batch size 30 | num_workers: 4 # 31 | optim: adamw 32 | lr: 0.0005 # initial learning rate 5e-4 33 | weight_decay: 0.0001 # weight decay 1e-4 34 | beta1: 0.9 # Adam parameter 35 | beta2: 0.999 # Adam parameter 36 | lr_policy: multistep # learning rate scheduler 37 | gamma: 0.1 # how much to decay learning rate 38 | use_mixed: false 39 | lambda_ins: 1 # weight of image instance segmentation loss 40 | lambda_cate: 1 # weight of image category classification loss 41 | 42 | 43 | dataroot: ./datasets/monuseg 44 | stainnorm: normed -------------------------------------------------------------------------------- /configs/cpm17_notype_large.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: hrnet64 16 | pretrain: true 17 | frozen_stages: 1 18 | norm_eval: true 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 7 22 | kernel_size: 1 23 | num_classes: 2 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 8 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | 42 | dataroot: ./datasets/cpm17 43 | stainnorm: normed 44 | image_norm_mean: 45 | - 0.7949688 46 | - 0.55828372 47 | - 0.70746591 48 | image_norm_std: 49 | - 0.1847547 50 | - 0.22577311 51 | - 0.17084152 -------------------------------------------------------------------------------- /configs/consep_notype_large.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: hrnet64 16 | pretrain: true 17 | frozen_stages: 1 18 | norm_eval: true 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 7 22 | kernel_size: 1 23 | num_classes: 2 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 8 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | 42 | dataroot: ./datasets/CoNSeP 43 | stainnorm: null 44 | image_norm_mean: 45 | - 0.82899987 46 | - 0.70380247 47 | - 0.84849265 48 | image_norm_std: 49 | - 0.16051538 50 | - 0.19952672 51 | - 0.12816858 52 | -------------------------------------------------------------------------------- /configs/consep_notype_middle.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: hrnet32 16 | pretrain: true 17 | frozen_stages: 1 18 | norm_eval: true 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 4 22 | kernel_size: 1 23 | num_classes: 2 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 8 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | 42 | dataroot: ./datasets/CoNSeP 43 | stainnorm: null 44 | image_norm_mean: 45 | - 0.82899987 46 | - 0.70380247 47 | - 0.84849265 48 | image_norm_std: 49 | - 0.16051538 50 | - 0.19952672 51 | - 0.12816858 52 | -------------------------------------------------------------------------------- /configs/consep_notype_tiny.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: hrnet18 16 | pretrain: true 17 | frozen_stages: 1 18 | norm_eval: true 19 | seg_feat_channels: 128 20 | ins_out_channels: 128 21 | stacked_convs: 4 22 | kernel_size: 1 23 | num_classes: 2 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 16 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | 42 | dataroot: ./datasets/CoNSeP 43 | stainnorm: null 44 | image_norm_mean: 45 | - 0.82899987 46 | - 0.70380247 47 | - 0.84849265 48 | image_norm_std: 49 | - 0.16051538 50 | - 0.19952672 51 | - 0.12816858 52 | -------------------------------------------------------------------------------- /configs/consep_type_large.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: hrnet64 16 | pretrain: true 17 | frozen_stages: 1 18 | norm_eval: true 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 7 22 | kernel_size: 1 23 | num_classes: 5 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 8 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | 42 | dataroot: ./datasets/CoNSeP 43 | stainnorm: null 44 | image_norm_mean: 45 | - 0.82899987 46 | - 0.70380247 47 | - 0.84849265 48 | image_norm_std: 49 | - 0.16051538 50 | - 0.19952672 51 | - 0.12816858 52 | -------------------------------------------------------------------------------- /configs/consep_type_middle.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: hrnet32 16 | pretrain: true 17 | frozen_stages: 1 18 | norm_eval: true 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 4 22 | kernel_size: 1 23 | num_classes: 5 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 8 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | 42 | dataroot: ./datasets/CoNSeP 43 | stainnorm: null 44 | image_norm_mean: 45 | - 0.82899987 46 | - 0.70380247 47 | - 0.84849265 48 | image_norm_std: 49 | - 0.16051538 50 | - 0.19952672 51 | - 0.12816858 52 | -------------------------------------------------------------------------------- /configs/consep_type_tiny.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: hrnet18 16 | pretrain: true 17 | frozen_stages: 1 18 | norm_eval: true 19 | seg_feat_channels: 128 20 | ins_out_channels: 128 21 | stacked_convs: 4 22 | kernel_size: 1 23 | num_classes: 5 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 16 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | 42 | dataroot: ./datasets/CoNSeP 43 | stainnorm: null 44 | image_norm_mean: 45 | - 0.82899987 46 | - 0.70380247 47 | - 0.84849265 48 | image_norm_std: 49 | - 0.16051538 50 | - 0.19952672 51 | - 0.12816858 52 | -------------------------------------------------------------------------------- /configs/kumar_swin.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: swin 16 | pretrain: true 17 | frozen_stages: -1 18 | norm_eval: false 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 7 22 | kernel_size: 1 23 | num_classes: 2 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 8 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | dataroot: ./datasets/kumar 42 | 43 | stainnorm: normed 44 | 45 | image_norm_mean: 46 | - 0.71174131 47 | - 0.5287984 48 | - 0.63705888 49 | image_norm_std: 50 | - 0.17307392 51 | - 0.19106038 52 | - 0.14219015 53 | 54 | -------------------------------------------------------------------------------- /configs/kumar_resnet.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: resnet101 16 | pretrain: true 17 | frozen_stages: 1 18 | norm_eval: true 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 7 22 | kernel_size: 1 23 | num_classes: 2 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 8 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | dataroot: ./datasets/kumar 42 | 43 | stainnorm: normed 44 | 45 | image_norm_mean: 46 | - 0.71174131 47 | - 0.5287984 48 | - 0.63705888 49 | image_norm_std: 50 | - 0.17307392 51 | - 0.19106038 52 | - 0.14219015 53 | 54 | -------------------------------------------------------------------------------- /configs/kumar_notype_large.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: hrnet64 16 | pretrain: true 17 | frozen_stages: 1 18 | norm_eval: true 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 7 22 | kernel_size: 1 23 | num_classes: 2 24 | output_stride: 4 25 | 26 | train: 27 | max_epoch: 100 # maximum number of training iterations 28 | batch_size: 8 # batch size 29 | num_workers: 4 # 30 | optim: adamw 31 | lr: 0.0001 # initial learning rate 5e-4 32 | weight_decay: 0.0001 # weight decay 1e-4 33 | beta1: 0.9 # Adam parameter 34 | beta2: 0.999 # Adam parameter 35 | lr_policy: multistep # learning rate scheduler 36 | gamma: 0.1 # how much to decay learning rate 37 | use_mixed: false 38 | lambda_ins: 1 # weight of image instance segmentation loss 39 | lambda_cate: 1 # weight of image category classification loss 40 | 41 | dataroot: ./datasets/kumar 42 | 43 | stainnorm: normed 44 | 45 | image_norm_mean: 46 | - 0.71174131 47 | - 0.5287984 48 | - 0.63705888 49 | image_norm_std: 50 | - 0.17307392 51 | - 0.19106038 52 | - 0.14219015 53 | 54 | -------------------------------------------------------------------------------- /configs/kumar_resnext101dcn.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_epoch: 5 # How often do you want to save output images during training 6 | image_display_epoch: 5 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_epoch: 5 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | mask_rescoring: false 12 | 13 | # model options 14 | model: 15 | backbone: resnext101dcn 16 | pretrain: true 17 | frozen_stages: 1 18 | norm_eval: true 19 | seg_feat_channels: 256 20 | ins_out_channels: 256 21 | stacked_convs: 7 22 | kernel_size: 1 23 | num_classes: 2 24 | output_stride: 4 25 | 26 | 27 | train: 28 | max_epoch: 100 # maximum number of training iterations 29 | batch_size: 8 # batch size 30 | num_workers: 4 # 31 | optim: adamw 32 | lr: 0.0001 # initial learning rate 5e-4 33 | weight_decay: 0.0001 # weight decay 1e-4 34 | beta1: 0.9 # Adam parameter 35 | beta2: 0.999 # Adam parameter 36 | lr_policy: multistep # learning rate scheduler 37 | gamma: 0.1 # how much to decay learning rate 38 | use_mixed: false 39 | lambda_ins: 1 # weight of image instance segmentation loss 40 | lambda_cate: 1 # weight of image category classification loss 41 | 42 | dataroot: ./datasets/kumar 43 | 44 | stainnorm: normed 45 | 46 | image_norm_mean: 47 | - 0.71174131 48 | - 0.5287984 49 | - 0.63705888 50 | image_norm_std: 51 | - 0.17307392 52 | - 0.19106038 53 | - 0.14219015 54 | 55 | -------------------------------------------------------------------------------- /losses/bce_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | 8 | 9 | class BCEWithLogitsLossWithOHEM(nn.Module): 10 | 11 | def __init__(self, ohem_ratio=1.0, pos_weight=None, eps=1e-7): 12 | super(BCEWithLogitsLossWithOHEM, self).__init__() 13 | self.criterion = nn.BCEWithLogitsLoss(reduction='none', 14 | pos_weight=pos_weight) 15 | self.ohem_ratio = ohem_ratio 16 | self.eps = eps 17 | 18 | def forward(self, pred, target): 19 | loss = self.criterion(pred, target) 20 | mask = _ohem_mask(loss, self.ohem_ratio) 21 | loss = loss * mask 22 | return loss.sum() / (mask.sum() + self.eps) 23 | 24 | def set_ohem_ratio(self, ohem_ratio): 25 | self.ohem_ratio = ohem_ratio 26 | 27 | def _ohem_mask(loss, ohem_ratio): 28 | with torch.no_grad(): 29 | values, _ = torch.topk(loss.reshape(-1), 30 | int(loss.nelement() * ohem_ratio)) 31 | mask = loss >= values[-1] 32 | return mask.float() 33 | class CrossEntropyLossWithOHEM(nn.Module): 34 | 35 | def __init__(self, ohem_ratio=1.0, weight=None, ignore_index=-100, 36 | eps=1e-7): 37 | super(CrossEntropyLossWithOHEM, self).__init__() 38 | self.criterion = nn.CrossEntropyLoss(weight=weight, 39 | ignore_index=ignore_index, 40 | reduction='none') 41 | self.ohem_ratio = ohem_ratio 42 | self.eps = eps 43 | 44 | def forward(self, pred, target): 45 | loss = self.criterion(pred, target) 46 | mask = _ohem_mask(loss, self.ohem_ratio) 47 | loss = loss * mask 48 | return loss.sum() / (mask.sum() + self.eps) 49 | 50 | def set_ohem_ratio(self, ohem_ratio): 51 | self.ohem_ratio = ohem_ratio 52 | -------------------------------------------------------------------------------- /models/necks/fpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class FPN(nn.Module): 5 | def __init__(self,channels=[256,512,1024,2048]): 6 | super(FPN, self).__init__() 7 | self.in_planes = 64 8 | # Top layer 9 | self.toplayer = nn.Conv2d(channels[3], 256, kernel_size=1, stride=1, padding=0) # Reduce channels 10 | 11 | # Smooth layers 12 | self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 13 | self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 14 | self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 15 | 16 | # Lateral layers 17 | self.latlayer1 = nn.Conv2d(channels[2], 256, kernel_size=1, stride=1, padding=0) 18 | self.latlayer2 = nn.Conv2d(channels[1], 256, kernel_size=1, stride=1, padding=0) 19 | self.latlayer3 = nn.Conv2d(channels[0], 256, kernel_size=1, stride=1, padding=0) 20 | 21 | def _upsample_add(self, x, y): 22 | '''Upsample and add two feature maps. 23 | Args: 24 | x: (Variable) top feature map to be upsampled. 25 | y: (Variable) lateral feature map. 26 | Returns: 27 | (Variable) added feature map. 28 | Note in PyTorch, when input size is odd, the upsampled feature map 29 | with `F.upsample(..., scale_factor=2, mode='nearest')` 30 | maybe not equal to the lateral feature map size. 31 | e.g. 32 | original input size: [N,_,15,15] -> 33 | conv2d feature map size: [N,_,8,8] -> 34 | upsampled feature map size: [N,_,16,16] 35 | So we choose bilinear upsample which supports arbitrary output sizes. 36 | ''' 37 | _, _, H, W = y.size() 38 | return F.interpolate(x, size=(H, W), mode='bilinear',align_corners=True) + y 39 | 40 | def forward(self, c2, c3, c4, c5): 41 | # Top-down 42 | p5 = self.toplayer(c5) 43 | p4 = self._upsample_add(p5, self.latlayer1(c4)) 44 | p3 = self._upsample_add(p5, self.latlayer2(c3)) 45 | p2 = self._upsample_add(p4, self.latlayer3(c2)) 46 | # Smooth 47 | p4 = self.smooth1(p4) 48 | p3 = self.smooth2(p3) 49 | p2 = self.smooth3(p2) 50 | return p2, p3, p4, p5 -------------------------------------------------------------------------------- /models/model_store.py: -------------------------------------------------------------------------------- 1 | """Model store which provides pretrained models.""" 2 | from __future__ import print_function 3 | 4 | import os 5 | import zipfile 6 | 7 | from .download import download, check_sha1 8 | 9 | __all__ = ['get_model_file', 'get_resnet_file'] 10 | 11 | _model_sha1 = {name: checksum for checksum, name in [ 12 | ('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'), 13 | ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'), 14 | ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'), 15 | ]} 16 | 17 | encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' 18 | _url_format = '{repo_url}encoding/models/{file_name}.zip' 19 | 20 | 21 | def short_hash(name): 22 | if name not in _model_sha1: 23 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 24 | return _model_sha1[name][:8] 25 | 26 | 27 | def get_resnet_file(name, root='~/.torch/models'): 28 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) 29 | root = os.path.expanduser(root) 30 | 31 | file_path = os.path.join(root, file_name + '.pth') 32 | sha1_hash = _model_sha1[name] 33 | if os.path.exists(file_path): 34 | if check_sha1(file_path, sha1_hash): 35 | return file_path 36 | else: 37 | print('Mismatch in the content of model file {} detected.' + 38 | ' Downloading again.'.format(file_path)) 39 | else: 40 | print('Model file {} is not found. Downloading.'.format(file_path)) 41 | 42 | if not os.path.exists(root): 43 | os.makedirs(root) 44 | 45 | zip_file_path = os.path.join(root, file_name + '.zip') 46 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url) 47 | if repo_url[-1] != '/': 48 | repo_url = repo_url + '/' 49 | download(_url_format.format(repo_url=repo_url, file_name=file_name), 50 | path=zip_file_path, 51 | overwrite=True) 52 | with zipfile.ZipFile(zip_file_path) as zf: 53 | zf.extractall(root) 54 | os.remove(zip_file_path) 55 | 56 | if check_sha1(file_path, sha1_hash): 57 | return file_path 58 | else: 59 | raise ValueError('Downloaded file has different hash. Please try again.') 60 | 61 | 62 | def get_model_file(name, root='~/.torch/models'): 63 | root = os.path.expanduser(root) 64 | file_path = os.path.join(root, name + '.pth') 65 | if os.path.exists(file_path): 66 | return file_path 67 | else: 68 | raise ValueError('Model file is not found. Downloading or trainning.') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os,shutil 2 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 3 | 4 | from utils.dataloader import NucleiDataset 5 | from trainer import Trainer 6 | from torch.utils.data import DataLoader 7 | import sys 8 | import torch.nn as nn 9 | from utils import prepare_sub_folder,get_config,collate_func 10 | import torch 11 | import numpy as np 12 | import math 13 | import argparse 14 | import random 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--config', type=str, default='configs/kumar_notype_large.yaml') 17 | parser.add_argument('--name', type=str, default='tmp') 18 | parser.add_argument('--output_dir', type=str, default='outputs') 19 | parser.add_argument('--seed', type=int, default=10) 20 | opts = parser.parse_args() 21 | 22 | def check_manual_seed(seed): 23 | """ If manual seed is not specified, choose a 24 | random one and communicate it to the user. 25 | Args: 26 | seed: seed to check 27 | """ 28 | seed = seed or random.randint(1, 10000) 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | # ia.random.seed(seed) 34 | 35 | print("Using manual seed: {seed}".format(seed=seed)) 36 | return 37 | 38 | if __name__ == '__main__': 39 | config=get_config(opts.config) 40 | train_dataset=NucleiDataset(config,opts.seed,is_train=True) 41 | check_manual_seed(opts.seed) 42 | train_loader=DataLoader(dataset=train_dataset, batch_size=config['train']['batch_size'], shuffle=True, drop_last=True, num_workers=config['train']['num_workers'],collate_fn=collate_func,pin_memory=True) 43 | 44 | output_directory = os.path.join(opts.output_dir, opts.name) 45 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory) 46 | shutil.copy(opts.config,os.path.join(output_directory,'config.yaml')) 47 | 48 | trainer = Trainer(config) 49 | trainer.cuda() 50 | 51 | iteration=0 52 | iter_per_epoch=len(train_loader) 53 | 54 | for epoch in range(config['train']['max_epoch']): 55 | for train_data in train_loader: 56 | for k in train_data.keys(): 57 | if not isinstance(train_data[k], list): 58 | train_data[k] = train_data[k].cuda().detach() 59 | else: 60 | train_data[k] = [s.cuda().detach() if s is not None else s for s in train_data[k]] 61 | 62 | ins_loss, cate_loss, maskiou_loss = trainer.seg_updata_FMIX(train_data) 63 | 64 | sys.stdout.write( 65 | f'\r epoch:{epoch} step:{iteration}/{iter_per_epoch} ins_loss: {ins_loss} cate_loss: {cate_loss} maskiou_loss: {maskiou_loss}') 66 | iteration += 1 67 | if (epoch + 1) % 20 == 0: 68 | trainer.save(checkpoint_directory, epoch) 69 | trainer.scheduler.step() 70 | 71 | trainer.save(checkpoint_directory, epoch) 72 | 73 | -------------------------------------------------------------------------------- /train_pannuke.py: -------------------------------------------------------------------------------- 1 | from utils.dataloader import NucleiDataset,PannukeDataset 2 | from trainer import Trainer 3 | from torch.utils.data import DataLoader 4 | import sys 5 | from utils import prepare_sub_folder,get_config,collate_func 6 | import torch 7 | import numpy as np 8 | import os,shutil 9 | import argparse 10 | import random 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--config', type=str, default='configs/pannuke.yaml') 14 | parser.add_argument('--name', type=str, default='pannuke_experiment') 15 | parser.add_argument('--train_fold', type=int, default=2) 16 | parser.add_argument('--val_fold', type=int, default=1) 17 | parser.add_argument('--test_fold', type=int, default=3) 18 | parser.add_argument('--output_dir', type=str, default='outputs') 19 | parser.add_argument('--seed', type=int, default=10) 20 | opts = parser.parse_args() 21 | 22 | def check_manual_seed(seed): 23 | """ If manual seed is not specified, choose a 24 | random one and communicate it to the user. 25 | Args: 26 | seed: seed to check 27 | """ 28 | seed = seed or random.randint(1, 10000) 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | # ia.random.seed(seed) 34 | 35 | print("Using manual seed: {seed}".format(seed=seed)) 36 | return 37 | 38 | if __name__ == '__main__': 39 | config=get_config(opts.config) 40 | check_manual_seed(opts.seed) 41 | train_dataset=PannukeDataset(data_root=config['dataroot'], seed=opts.seed, is_train=True, fold=opts.train_fold,output_stride=config['model']['output_stride']) 42 | train_loader=DataLoader(dataset=train_dataset, batch_size=config['train']['batch_size'], shuffle=True, drop_last=True, num_workers=config['train']['num_workers'],persistent_workers=True,collate_fn=collate_func,pin_memory=True) 43 | 44 | output_directory = os.path.join(opts.output_dir, opts.name, 'train_{}_to_test_{}'.format( opts.train_fold,opts.test_fold)) 45 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory) 46 | shutil.copy(opts.config,os.path.join(output_directory,'config.yaml')) 47 | 48 | trainer = Trainer(config) 49 | trainer.cuda() 50 | 51 | iteration=0 52 | iter_per_epoch=len(train_loader) 53 | for epoch in range(config['train']['max_epoch']): 54 | for train_data in train_loader: 55 | for k in train_data.keys(): 56 | if not isinstance(train_data[k],list): 57 | train_data[k]=train_data[k].cuda().detach() 58 | else: 59 | train_data[k] = [s.cuda().detach() if s is not None else s for s in train_data[k]] 60 | ins_loss, cate_loss,maskiou_loss=trainer.seg_updata(train_data) 61 | sys.stdout.write(f'\r epoch:{epoch} step:{iteration}/{iter_per_epoch} ins_loss: {ins_loss} cate_loss: {cate_loss} maskiou_loss: {maskiou_loss}') 62 | iteration+=1 63 | trainer.scheduler.step() 64 | 65 | if (epoch+1)%50==0: 66 | trainer.save(checkpoint_directory, epoch) 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /infer_pannuke.py: -------------------------------------------------------------------------------- 1 | from trainer import Trainer 2 | try: 3 | from itertools import izip as zip 4 | except ImportError: # will be 3.x series 5 | pass 6 | import scipy.io as scio 7 | from utils.dataloader import NucleiDataset,PannukeDataset 8 | import torch 9 | import os 10 | from torch.utils.data import DataLoader 11 | import argparse 12 | import numpy as np 13 | from utils import get_config,collate_func 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--output_dir', type=str, default='outputs') 16 | parser.add_argument('--name', type=str, default='pannuke227') 17 | parser.add_argument('--train_fold', type=int, default=2) 18 | parser.add_argument('--test_fold', type=int, default=3) 19 | parser.add_argument('--epoch',type=int,default=100) 20 | opts = parser.parse_args() 21 | 22 | def stack_prediction(seg_masks,cate_labels): 23 | out_seg=np.zeros((256,256,6)) 24 | idx_num=1 25 | for mask,label in zip(seg_masks,cate_labels): 26 | assert label!=5 27 | out_seg[:,:,label]=np.maximum(out_seg[:,:,label],mask*idx_num) 28 | idx_num+=1 29 | out_seg[:,:,5]=np.sum(out_seg[:,:,:5],axis=-1)==0 30 | return out_seg 31 | 32 | if __name__ == '__main__': 33 | opts.config=os.path.join(opts.output_dir,'{}'.format(opts.name),'train_{}_to_test_{}/config.yaml'.format( opts.train_fold,opts.test_fold)) 34 | config=get_config(opts.config) 35 | 36 | #train_dataset=NucleiDataset(data_root=config['dataroot'],is_train=True,stain_norm=stain_norm_type) 37 | test_dataset=PannukeDataset(data_root=config['dataroot'], is_train=False, fold=opts.test_fold,output_stride=config['model']['output_stride']) 38 | test_loader=DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=0,collate_fn=collate_func,pin_memory=True) 39 | config['model']['kernel_size']=1 40 | config['train']['use_mixed']=False 41 | trainer = Trainer(config) 42 | trainer.cuda() 43 | 44 | state_path = os.path.join(opts.output_dir,opts.name,'train_{}_to_test_{}'.format( opts.train_fold,opts.test_fold),'checkpoints/model_{}.pt'.format('%04d' % (opts.epoch))) 45 | #state_path = os.path.join(opts.output_dir,opts.name,'train_{}_to_test_{}'.format( opts.train_fold,opts.test_fold),'checkpoints/model_ema.pt') 46 | state_dict = torch.load(state_path) 47 | 48 | trainer.model.load_state_dict(state_dict['seg']) 49 | predictions=[] 50 | for test_data in test_loader: 51 | for k in test_data.keys(): 52 | if not isinstance(test_data[k], list): 53 | test_data[k] = test_data[k].cuda().detach() 54 | else: 55 | test_data[k] = [s.cuda().detach() if s is not None else s for s in test_data[k]] 56 | with torch.no_grad(): 57 | img=test_data['image'] 58 | output = trainer.prediction(img, score_thr=0.4, update_thr=0.2) 59 | if output is not None: 60 | seg_masks, cate_labels, cate_scores = output 61 | seg_masks = seg_masks.cpu().numpy() 62 | cate_labels = cate_labels.cpu().numpy() 63 | cate_scores = cate_scores.cpu().numpy() 64 | predictions.append(stack_prediction(seg_masks, cate_labels)) 65 | else: 66 | predictions.append(np.zeros((256,256,6))) 67 | 68 | predictions=np.stack(predictions,0).astype(np.int16) 69 | save_fp= os.path.join(opts.output_dir,opts.name,'train_{}_to_test_{}/masks.npy'.format( opts.train_fold,opts.test_fold)) 70 | 71 | np.save(save_fp,predictions) 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /models/download.py: -------------------------------------------------------------------------------- 1 | """Download files with progress bar.""" 2 | import os 3 | import hashlib 4 | import requests 5 | from tqdm import tqdm 6 | 7 | def check_sha1(filename, sha1_hash): 8 | """Check whether the sha1 hash of the file content matches the expected hash. 9 | Parameters 10 | ---------- 11 | filename : str 12 | Path to the file. 13 | sha1_hash : str 14 | Expected sha1 hash in hexadecimal digits. 15 | Returns 16 | ------- 17 | bool 18 | Whether the file content matches the expected hash. 19 | """ 20 | sha1 = hashlib.sha1() 21 | with open(filename, 'rb') as f: 22 | while True: 23 | data = f.read(1048576) 24 | if not data: 25 | break 26 | sha1.update(data) 27 | 28 | sha1_file = sha1.hexdigest() 29 | l = min(len(sha1_file), len(sha1_hash)) 30 | return sha1.hexdigest()[0:l] == sha1_hash[0:l] 31 | 32 | def download(url, path=None, overwrite=False, sha1_hash=None): 33 | """Download an given URL 34 | Parameters 35 | ---------- 36 | url : str 37 | URL to download 38 | path : str, optional 39 | Destination path to store downloaded file. By default stores to the 40 | current directory with same name as in url. 41 | overwrite : bool, optional 42 | Whether to overwrite destination file if already exists. 43 | sha1_hash : str, optional 44 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 45 | but doesn't match. 46 | Returns 47 | ------- 48 | str 49 | The file path of the downloaded file. 50 | """ 51 | if path is None: 52 | fname = url.split('/')[-1] 53 | else: 54 | path = os.path.expanduser(path) 55 | if os.path.isdir(path): 56 | fname = os.path.join(path, url.split('/')[-1]) 57 | else: 58 | fname = path 59 | 60 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 61 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 62 | if not os.path.exists(dirname): 63 | os.makedirs(dirname) 64 | 65 | print('Downloading %s from %s...'%(fname, url)) 66 | r = requests.get(url, stream=True) 67 | if r.status_code != 200: 68 | raise RuntimeError("Failed downloading url %s"%url) 69 | total_length = r.headers.get('content-length') 70 | with open(fname, 'wb') as f: 71 | if total_length is None: # no content length header 72 | for chunk in r.iter_content(chunk_size=1024): 73 | if chunk: # filter out keep-alive new chunks 74 | f.write(chunk) 75 | else: 76 | total_length = int(total_length) 77 | for chunk in tqdm(r.iter_content(chunk_size=1024), 78 | total=int(total_length / 1024. + 0.5), 79 | unit='KB', unit_scale=False, dynamic_ncols=True): 80 | f.write(chunk) 81 | 82 | if sha1_hash and not check_sha1(fname, sha1_hash): 83 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 84 | 'The repo may be outdated or download may be incomplete. ' \ 85 | 'If the "repo_url" is overridden, consider switching to ' \ 86 | 'the default repo.'.format(fname)) 87 | 88 | return fname 89 | -------------------------------------------------------------------------------- /models/base_ops/DCNv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.ops import DeformConv2d, deform_conv2d 4 | import math 5 | class DeformConv(DeformConv2d): 6 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,dilation=1,groups=1, bias=None): 7 | super(DeformConv,self).__init__(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size, stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias) 8 | channels_ = groups * 3 * self.kernel_size[0] * self.kernel_size[1] 9 | self.conv_offset_mask = nn.Conv2d(self.in_channels, 10 | channels_, 11 | kernel_size=self.kernel_size, 12 | stride=self.stride, 13 | padding=self.padding, 14 | bias=True) 15 | self.init_offset() 16 | 17 | def init_offset(self): 18 | self.conv_offset_mask.weight.data.zero_() 19 | self.conv_offset_mask.bias.data.zero_() 20 | 21 | def forward(self, input): 22 | out = self.conv_offset_mask(input) 23 | o1, o2, mask = torch.chunk(out, 3, dim=1) 24 | offset = torch.cat((o1, o2), dim=1) 25 | mask = torch.sigmoid(mask) 26 | return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride, 27 | padding=self.padding, dilation=self.dilation, mask=mask) 28 | 29 | class DeformConv123(DeformConv2d): 30 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=None): 31 | super(DeformConv, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 32 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 33 | channels_ = self.groups * 2 * self.kernel_size[0] * self.kernel_size[1] 34 | self.conv_offset = nn.Conv2d(self.in_channels, 35 | channels_, 36 | kernel_size=self.groups * 2 * self.kernel_size[0] * self.kernel_size[1], 37 | stride=self.stride, 38 | padding=self.padding, 39 | bias=True) 40 | nn.init.constant_(self.conv_offset.weight, 0) 41 | self.conv_offset.register_backward_hook(self._set_lr) 42 | self.conv_mask = nn.Conv2d(self.in_channels, 43 | self.groups * 1 * self.kernel_size[0] * self.kernel_size[1], 44 | kernel_size=self.kernel_size, 45 | stride=self.stride, 46 | padding=self.padding, 47 | bias=True) 48 | nn.init.constant_(self.conv_mask.weight, 0) 49 | self.conv_mask.register_backward_hook(self._set_lr) 50 | 51 | @staticmethod 52 | def _set_lr(module, grad_input, grad_output): 53 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 54 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 55 | 56 | def init_offset(self): 57 | n = self.in_channels 58 | for k in self.kernel_size: 59 | n *= k 60 | stdv = 1. / math.sqrt(n) 61 | self.conv_offset.weight.data.uniform_(-stdv, stdv) 62 | self.conv_offset.bias.data.zero_() 63 | 64 | def forward(self, input): 65 | offset = self.conv_offset(input) 66 | mask = torch.sigmoid(self.conv_mask(input)) 67 | return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride, 68 | padding=self.padding, dilation=self.dilation, mask=mask) -------------------------------------------------------------------------------- /utils/imsave.py: -------------------------------------------------------------------------------- 1 | import torchvision.utils as vutils 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import random 6 | import time 7 | from PIL import Image 8 | import colorsys 9 | import cv2 10 | def write_2images(image_outputs, display_image_num, image_directory, postfix): 11 | __write_images(image_outputs[:], display_image_num, '%s/gen_%s.png' % (image_directory, postfix)) 12 | 13 | def __write_images(image_outputs, display_image_num, file_name): 14 | image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels 15 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0) 16 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True) 17 | vutils.save_image(image_grid, file_name, nrow=1) 18 | 19 | def collate_func(batch_dic): 20 | output={} 21 | #for k in ['labels','ins_masks']: 22 | for k in ['cate_labels', 'ins_labels','ins_ind_labels']: 23 | output[k]=[dic[k] if dic[k] is not None else None for dic in batch_dic] 24 | output['image']=torch.stack([dic['image'] for dic in batch_dic]) 25 | return output 26 | 27 | def convert_labels(arr): 28 | w,h=arr.shape 29 | output=np.zeros([w,h,3],dtype=np.uint8) 30 | for i in np.unique(arr): 31 | if i ==0:continue 32 | output[arr==i]=[random.randint(64,255),random.randint(64,255),random.randint(64,255)] 33 | return output 34 | 35 | def _imageshow(image,pred,gt,unpaired_pred_ind,unpaired_gt_ind,title=None,cmap='tab20b'): 36 | plt.rcParams['figure.figsize'] = (8.0, 6.0) # 设置figure_size尺寸 37 | plt.rcParams['image.interpolation'] = 'nearest' # 设置 interpolation style 38 | plt.rcParams['savefig.dpi'] = 300 39 | plt.rcParams['figure.dpi'] = 300 40 | 41 | plt.subplot(231) 42 | plt.imshow(image) 43 | if title is not None: 44 | plt.title(title) 45 | plt.axis('off') 46 | plt.subplot(232) 47 | plt.imshow(convert_labels(pred)) 48 | plt.title(f'pred_num {len(np.unique(pred))}') 49 | plt.axis('off') 50 | plt.subplot(233) 51 | plt.imshow(convert_labels(gt)) 52 | plt.title(f'gt_num {len(np.unique(gt))}') 53 | plt.axis('off') 54 | plt.subplot(224) 55 | plt.axis('off') 56 | unmatched_pred=np.zeros_like(pred) 57 | for i in np.unique(unpaired_pred_ind): 58 | if i ==0: 59 | continue 60 | unmatched_pred[pred==i]=i 61 | plt.subplot(235) 62 | plt.imshow(convert_labels(unmatched_pred)) 63 | plt.title(f'unmatched_pred {len(np.unique(unpaired_pred_ind))}') 64 | plt.axis('off') 65 | unmatched_gt=np.zeros_like(gt) 66 | for i in np.unique(unpaired_gt_ind): 67 | if i ==0: 68 | continue 69 | unmatched_gt[gt==i]=i 70 | plt.subplot(236) 71 | plt.imshow(convert_labels(unmatched_gt)) 72 | plt.title(f'unmatched_gt {len(np.unique(unpaired_gt_ind))}') 73 | plt.axis('off') 74 | plt.subplots_adjust(top=0.95, bottom=0.05, right=0.99, left=0.01, hspace=0.1, wspace=0.1) 75 | plt.tight_layout() 76 | #plt.savefig('D:/{}.png'.format(time.time())) 77 | plt.show() 78 | 79 | 80 | def random_colors(N, bright=True): 81 | """Generate random colors. 82 | 83 | To get visually distinct colors, generate them in HSV space then 84 | convert to RGB. 85 | """ 86 | brightness = 1.0 if bright else 0.7 87 | hsv = [(i / N, 1, brightness) for i in range(N)] 88 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 89 | random.shuffle(colors) 90 | return colors 91 | 92 | 93 | def _imagesave(image, mask, label,save_path): 94 | colors={ 95 | 1: [255, 0, 0], 96 | 2:[0, 255, 0], 97 | 3: [0, 0, 255], 98 | 4: [255, 255, 0], 99 | 5 : [255, 165, 0] 100 | } 101 | 102 | inst_rng_colors = random_colors(len(np.unique(mask))) 103 | inst_rng_colors = np.clip(inst_rng_colors,0,255) * 255 104 | inst_rng_colors = inst_rng_colors.astype(np.uint8).tolist() 105 | 106 | image=image.copy() 107 | if image.shape[2]==4: 108 | image=image[:,:,0:3].copy() 109 | 110 | 111 | for ui in np.unique(mask): 112 | if ui == 0: continue 113 | binary = ((mask == ui) * 255).astype(np.uint8) 114 | contours, hierarchy = cv2.findContours(binary.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 115 | if label is not None: 116 | image = cv2.drawContours(image, contours, -1, tuple(colors[label[ui-1]]), 2) 117 | else: 118 | image = cv2.drawContours(image, contours, -1, tuple(inst_rng_colors[ui-1]), 2) 119 | 120 | Image.fromarray(image).save(save_path) 121 | -------------------------------------------------------------------------------- /losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def binary_focal_loss(x, y, alpha=0.25, gamma=2., reduction='none'): 7 | pt = x.detach() * (y.detach() * 2 - 1) 8 | w = (1 - pt).pow(gamma) 9 | w[y == 0] *= (1 - alpha) 10 | w[y > 0] *= alpha 11 | # a = torch.where(y < 0, alpha, (1 - alpha)) 12 | loss = F.binary_cross_entropy(x, y, w, reduction=reduction) 13 | return loss 14 | 15 | class BinaryFocalLoss(nn.Module): 16 | def __init__(self, alpha=0.25, gamma=2): 17 | super(BinaryFocalLoss, self).__init__() 18 | self.alpha = alpha 19 | self.gamma = gamma 20 | self.smooth = 1e-6 # set '1e-4' when train with FP16 21 | 22 | def forward(self, output, target): 23 | prob = torch.sigmoid(output) 24 | prob = torch.clamp(prob, self.smooth, 1.0 - self.smooth) 25 | loss=-target*(1-self.alpha)*((1-prob)**self.gamma)*torch.log(prob)-(1-target)*self.alpha*(prob**self.gamma)*torch.log(1-prob) 26 | return loss 27 | 28 | class BCELoss(nn.Module): 29 | def __init__(self): 30 | super(BCELoss, self).__init__() 31 | self.smooth = 1e-6 # set '1e-4' when train with FP16 32 | 33 | def forward(self, output, target): 34 | prob = torch.sigmoid(output) 35 | prob = torch.clamp(prob, self.smooth, 1.0 - self.smooth) 36 | loss=-target*torch.log(prob)-(1-target)*torch.log(1-prob) 37 | return loss 38 | 39 | class FocalLoss_Ori(nn.Module): 40 | """ 41 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 42 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 43 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 44 | :param num_class: 45 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 46 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 47 | focus on hard misclassified example 48 | :param smooth: (float,double) smooth value when cross entropy 49 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 50 | """ 51 | 52 | def __init__(self, num_class, alpha=[0.25, 0.75], gamma=2, balance_index=-1, size_average=True): 53 | super(FocalLoss_Ori, self).__init__() 54 | self.num_class = num_class 55 | self.alpha = alpha 56 | self.gamma = gamma 57 | self.size_average = size_average 58 | self.eps = 1e-6 59 | 60 | if isinstance(self.alpha, (list, tuple)): 61 | assert len(self.alpha) == self.num_class 62 | self.alpha = torch.Tensor(list(self.alpha)) 63 | elif isinstance(self.alpha, (float, int)): 64 | assert 0 < self.alpha < 1.0, 'alpha should be in `(0,1)`)' 65 | assert balance_index > -1 66 | alpha = torch.ones((self.num_class)) 67 | alpha *= 1 - self.alpha 68 | alpha[balance_index] = self.alpha 69 | self.alpha = alpha 70 | elif isinstance(self.alpha, torch.Tensor): 71 | self.alpha = self.alpha 72 | else: 73 | raise TypeError('Not support alpha type, expect `int|float|list|tuple|torch.Tensor`') 74 | 75 | def forward(self, logit, target): 76 | 77 | if logit.dim() > 2: 78 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 79 | logit = logit.view(logit.size(0), logit.size(1), -1) 80 | logit = logit.transpose(1, 2).contiguous() # [N,C,d1*d2..] -> [N,d1*d2..,C] 81 | logit = logit.view(-1, logit.size(-1)) # [N,d1*d2..,C]-> [N*d1*d2..,C] 82 | target = target.view(-1, 1) # [N,d1,d2,...]->[N*d1*d2*...,1] 83 | 84 | # -----------legacy way------------ 85 | # idx = target.cpu().long() 86 | # one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_() 87 | # one_hot_key = one_hot_key.scatter_(1, idx, 1) 88 | # if one_hot_key.device != logit.device: 89 | # one_hot_key = one_hot_key.to(logit.device) 90 | # pt = (one_hot_key * logit).sum(1) + epsilon 91 | 92 | # ----------memory saving way-------- 93 | pt = logit.gather(1, target).view(-1) + self.eps # avoid apply 94 | logpt = pt.log() 95 | 96 | if self.alpha.device != logpt.device: 97 | alpha = self.alpha.to(logpt.device) 98 | alpha_class = alpha.gather(0, target.view(-1)) 99 | logpt = alpha_class * logpt 100 | loss = -1 * torch.pow(torch.sub(1.0, pt), self.gamma) * logpt 101 | 102 | if self.size_average: 103 | loss = loss.mean() 104 | else: 105 | loss = loss.sum() 106 | return loss 107 | 108 | -------------------------------------------------------------------------------- /utils/matrix_nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None): 5 | """Matrix NMS for multi-class masks. 6 | Args: 7 | seg_masks (Tensor): shape (n, h, w) 8 | cate_labels (Tensor): shape (n), mask labels in descending order 9 | cate_scores (Tensor): shape (n), mask scores in descending order 10 | kernel (str): 'linear' or 'gauss' 11 | sigma (float): std in gaussian method 12 | sum_masks (Tensor): The sum of seg_masks 13 | Returns: 14 | Tensor: cate_scores_update, tensors of shape (n) 15 | """ 16 | n_samples = len(cate_labels) 17 | if n_samples == 0: 18 | return [] 19 | if sum_masks is None: 20 | sum_masks = seg_masks.sum((1, 2)).float() 21 | seg_masks = seg_masks.reshape(n_samples, -1).float() 22 | # inter. 23 | inter_matrix = torch.mm(seg_masks, seg_masks.transpose(1, 0)) 24 | # union. 25 | sum_masks_x = sum_masks.expand(n_samples, n_samples) 26 | # iou. 27 | iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1) 28 | # label_specific matrix. 29 | cate_labels_x = cate_labels.expand(n_samples, n_samples) 30 | label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1) 31 | 32 | # IoU compensation 33 | compensate_iou, _ = (iou_matrix * label_matrix).max(0) 34 | compensate_iou = compensate_iou.expand(n_samples, n_samples).transpose(1, 0) 35 | 36 | # IoU decay 37 | decay_iou = iou_matrix * label_matrix 38 | 39 | # matrix nms 40 | if kernel == 'gaussian': 41 | decay_matrix = torch.exp(-1 * sigma * (decay_iou ** 2)) 42 | compensate_matrix = torch.exp(-1 * sigma * (compensate_iou ** 2)) 43 | decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0) 44 | elif kernel == 'linear': 45 | decay_matrix = (1-decay_iou)/(1-compensate_iou) 46 | decay_coefficient, _ = decay_matrix.min(0) 47 | else: 48 | raise NotImplementedError 49 | 50 | # update the score. 51 | cate_scores_update = cate_scores * decay_coefficient 52 | return cate_scores_update 53 | 54 | 55 | def multiclass_nms(multi_bboxes, 56 | multi_scores, 57 | score_thr, 58 | nms_cfg, 59 | max_num=-1, 60 | score_factors=None): 61 | """NMS for multi-class bboxes. 62 | Args: 63 | multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) 64 | multi_scores (Tensor): shape (n, #class), where the 0th column 65 | contains scores of the background class, but this will be ignored. 66 | score_thr (float): bbox threshold, bboxes with scores lower than it 67 | will not be considered. 68 | nms_thr (float): NMS IoU threshold 69 | max_num (int): if there are more than max_num bboxes after NMS, 70 | only top max_num will be kept. 71 | score_factors (Tensor): The factors multiplied to scores before 72 | applying NMS 73 | Returns: 74 | tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels 75 | are 0-based. 76 | """ 77 | num_classes = multi_scores.shape[1] 78 | bboxes, labels = [], [] 79 | nms_cfg_ = nms_cfg.copy() 80 | nms_type = nms_cfg_.pop('type', 'nms') 81 | nms_op = getattr(nms_wrapper, nms_type) 82 | for i in range(1, num_classes): 83 | cls_inds = multi_scores[:, i] > score_thr 84 | if not cls_inds.any(): 85 | continue 86 | # get bboxes and scores of this class 87 | if multi_bboxes.shape[1] == 4: 88 | _bboxes = multi_bboxes[cls_inds, :] 89 | else: 90 | _bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4] 91 | _scores = multi_scores[cls_inds, i] 92 | if score_factors is not None: 93 | _scores *= score_factors[cls_inds] 94 | cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1) 95 | cls_dets, _ = nms_op(cls_dets, **nms_cfg_) 96 | cls_labels = multi_bboxes.new_full((cls_dets.shape[0], ), 97 | i - 1, 98 | dtype=torch.long) 99 | bboxes.append(cls_dets) 100 | labels.append(cls_labels) 101 | if bboxes: 102 | bboxes = torch.cat(bboxes) 103 | labels = torch.cat(labels) 104 | if bboxes.shape[0] > max_num: 105 | _, inds = bboxes[:, -1].sort(descending=True) 106 | inds = inds[:max_num] 107 | bboxes = bboxes[inds] 108 | labels = labels[inds] 109 | else: 110 | bboxes = multi_bboxes.new_zeros((0, 5)) 111 | labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) 112 | 113 | return bboxes, labels -------------------------------------------------------------------------------- /models/segbase.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from .backbone.resnetv1b import resnet50_v1s,resnet101_v1s 3 | from .backbone.res2netv1b import res2net50_v1b 4 | from .backbone.resnext import resnext50_32x4d,resnext101_32x8d 5 | from .backbone.seg_hrnet import hrnet_w18_v2,hrnet_w32,hrnet_w44,hrnet_w48,hrnet_w64 6 | from .backbone.resnextdcn import resnext101_32x8d_dcn 7 | from .backbone.swim_transformer import swim_large 8 | from torch.nn.modules.batchnorm import _BatchNorm 9 | __all__ = ['SegBaseModel'] 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | class SegBaseModel(nn.Module): 14 | r"""Base Model for Semantic Segmentation 15 | Parameters 16 | ---------- 17 | backbone : string 18 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 19 | 'resnet101' or 'resnet152'). 20 | resnest 21 | resnext 22 | res2net 23 | DLA 24 | """ 25 | def __init__(self, backbone='enc', pretrained_base=False, frozen_stages=-1,norm_eval=False, **kwargs): 26 | super(SegBaseModel, self).__init__() 27 | self.norm_eval=norm_eval 28 | self.frozen_stages=frozen_stages 29 | if backbone == 'resnext101dcn': 30 | self.pretrained = resnext101_32x8d_dcn(pretrained=pretrained_base, dilated=False, **kwargs) 31 | elif backbone == 'resnet50': 32 | self.pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=False, **kwargs) 33 | elif backbone == 'hrnet18': 34 | self.pretrained = hrnet_w18_v2(pretrained=pretrained_base, dilated=False, **kwargs) 35 | elif backbone =='hrnet32': 36 | self.pretrained=hrnet_w32(pretrained=pretrained_base, **kwargs) 37 | elif backbone =='hrnet64': 38 | self.pretrained=hrnet_w64(pretrained=pretrained_base, **kwargs) 39 | elif backbone == 'resnet101': 40 | self.pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=False, **kwargs) 41 | elif backbone == 'resnext101': 42 | self.pretrained = resnext101_32x8d(pretrained=pretrained_base, dilated=False, **kwargs) 43 | elif backbone == 'res2net50': 44 | self.pretrained = res2net50_v1b(pretrained=pretrained_base, **kwargs) 45 | elif backbone == 'resnext50': 46 | self.pretrained = resnext50_32x4d(pretrained=pretrained_base, dilated=False, **kwargs) 47 | elif backbone == 'swin': 48 | self.pretrained = swim_large(pretrained=False) 49 | else: 50 | raise RuntimeError('unknown backbone: {}'.format(backbone)) 51 | self._train() 52 | 53 | 54 | def set_requires_grad(self, nets, requires_grad=False): 55 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 56 | Parameters: 57 | nets (network list) -- a list of networks 58 | requires_grad (bool) -- whether the networks require gradients or not 59 | """ 60 | if not isinstance(nets, list): 61 | nets = [nets] 62 | for net in nets: 63 | if net is not None: 64 | for param in net.parameters(): 65 | param.requires_grad = requires_grad 66 | 67 | def freeze(self): 68 | self.set_requires_grad(self.pretrained,False) 69 | self.set_requires_grad([self.pretrained.conv1,self.pretrained.bn1],True) 70 | 71 | def unfreeze(self): 72 | self.set_requires_grad([self.pretrained],True) 73 | 74 | def _train(self, mode=True): 75 | super(SegBaseModel, self).train(mode) 76 | self._freeze_stages() 77 | if mode and self.norm_eval: 78 | print('Freeze backbone BN using running mean and std') 79 | for m in self.modules(): 80 | # trick: eval have effect on BatchNorm only 81 | if isinstance(m, _BatchNorm): 82 | m.eval() 83 | 84 | def _freeze_stages(self): 85 | if self.frozen_stages >= 0: 86 | self.pretrained.bn1.eval() 87 | for m in [self.pretrained.conv1, self.pretrained.bn1]: 88 | for param in m.parameters(): 89 | param.requires_grad = False 90 | if hasattr(self.pretrained, 'conv2'): 91 | self.pretrained.bn2.eval() 92 | for m in [self.pretrained.conv2, self.pretrained.bn2]: 93 | for param in m.parameters(): 94 | param.requires_grad = False 95 | 96 | print(f'Freezing backbone stage {self.frozen_stages}') 97 | for i in range(1, self.frozen_stages + 1): 98 | m = getattr(self.pretrained, 'layer{}'.format(i)) 99 | m.eval() 100 | for param in m.parameters(): 101 | param.requires_grad = False 102 | 103 | # trick: only train conv1 since not use ImageNet mean and std image norm 104 | if hasattr(self.pretrained,'conv1'): 105 | print('active train conv1 and bn1') 106 | self.set_requires_grad([self.pretrained.conv1,self.pretrained.bn1],True) 107 | 108 | def base_forward(self, x): 109 | """forwarding pre-trained network""" 110 | x = self.pretrained.conv1(x) 111 | x = self.pretrained.bn1(x) 112 | x = self.pretrained.relu(x) 113 | c1 = self.pretrained.maxpool(x) 114 | c2 = self.pretrained.layer1(c1) 115 | c3 = self.pretrained.layer2(c2) 116 | c4 = self.pretrained.layer3(c3) 117 | c5 = self.pretrained.layer4(c4) 118 | return c1, c2, c3, c4, c5 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /losses/generalise_focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def reduce_loss(loss, reduction): 5 | """Reduce loss as specified. 6 | Args: 7 | loss (Tensor): Elementwise loss tensor. 8 | reduction (str): Options are "none", "mean" and "sum". 9 | Return: 10 | Tensor: Reduced loss tensor. 11 | """ 12 | reduction_enum = F._Reduction.get_enum(reduction) 13 | # none: 0, elementwise_mean:1, sum: 2 14 | if reduction_enum == 0: 15 | return loss 16 | elif reduction_enum == 1: 17 | return loss.mean() 18 | elif reduction_enum == 2: 19 | return loss.sum() 20 | 21 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 22 | """Apply element-wise weight and reduce loss. 23 | Args: 24 | loss (Tensor): Element-wise loss. 25 | weight (Tensor): Element-wise weights. 26 | reduction (str): Same as built-in losses of PyTorch. 27 | avg_factor (float): Avarage factor when computing the mean of losses. 28 | Returns: 29 | Tensor: Processed loss values. 30 | """ 31 | # if weight is specified, apply element-wise weight 32 | if weight is not None: 33 | loss = loss * weight 34 | 35 | # if avg_factor is not specified, just reduce the loss 36 | if avg_factor is None: 37 | loss = reduce_loss(loss, reduction) 38 | else: 39 | # if reduction is mean, then average the loss by avg_factor 40 | if reduction == 'mean': 41 | loss = loss.sum() / avg_factor 42 | # if reduction is 'none', then do nothing, otherwise raise an error 43 | elif reduction != 'none': 44 | raise ValueError('avg_factor can not be used with reduction="sum"') 45 | return loss 46 | 47 | 48 | def quality_focal_loss( 49 | pred, # (n, 80) 50 | label, # (n) 0, 1-80: 0 is neg, 1-80 is positive 51 | score, # (n) reg target 0-1, only positive is good 52 | weight=None, 53 | beta=2.0, 54 | reduction='mean', 55 | avg_factor=None): 56 | # all goes to 0 57 | pred_sigmoid = pred.sigmoid() 58 | pt = pred_sigmoid 59 | zerolabel = pt.new_zeros(pred.shape) 60 | loss = F.binary_cross_entropy_with_logits( 61 | pred, zerolabel, reduction='none') * pt.pow(beta) 62 | 63 | label = label - 1 64 | pos = (label >= 0).nonzero().squeeze(1) 65 | a = pos 66 | b = label[pos].long() 67 | 68 | # positive goes to bbox quality 69 | pt = score[a] - pred_sigmoid[a, b] 70 | loss[a, b] = F.binary_cross_entropy_with_logits( 71 | pred[a, b], score[a], reduction='none') * pt.pow(beta) 72 | 73 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 74 | return loss 75 | 76 | 77 | def distribution_focal_loss( 78 | pred, 79 | label, 80 | weight=None, 81 | reduction='mean', 82 | avg_factor=None): 83 | disl = label.long() 84 | disr = disl + 1 85 | 86 | wl = disr.float() - label 87 | wr = label - disl.float() 88 | 89 | loss = F.cross_entropy(pred, disl, reduction='none') * wl \ 90 | + F.cross_entropy(pred, disr, reduction='none') * wr 91 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 92 | return loss 93 | 94 | class QualityFocalLoss(nn.Module): 95 | def __init__(self, 96 | use_sigmoid=True, 97 | beta=2.0, 98 | reduction='mean', 99 | loss_weight=1.0): 100 | super(QualityFocalLoss, self).__init__() 101 | assert use_sigmoid is True, 'Only sigmoid in QFL supported now.' 102 | self.use_sigmoid = use_sigmoid 103 | self.beta = beta 104 | self.reduction = reduction 105 | self.loss_weight = loss_weight 106 | 107 | def forward(self, 108 | pred, 109 | target, 110 | score, 111 | weight=None, 112 | avg_factor=None, 113 | reduction_override=None): 114 | assert reduction_override in (None, 'none', 'mean', 'sum') 115 | reduction = ( 116 | reduction_override if reduction_override else self.reduction) 117 | if self.use_sigmoid: 118 | loss_cls = self.loss_weight * quality_focal_loss( 119 | pred, 120 | target, 121 | score, 122 | weight, 123 | beta=self.beta, 124 | reduction=reduction, 125 | avg_factor=avg_factor) 126 | else: 127 | raise NotImplementedError 128 | return loss_cls 129 | 130 | class DistributionFocalLoss(nn.Module): 131 | 132 | def __init__(self, 133 | reduction='mean', 134 | loss_weight=1.0): 135 | super(DistributionFocalLoss, self).__init__() 136 | self.reduction = reduction 137 | self.loss_weight = loss_weight 138 | 139 | def forward(self, 140 | pred, 141 | target, 142 | weight=None, 143 | avg_factor=None, 144 | reduction_override=None): 145 | assert reduction_override in (None, 'none', 'mean', 'sum') 146 | reduction = ( 147 | reduction_override if reduction_override else self.reduction) 148 | loss_cls = self.loss_weight * distribution_focal_loss( 149 | pred, 150 | target, 151 | weight, 152 | reduction=reduction, 153 | avg_factor=avg_factor) 154 | return loss_cls -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointNu-Net 2 | 3 | ## PointNu-Net: Keypoint-assisted Convolutional Neural Network for Simultaneous Multi-tissue Histology Nuclei Segmentation and Classification. [Ppaer](https://ieeexplore.ieee.org/abstract/document/10148651) [ArXiv](https://arxiv.org/pdf/2111.01557.pdf) 4 | Kai Yao, Kaizhu Huang, Jie Sun, Amir Hussain, Curran Jude \ 5 | Both University of Liverpool and Xi'an Jiaotong-liverpool University 6 | 7 | **Abstract** 8 | 9 | Automatic nuclei segmentation and classification play a vital role in digital pathology. However, previous works are mostly built on data with limited diversity and small sizes, making the results questionable or misleading in actual downstream tasks. In this paper, we aim to build a reliable and robust method capable of dealing with data from the ‘the clinical wild’. Specifically, we study and design a new method to simultaneously detect, segment, and classify nuclei from Haematoxylin and Eosin (H\&E) stained histopathology data, and evaluate our approach using the recent largest dataset: PanNuke. We address the detection and classification of each nuclei as a novel semantic keypoint estimation problem to determine the center point of each nuclei. Next, the corresponding class-agnostic masks for nuclei center points are obtained using dynamic instance segmentation. Meanwhile, we proposed a novel Joint Pyramid Fusion Module (JPFM) to model the cross-scale dependencies, thus enhancing the local feature for better nuclei detection and classification. By decoupling two simultaneous challenging tasks and taking advantage of JPFM, our method can benefit from class-aware detection and class-agnostic segmentation, thus leading to a significant performance boost. We demonstrate the superior performance of our proposed approach for nuclei segmentation and classification across 19 different tissue types, delivering new benchmark results. 10 | 11 | ## News: 12 | \[2023/12/25\] We add additional experimental results on CPM17 (PQ 0.706, AJI 0.712). \ 13 | \[2023/5/15\] Our paper have been accepted by TETCI. \ 14 | \[2023/5/1\] We release the training and inference code, and the training instruction. 15 | 16 | 17 | ## 1. Installation 18 | 19 | Clone this repo. 20 | ```bash 21 | git clone https://github.com/Kaiseem/PointNu-Net.git 22 | cd PointNu-Net/ 23 | ``` 24 | 25 | This code requires PyTorch 1.10+ and python 3+. Please install dependencies by 26 | ```bash 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | 31 | ## 2. Data preparation 32 | 33 | For small dataset Kumar and CoNSeP, we conduct datasets preparation following [Hover-Net](https://github.com/vqdang/hover_net). 34 | 35 | We provide the [processed Kumar and CoNSeP datasets](https://drive.google.com/file/d/1_eI_ii6xcNe_77NWx7Qo8_KndK5UwPBO/view?usp=sharing). Also, we provide the [processed CPM17 dataset](https://drive.google.com/file/d/1igsZ2oBUmylPSzhsXn_3TCOWZNe_Zfd8/view?usp=sharing). 36 | 37 | The [PanNuKe](https://arxiv.org/pdf/2003.10778v7.pdf) datasets can be found [here](https://warwick.ac.uk/fac/sci/dcs/research/tia/data/pannuke) 38 | 39 | Download and unzip all the files where the folder structure should look this: 40 | 41 | ```none 42 | PointNu-Net 43 | ├── ... 44 | ├── datasets 45 | │ ├── kumar 46 | │ │ ├── train 47 | │ │ ├── test 48 | │ ├── CoNSeP 49 | │ │ ├── train 50 | │ │ ├── test 51 | │ ├── PanNuKe 52 | │ │ ├── images 53 | │ │ │ ├── fold1 54 | │ │ │ │ ├── images.npy 55 | │ │ │ │ ├── types.npy 56 | │ │ │ ├── fold2 57 | │ │ │ │ ├── images.npy 58 | │ │ │ │ ├── types.npy 59 | │ │ │ ├── fold3 60 | │ │ │ │ ├── images.npy 61 | │ │ │ │ ├── types.npy 62 | │ │ ├── masks 63 | │ │ │ ├── fold1 64 | │ │ │ │ ├── masks.npy 65 | │ │ │ ├── fold2 66 | │ │ │ │ ├── masks.npy 67 | │ │ │ ├── fold3 68 | │ │ │ │ ├── masks.npy 69 | ├── ... 70 | ``` 71 | 72 | ## 3. Training and Inference 73 | To reproduce the performance, you need one 3090 GPU at least. 74 | 75 | Download ImageNet Pretrain Checkpoints from official [HRNet](https://github.com/HRNet/HRNet-Image-Classification). 76 | 77 |
78 | 79 | 1) Kumar Dataset 80 | 81 | 82 | run the command to train the model 83 | ```bash 84 | python train.py --name=kumar_exp --seed=888 --config=configs/kumar_notype_large.yaml 85 | ``` 86 | 87 | run the command to inference 88 | ```bash 89 | python inference.py --name=kumar_exp 90 | ``` 91 |
92 | 93 |
94 | 95 | 2) CoNSeP Dataset 96 | 97 | 98 | run the command to train the model 99 | ```bash 100 | python train.py --name=consep_exp --seed=888 --config=configs/consep_type_large.yaml 101 | ``` 102 | 103 | run the command to inference 104 | ```bash 105 | python inference.py --name=consep_exp 106 | ``` 107 |
108 | 109 | 110 |
111 | 112 | 2) PanNuKe Dataset 113 | 114 | 115 | run the command to train the model 116 | ```bash 117 | python train_pannuke.py --name=pannuke_exp --seed=888 --train_fold={} --val_fold={} --test_fold={} 118 | ``` 119 | [train_fold, val_fold, test_fold] should be selected from {[1, 2, 3], [2, 1, 3], [3, 2, 1]} 120 | 121 | run the command to inference the model 122 | ```bash 123 | python infer_pannuke.py --name=pannuke_exp --train_fold={} --test_fold={} 124 | ``` 125 | 126 | run the command to evaluate the performance 127 | ```bash 128 | python eval_pannuke.py --name=pannuke_exp --train_fold={} --test_fold={} 129 | ``` 130 | 131 |
132 | 133 | 134 | ## Citation 135 | If our work or code helps you, please consider to cite our paper. Thank you! 136 | 137 | ``` 138 | @article{yao2023pointnu, 139 | title={PointNu-Net: Keypoint-Assisted Convolutional Neural Network for Simultaneous Multi-Tissue Histology Nuclei Segmentation and Classification}, 140 | author={Yao, Kai and Huang, Kaizhu and Sun, Jie and Hussain, Amir}, 141 | journal={IEEE Transactions on Emerging Topics in Computational Intelligence}, 142 | year={2023}, 143 | } 144 | ``` 145 | -------------------------------------------------------------------------------- /utils/imop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage.interpolation import shift 3 | import cv2 4 | from collections import Counter 5 | import math 6 | from scipy import ndimage 7 | def get_ins_info(seg_mask,method='bbox'): 8 | methods=['bbox','circle','area'] 9 | assert method in methods, f'instance segmentation information should in {methods}' 10 | if method=='circle': 11 | contours, hierachy = cv2.findContours((seg_mask * 255).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 12 | (center_w, center_h), EC_radius = cv2.minEnclosingCircle(contours[0]) 13 | return center_w,center_h,EC_radius*2,EC_radius*2 14 | elif method=='bbox': 15 | bbox_x, bbox_y, bbox_w, bbox_h = cv2.boundingRect(np.array(seg_mask).astype(np.uint8)) 16 | center_w = bbox_x + bbox_w / 2 17 | center_h = bbox_y + bbox_h / 2 18 | return center_w, center_h, bbox_w, bbox_h 19 | elif method=='area': 20 | center_h, center_w = ndimage.measurements.center_of_mass(seg_mask) 21 | equal_diameter=(np.sum(seg_mask)/3.1415)**0.5*2 22 | return center_w,center_h,equal_diameter,equal_diameter 23 | else: 24 | raise NotImplementedError 25 | 26 | def gaussian_radius(det_size, min_overlap=0.7): 27 | 28 | #https://github.com/princeton-vl/CornerNet/blob/master/sample/utils.py 29 | height, width = det_size 30 | 31 | a1 = 1 32 | b1 = (height + width) 33 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 34 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) 35 | r1 = (b1 - sq1) / 2 36 | 37 | a2 = 4 38 | b2 = 2 * (height + width) 39 | c2 = (1 - min_overlap) * width * height 40 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) 41 | r2 = (b2 - sq2) / 2 42 | 43 | a3 = 4 * min_overlap 44 | b3 = -2 * min_overlap * (height + width) 45 | c3 = (min_overlap - 1) * width * height 46 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) 47 | r3 = (b3 + sq3) / 2 48 | 49 | return min(r1, r2, r3) 50 | 51 | def gaussian_radius_new(det_size, min_overlap=0.7): 52 | #https://github.com/princeton-vl/CornerNet/blob/master/sample/utils.py 53 | height, width = det_size 54 | 55 | a1 = 1 56 | b1 = (height + width) 57 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 58 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) 59 | r1 = (b1 - sq1) / (2 * a1) 60 | 61 | a2 = 4 62 | b2 = 2 * (height + width) 63 | c2 = (1 - min_overlap) * width * height 64 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) 65 | r2 = (b2 - sq2) / (2 * a2) 66 | 67 | a3 = 4 * min_overlap 68 | b3 = -2 * min_overlap * (height + width) 69 | c3 = (min_overlap - 1) * width * height 70 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) 71 | r3 = (b3 + sq3) / (2 * a3) 72 | 73 | return min(r1, r2, r3) 74 | 75 | def gaussian2D(shape, sigma=1.): 76 | m, n = [(ss - 1.) / 2. for ss in shape] 77 | y, x = np.ogrid[-m:m+1,-n:n+1] 78 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 79 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 80 | return h 81 | 82 | def draw_gaussian(heatmap, center, radius): 83 | diameter = float(2 * (int(radius)+1) + 1) 84 | gaussian = gaussian2D((diameter, diameter), sigma = radius/3)#gaussian_2d_kernel(int(diameter),radius/3)# 85 | coord_w, coord_h = center 86 | height, width = heatmap.shape 87 | temp=np.zeros((height,width), dtype=np.float) 88 | temp = insert_image(temp, gaussian, coord_h, coord_w) 89 | np.maximum(heatmap, temp, out=heatmap) 90 | return temp 91 | 92 | def gaussian_2d_kernel(kernel_size=3, sigma=0): 93 | kernel = np.zeros([kernel_size, kernel_size]) 94 | center = kernel_size // 2 95 | if sigma == 0: 96 | sigma = ((kernel_size - 1) * 0.5 - 1) * 0.3 + 0.8 97 | s = 2 * (sigma ** 2) 98 | 99 | for i in range(0, kernel_size): 100 | for j in range(0, kernel_size): 101 | x = i - center 102 | y = j - center 103 | kernel[i, j] = np.exp(-(x ** 2 + y ** 2) / s) 104 | return kernel 105 | 106 | def insert_image(img, kernel,h,w): 107 | ks=kernel.shape[0] 108 | if ks !=0: 109 | half_ks=ks//2 110 | img=np.pad(img,((half_ks,half_ks),(half_ks,half_ks))) 111 | img[h:h+ks,w:w+ks]=kernel 112 | return img[half_ks:-half_ks,half_ks:-half_ks] 113 | else: 114 | img[h:h+1,w:w+1]=kernel 115 | return img 116 | 117 | def label_relaxation(msk, border_window=3): 118 | msk=msk.astype(np.float) 119 | border=border_window//2 120 | output=np.zeros_like(msk) 121 | for i in range(-border,border+1,1): 122 | for j in range(-border, border + 1, 1): 123 | output+=shift(msk,shift=[i,j],cval=0) 124 | output/=border_window**2 125 | return output 126 | 127 | def ensemble_img(ins_list,cate_list,score_list): 128 | N,w,h=ins_list.shape 129 | output_img=np.zeros((w,h),dtype=np.int16) 130 | #print(ins_list.shape,cate_list.shape,score_list.shape) 131 | for i in range(N): 132 | ins_num = i+1 133 | ins_=ins_list[i] 134 | score_=score_list[i] 135 | if np.sum(np.logical_and(output_img>0,ins_))==0: 136 | output_img=np.where(output_img>0,output_img,ins_*ins_num) 137 | else: 138 | compared_num,_ = Counter((output_img*ins_).flatten()).most_common(2)[1] 139 | assert compared_num>0 140 | #print( Counter((output_img*ins_).flatten()).most_common(2)) 141 | compared_num=int(compared_num) 142 | compared_score=score_list[compared_num-1] 143 | if np.sum(np.logical_and(output_img==compared_num,ins_))/np.sum(np.logical_or(output_img==compared_num,ins_))>0.5: 144 | if compared_score>score_: 145 | pass 146 | else: 147 | output_img[output_img==compared_num]=0 148 | output_img=np.where(output_img>0,output_img,ins_*ins_num) 149 | else: 150 | output_img = np.where(output_img > 0, output_img, ins_ * ins_num) 151 | return output_img -------------------------------------------------------------------------------- /models/backbone/resnext.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | __all__ = ['ResNext', 'resnext50_32x4d', 'resnext101_32x8d'] 5 | 6 | model_urls = { 7 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 8 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 9 | } 10 | 11 | 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 16 | base_width=64, dilation=1, norm_layer=None, **kwargs): 17 | super(Bottleneck, self).__init__() 18 | width = int(planes * (base_width / 64.)) * groups 19 | 20 | self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False) 21 | self.bn1 = norm_layer(width) 22 | self.conv2 = nn.Conv2d(width, width, 3, stride, dilation, dilation, groups, bias=False) 23 | self.bn2 = norm_layer(width) 24 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False) 25 | self.bn3 = norm_layer(planes * self.expansion) 26 | self.relu = nn.ReLU(True) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | identity = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv3(out) 42 | out = self.bn3(out) 43 | 44 | if self.downsample is not None: 45 | identity = self.downsample(x) 46 | 47 | out += identity 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class ResNext(nn.Module): 54 | 55 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, 56 | width_per_group=64, dilated=False, norm_layer=nn.BatchNorm2d,deep_stem=False, **kwargs): 57 | super(ResNext, self).__init__() 58 | self.inplanes = 64 59 | self.groups = groups 60 | self.base_width = width_per_group 61 | if deep_stem: 62 | self.conv1 = nn.Sequential( 63 | nn.Conv2d(3, 64, 3, 2, 1, bias=False), 64 | norm_layer(64), 65 | nn.ReLU(True), 66 | nn.Conv2d(64, 64, 3, 1, 1, bias=False), 67 | norm_layer(64), 68 | nn.ReLU(True), 69 | nn.Conv2d(64, 128, 3, 1, 1, bias=False) 70 | ) 71 | else: 72 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 73 | self.bn1 = norm_layer(self.inplanes) 74 | self.relu = nn.ReLU(True) 75 | self.maxpool = nn.MaxPool2d(3, 2, 1) 76 | 77 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 78 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 79 | if dilated: 80 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer) 81 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer) 82 | else: 83 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 84 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 85 | 86 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 87 | self.fc = nn.Linear(512 * block.expansion, num_classes) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | elif isinstance(m, nn.BatchNorm2d): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | if zero_init_residual: 97 | for m in self.modules(): 98 | if isinstance(m, Bottleneck): 99 | nn.init.constant_(m.bn3.weight, 0) 100 | 101 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): 102 | downsample = None 103 | if stride != 1 or self.inplanes != planes * block.expansion: 104 | downsample = nn.Sequential( 105 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False), 106 | norm_layer(planes * block.expansion) 107 | ) 108 | 109 | layers = list() 110 | if dilation in (1, 2): 111 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 112 | self.base_width, norm_layer=norm_layer)) 113 | elif dilation == 4: 114 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 115 | self.base_width, dilation=2, norm_layer=norm_layer)) 116 | else: 117 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 118 | self.inplanes = planes * block.expansion 119 | for _ in range(1, blocks): 120 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, 121 | dilation=dilation, norm_layer=norm_layer)) 122 | 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, x): 126 | x = self.conv1(x) 127 | x = self.bn1(x) 128 | x = self.relu(x) 129 | x = self.maxpool(x) 130 | 131 | x = self.layer1(x) 132 | x = self.layer2(x) 133 | x = self.layer3(x) 134 | x = self.layer4(x) 135 | 136 | x = self.avgpool(x) 137 | x = x.view(x.size(0), -1) 138 | x = self.fc(x) 139 | 140 | return x 141 | 142 | 143 | def resnext50_32x4d(pretrained=False, **kwargs): 144 | kwargs['groups'] = 32 145 | kwargs['width_per_group'] = 4 146 | model = ResNext(Bottleneck, [3, 4, 6, 3], **kwargs) 147 | if pretrained: 148 | state_dict = model_zoo.load_url(model_urls['resnext50_32x4d']) 149 | model.load_state_dict(state_dict) 150 | return model 151 | 152 | 153 | def resnext101_32x8d(pretrained=False, **kwargs): 154 | kwargs['groups'] = 32 155 | kwargs['width_per_group'] = 8 156 | model = ResNext(Bottleneck, [3, 4, 23, 3], **kwargs) 157 | if pretrained: 158 | state_dict = model_zoo.load_url(model_urls['resnext101_32x8d']) 159 | model.load_state_dict(state_dict) 160 | return model 161 | 162 | 163 | if __name__ == '__main__': 164 | model = resnext101_32x8d() -------------------------------------------------------------------------------- /models/backbone/resnextdcn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | from ..base_ops.DCNv2 import DeformConv 4 | __all__ = ['ResNext', 'resnext101_32x8d_dcn'] 5 | 6 | model_urls = { 7 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 8 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 9 | } 10 | 11 | class Bottleneck(nn.Module): 12 | expansion = 4 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 15 | base_width=64, dilation=1, norm_layer=None, use_dcn=False, **kwargs): 16 | super(Bottleneck, self).__init__() 17 | width = int(planes * (base_width / 64.)) * groups 18 | 19 | self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False) 20 | self.bn1 = norm_layer(width) 21 | if use_dcn: 22 | self.conv2 = DeformConv(width, width, kernel_size=3, stride=stride, 23 | padding=dilation,dilation=dilation, groups=groups) 24 | 25 | # self.conv2.conv_offset.weight.data.zero_() 26 | # self.conv2.conv_offset.bias.data.zero_() 27 | else: 28 | self.conv2 = nn.Conv2d(width, width, 3, stride, dilation, dilation, groups, bias=False) 29 | self.bn2 = norm_layer(width) 30 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False) 31 | self.bn3 = norm_layer(planes * self.expansion) 32 | self.relu = nn.ReLU(True) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | identity = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv3(out) 48 | out = self.bn3(out) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class ResNext(nn.Module): 60 | 61 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, 62 | width_per_group=64, norm_layer=nn.BatchNorm2d,deep_stem=False, **kwargs): 63 | super(ResNext, self).__init__() 64 | 65 | self.inplanes = 64 66 | self.groups = groups 67 | self.base_width = width_per_group 68 | if deep_stem: 69 | self.conv1 = nn.Sequential( 70 | nn.Conv2d(3, 64, 3, 2, 1, bias=False), 71 | norm_layer(64), 72 | nn.ReLU(True), 73 | nn.Conv2d(64, 64, 3, 1, 1, bias=False), 74 | norm_layer(64), 75 | nn.ReLU(True), 76 | nn.Conv2d(64, 128, 3, 1, 1, bias=False) 77 | ) 78 | else: 79 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 80 | self.bn1 = norm_layer(self.inplanes) 81 | self.relu = nn.ReLU(True) 82 | self.maxpool = nn.MaxPool2d(3, 2, 1) 83 | 84 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer,use_dcn=False) 85 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer,use_dcn=True) 86 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer,use_dcn=True) 87 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer,use_dcn=True) 88 | 89 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 90 | self.fc = nn.Linear(512 * block.expansion, num_classes) 91 | 92 | for m in self.modules(): 93 | if isinstance(m, nn.Conv2d): 94 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 95 | elif isinstance(m, nn.BatchNorm2d): 96 | nn.init.constant_(m.weight, 1) 97 | nn.init.constant_(m.bias, 0) 98 | 99 | if zero_init_residual: 100 | for m in self.modules(): 101 | if isinstance(m, Bottleneck): 102 | nn.init.constant_(m.bn3.weight, 0) 103 | 104 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d,use_dcn=False): 105 | downsample = None 106 | if stride != 1 or self.inplanes != planes * block.expansion: 107 | downsample = nn.Sequential( 108 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False), 109 | norm_layer(planes * block.expansion) 110 | ) 111 | 112 | layers = list() 113 | if dilation in (1, 2): 114 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 115 | self.base_width, norm_layer=norm_layer,use_dcn=use_dcn)) 116 | elif dilation == 4: 117 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 118 | self.base_width, dilation=2, norm_layer=norm_layer,use_dcn=use_dcn)) 119 | else: 120 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 121 | 122 | self.inplanes = planes * block.expansion 123 | for _ in range(1, blocks): 124 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, 125 | dilation=1, norm_layer=norm_layer,use_dcn=use_dcn)) 126 | 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | x = self.maxpool(x) 134 | 135 | x = self.layer1(x) 136 | x = self.layer2(x) 137 | x = self.layer3(x) 138 | x = self.layer4(x) 139 | 140 | x = self.avgpool(x) 141 | x = x.view(x.size(0), -1) 142 | x = self.fc(x) 143 | 144 | return x 145 | 146 | def resnext101_32x8d_dcn(pretrained=False, **kwargs): 147 | kwargs['groups'] = 32 148 | kwargs['width_per_group'] = 8 149 | model = ResNext(Bottleneck, [3, 4, 23, 3], **kwargs) 150 | if pretrained: 151 | old_dict = model_zoo.load_url(model_urls['resnext101_32x8d']) 152 | 153 | model_dict = model.state_dict() 154 | for i,(k,v) in enumerate(old_dict.items()): 155 | if k not in model_dict: 156 | print(k) 157 | pass 158 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 159 | model_dict.update(old_dict) 160 | model.load_state_dict(model_dict) 161 | 162 | return model 163 | 164 | 165 | if __name__ == '__main__': 166 | model = resnext101_32x8d_dcn() -------------------------------------------------------------------------------- /models/backbone/res2next.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | import torch 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | __all__ = ['res2next50'] 10 | model_urls = { 11 | 'res2next50': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2next50_4s-6ef7e7bf.pth', 12 | } 13 | 14 | class Bottle2neckX(nn.Module): 15 | expansion = 4 16 | 17 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None, scale = 4, stype='normal'): 18 | """ Constructor 19 | Args: 20 | inplanes: input channel dimensionality 21 | planes: output channel dimensionality 22 | baseWidth: base width. 23 | cardinality: num of convolution groups. 24 | stride: conv stride. Replaces pooling layer. 25 | scale: number of scale. 26 | type: 'normal': normal set. 'stage': frist blokc of a new stage. 27 | """ 28 | super(Bottle2neckX, self).__init__() 29 | 30 | D = int(math.floor(planes * (baseWidth/64.0))) 31 | C = cardinality 32 | 33 | self.conv1 = nn.Conv2d(inplanes, D*C*scale, kernel_size=1, stride=1, padding=0, bias=False) 34 | self.bn1 = nn.BatchNorm2d(D*C*scale) 35 | 36 | if scale == 1: 37 | self.nums = 1 38 | else: 39 | self.nums = scale -1 40 | if stype == 'stage': 41 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 42 | convs = [] 43 | bns = [] 44 | for i in range(self.nums): 45 | convs.append(nn.Conv2d(D*C, D*C, kernel_size=3, stride = stride, padding=1, groups=C, bias=False)) 46 | bns.append(nn.BatchNorm2d(D*C)) 47 | self.convs = nn.ModuleList(convs) 48 | self.bns = nn.ModuleList(bns) 49 | 50 | self.conv3 = nn.Conv2d(D*C*scale, planes * 4, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * 4) 52 | self.relu = nn.ReLU(inplace=True) 53 | 54 | self.downsample = downsample 55 | self.width = D*C 56 | self.stype = stype 57 | self.scale = scale 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | spx = torch.split(out, self.width, 1) 67 | for i in range(self.nums): 68 | if i==0 or self.stype=='stage': 69 | sp = spx[i] 70 | else: 71 | sp = sp + spx[i] 72 | sp = self.convs[i](sp) 73 | sp = self.relu(self.bns[i](sp)) 74 | if i==0: 75 | out = sp 76 | else: 77 | out = torch.cat((out, sp), 1) 78 | if self.scale != 1 and self.stype=='normal': 79 | out = torch.cat((out, spx[self.nums]),1) 80 | elif self.scale != 1 and self.stype=='stage': 81 | out = torch.cat((out, self.pool(spx[self.nums])),1) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class Res2NeXt(nn.Module): 96 | def __init__(self, block, baseWidth, cardinality, layers, num_classes, scale=4): 97 | """ Constructor 98 | Args: 99 | baseWidth: baseWidth for ResNeXt. 100 | cardinality: number of convolution groups. 101 | layers: config of layers, e.g., [3, 4, 6, 3] 102 | num_classes: number of classes 103 | scale: scale in res2net 104 | """ 105 | super(Res2NeXt, self).__init__() 106 | 107 | self.cardinality = cardinality 108 | self.baseWidth = baseWidth 109 | self.num_classes = num_classes 110 | self.inplanes = 64 111 | self.output_size = 64 112 | self.scale = scale 113 | 114 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 118 | self.layer1 = self._make_layer(block, 64, layers[0]) 119 | self.layer2 = self._make_layer(block, 128, layers[1], 2) 120 | self.layer3 = self._make_layer(block, 256, layers[2], 2) 121 | self.layer4 = self._make_layer(block, 512, layers[3], 2) 122 | self.avgpool = nn.AdaptiveAvgPool2d(1) 123 | self.fc = nn.Linear(512 * block.expansion, num_classes) 124 | 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 128 | m.weight.data.normal_(0, math.sqrt(2. / n)) 129 | elif isinstance(m, nn.BatchNorm2d): 130 | m.weight.data.fill_(1) 131 | m.bias.data.zero_() 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | nn.Conv2d(self.inplanes, planes * block.expansion, 138 | kernel_size=1, stride=stride, bias=False), 139 | nn.BatchNorm2d(planes * block.expansion), 140 | ) 141 | 142 | layers = [] 143 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample, scale=self.scale, stype='stage')) 144 | self.inplanes = planes * block.expansion 145 | for i in range(1, blocks): 146 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, scale=self.scale)) 147 | 148 | return nn.Sequential(*layers) 149 | 150 | def forward(self, x): 151 | x = self.conv1(x) 152 | x = self.bn1(x) 153 | x = self.relu(x) 154 | x = self.maxpool1(x) 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | x = self.avgpool(x) 160 | x = x.view(x.size(0), -1) 161 | x = self.fc(x) 162 | 163 | return x 164 | def res2next50(pretrained=False, **kwargs): 165 | """ Construct Res2NeXt-50. 166 | The default scale is 4. 167 | Args: 168 | pretrained (bool): If True, returns a model pre-trained on ImageNet 169 | """ 170 | model = Res2NeXt(Bottle2neckX, layers = [3, 4, 6, 3], baseWidth = 4, cardinality=8, scale = 4, num_classes=1000) 171 | if pretrained: 172 | model.load_state_dict(model_zoo.load_url(model_urls['res2next50'])) 173 | return model 174 | 175 | if __name__ == '__main__': 176 | images = torch.rand(1, 3, 224, 224).cuda(0) 177 | model = res2next50(pretrained=True) 178 | model = model.cuda(0) 179 | print(model(images).size()) 180 | -------------------------------------------------------------------------------- /utils/fmix.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | from scipy.stats import beta 6 | 7 | 8 | def fftfreqnd(h, w=None, z=None): 9 | """ Get bin values for discrete fourier transform of size (h, w, z) 10 | :param h: Required, first dimension size 11 | :param w: Optional, second dimension size 12 | :param z: Optional, third dimension size 13 | """ 14 | fz = fx = 0 15 | fy = np.fft.fftfreq(h) 16 | 17 | if w is not None: 18 | fy = np.expand_dims(fy, -1) 19 | 20 | if w % 2 == 1: 21 | fx = np.fft.fftfreq(w)[: w // 2 + 2] 22 | else: 23 | fx = np.fft.fftfreq(w)[: w // 2 + 1] 24 | 25 | if z is not None: 26 | fy = np.expand_dims(fy, -1) 27 | if z % 2 == 1: 28 | fz = np.fft.fftfreq(z)[:, None] 29 | else: 30 | fz = np.fft.fftfreq(z)[:, None] 31 | 32 | return np.sqrt(fx * fx + fy * fy + fz * fz) 33 | 34 | 35 | def get_spectrum(freqs, decay_power, ch, h, w=0, z=0): 36 | """ Samples a fourier image with given size and frequencies decayed by decay power 37 | :param freqs: Bin values for the discrete fourier transform 38 | :param decay_power: Decay power for frequency decay prop 1/f**d 39 | :param ch: Number of channels for the resulting mask 40 | :param h: Required, first dimension size 41 | :param w: Optional, second dimension size 42 | :param z: Optional, third dimension size 43 | """ 44 | scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power) 45 | 46 | param_size = [ch] + list(freqs.shape) + [2] 47 | param = np.random.randn(*param_size) 48 | 49 | scale = np.expand_dims(scale, -1)[None, :] 50 | 51 | return scale * param 52 | 53 | 54 | def make_low_freq_image(decay, shape, ch=1): 55 | """ Sample a low frequency image from fourier space 56 | :param decay_power: Decay power for frequency decay prop 1/f**d 57 | :param shape: Shape of desired mask, list up to 3 dims 58 | :param ch: Number of channels for desired mask 59 | """ 60 | freqs = fftfreqnd(*shape) 61 | spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1)) 62 | spectrum = spectrum[:, 0] + 1j * spectrum[:, 1] 63 | mask = np.real(np.fft.irfftn(spectrum, shape)) 64 | 65 | if len(shape) == 1: 66 | mask = mask[:1, :shape[0]] 67 | if len(shape) == 2: 68 | mask = mask[:1, :shape[0], :shape[1]] 69 | if len(shape) == 3: 70 | mask = mask[:1, :shape[0], :shape[1], :shape[2]] 71 | 72 | mask = mask 73 | mask = (mask - mask.min()) 74 | mask = mask / mask.max() 75 | return mask 76 | 77 | 78 | def sample_lam(alpha, reformulate=False): 79 | """ Sample a lambda from symmetric beta distribution with given alpha 80 | :param alpha: Alpha value for beta distribution 81 | :param reformulate: If True, uses the reformulation of [1]. 82 | """ 83 | if reformulate: 84 | lam = beta.rvs(alpha+1, alpha) 85 | else: 86 | lam = beta.rvs(alpha, alpha) 87 | 88 | return lam 89 | 90 | 91 | def binarise_mask(mask, lam, in_shape, max_soft=0.0): 92 | """ Binarises a given low frequency image such that it has mean lambda. 93 | :param mask: Low frequency image, usually the result of `make_low_freq_image` 94 | :param lam: Mean value of final mask 95 | :param in_shape: Shape of inputs 96 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. 97 | :return: 98 | """ 99 | idx = mask.reshape(-1).argsort()[::-1] 100 | mask = mask.reshape(-1) 101 | num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size) 102 | 103 | eff_soft = max_soft 104 | if max_soft > lam or max_soft > (1-lam): 105 | eff_soft = min(lam, 1-lam) 106 | 107 | soft = int(mask.size * eff_soft) 108 | num_low = num - soft 109 | num_high = num + soft 110 | 111 | mask[idx[:num_high]] = 1 112 | mask[idx[num_low:]] = 0 113 | mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low)) 114 | 115 | mask = mask.reshape((1, *in_shape)) 116 | return mask 117 | 118 | 119 | def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False): 120 | """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises 121 | it based on this lambda 122 | :param alpha: Alpha value for beta distribution from which to sample mean of mask 123 | :param decay_power: Decay power for frequency decay prop 1/f**d 124 | :param shape: Shape of desired mask, list up to 3 dims 125 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. 126 | :param reformulate: If True, uses the reformulation of [1]. 127 | """ 128 | if isinstance(shape, int): 129 | shape = (shape,) 130 | 131 | # Choose lambda 132 | lam = sample_lam(alpha, reformulate) 133 | 134 | # Make mask, get mean / std 135 | mask = make_low_freq_image(decay_power, shape) 136 | mask = binarise_mask(mask, lam, shape, max_soft) 137 | 138 | return lam, mask 139 | 140 | 141 | def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False): 142 | """ 143 | :param x: Image batch on which to apply fmix of shape [b, c, shape*] 144 | :param alpha: Alpha value for beta distribution from which to sample mean of mask 145 | :param decay_power: Decay power for frequency decay prop 1/f**d 146 | :param shape: Shape of desired mask, list up to 3 dims 147 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. 148 | :param reformulate: If True, uses the reformulation of [1]. 149 | :return: mixed input, permutation indices, lambda value of mix, 150 | """ 151 | lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate) 152 | index = np.random.permutation(x.shape[0]) 153 | 154 | x1, x2 = x * mask, x[index] * (1-mask) 155 | return x1+x2, index, lam 156 | 157 | 158 | class FMixBase: 159 | r""" FMix augmentation 160 | Args: 161 | decay_power (float): Decay power for frequency decay prop 1/f**d 162 | alpha (float): Alpha value for beta distribution from which to sample mean of mask 163 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims 164 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask. 165 | reformulate (bool): If True, uses the reformulation of [1]. 166 | """ 167 | 168 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False): 169 | super().__init__() 170 | self.decay_power = decay_power 171 | self.reformulate = reformulate 172 | self.size = size 173 | self.alpha = alpha 174 | self.max_soft = max_soft 175 | self.index = None 176 | self.lam = None 177 | 178 | def __call__(self, x): 179 | raise NotImplementedError 180 | 181 | def loss(self, *args, **kwargs): 182 | raise NotImplementedError 183 | 184 | 185 | -------------------------------------------------------------------------------- /utils/_aug.py: -------------------------------------------------------------------------------- 1 | from imgaug import augmenters as iaa 2 | import cv2 3 | import numpy as np 4 | from scipy.ndimage import measurements 5 | """ 6 | from Hover-Net augmentation 7 | https://github.com/vqdang/hover_net/blob/master/dataloader/augs.py 8 | """ 9 | 10 | #### 11 | def fix_mirror_padding(ann): 12 | """Deal with duplicated instances due to mirroring in interpolation 13 | during shape augmentation (scale, rotation etc.). 14 | 15 | """ 16 | current_max_id = np.amax(ann) 17 | inst_list = list(np.unique(ann)) 18 | inst_list.remove(0) # 0 is background 19 | for inst_id in inst_list: 20 | inst_map = np.array(ann == inst_id, np.uint8) 21 | remapped_ids = measurements.label(inst_map)[0] 22 | remapped_ids[remapped_ids > 1] += current_max_id 23 | ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1] 24 | current_max_id = np.amax(ann) 25 | return ann 26 | 27 | 28 | #### 29 | def gaussian_blur(images, random_state, parents, hooks, max_ksize=3): 30 | """Apply Gaussian blur to input images.""" 31 | img = images[0] # aleju input batch as default (always=1 in our case) 32 | ksize = random_state.randint(0, max_ksize, size=(2,)) 33 | ksize = tuple((ksize * 2 + 1).tolist()) 34 | 35 | ret = cv2.GaussianBlur( 36 | img, ksize, sigmaX=0, sigmaY=0, borderType=cv2.BORDER_REPLICATE 37 | ) 38 | ret = np.reshape(ret, img.shape) 39 | ret = ret.astype(np.uint8) 40 | return [ret] 41 | 42 | 43 | #### 44 | def median_blur(images, random_state, parents, hooks, max_ksize=3): 45 | """Apply median blur to input images.""" 46 | img = images[0] # aleju input batch as default (always=1 in our case) 47 | ksize = random_state.randint(0, max_ksize) 48 | ksize = ksize * 2 + 1 49 | ret = cv2.medianBlur(img, ksize) 50 | ret = ret.astype(np.uint8) 51 | return [ret] 52 | 53 | 54 | #### 55 | def add_to_hue(images, random_state, parents, hooks, range=(-8, 8)): 56 | """Perturbe the hue of input images.""" 57 | img = images[0] # aleju input batch as default (always=1 in our case) 58 | hue = random_state.uniform(*range) 59 | hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) 60 | if hsv.dtype.itemsize == 1: 61 | # OpenCV uses 0-179 for 8-bit images 62 | hsv[..., 0] = (hsv[..., 0] + hue) % 180 63 | else: 64 | # OpenCV uses 0-360 for floating point images 65 | hsv[..., 0] = (hsv[..., 0] + 2 * hue) % 360 66 | ret = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) 67 | ret = ret.astype(np.uint8) 68 | return [ret] 69 | 70 | 71 | #### 72 | def add_to_saturation(images, random_state, parents, hooks, range=(-0.2, 0.2)): 73 | """Perturbe the saturation of input images.""" 74 | img = images[0] # aleju input batch as default (always=1 in our case) 75 | value = 1 + random_state.uniform(*range) 76 | gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 77 | ret = img * value + (gray * (1 - value))[:, :, np.newaxis] 78 | ret = np.clip(ret, 0, 255) 79 | ret = ret.astype(np.uint8) 80 | return [ret] 81 | 82 | 83 | #### 84 | def add_to_contrast(images, random_state, parents, hooks, range=(0.75, 1.25)): 85 | """Perturbe the contrast of input images.""" 86 | img = images[0] # aleju input batch as default (always=1 in our case) 87 | value = random_state.uniform(*range) 88 | mean = np.mean(img, axis=(0, 1), keepdims=True) 89 | ret = img * value + mean * (1 - value) 90 | ret = np.clip(img, 0, 255) 91 | ret = ret.astype(np.uint8) 92 | return [ret] 93 | 94 | 95 | #### 96 | def add_to_brightness(images, random_state, parents, hooks, range=(-26, 26)): 97 | """Perturbe the brightness of input images.""" 98 | img = images[0] # aleju input batch as default (always=1 in our case) 99 | value = random_state.uniform(*range) 100 | ret = np.clip(img + value, 0, 255) 101 | ret = ret.astype(np.uint8) 102 | return [ret] 103 | 104 | 105 | def get_augmentation(mode, rng,input_shape=(256,256)): 106 | if mode == "train": 107 | print('Using train augmentation') 108 | shape_augs = [ 109 | # * order = ``0`` -> ``cv2.INTER_NEAREST`` 110 | # * order = ``1`` -> ``cv2.INTER_LINEAR`` 111 | # * order = ``2`` -> ``cv2.INTER_CUBIC`` 112 | # * order = ``3`` -> ``cv2.INTER_CUBIC`` 113 | # * order = ``4`` -> ``cv2.INTER_CUBIC`` 114 | # ! for pannuke v0, no rotation or translation, just flip to avoid mirror padding 115 | iaa.Affine( 116 | # scale images to 80-120% of their size, individually per axis 117 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, 118 | # translate by -A to +A percent (per axis) 119 | translate_percent={"x": (-0.01, 0.01), "y": (-0.01, 0.01)}, 120 | shear=(-5, 5), # shear by -5 to +5 degrees 121 | rotate=(-179, 179), # rotate by -179 to +179 degrees 122 | order=0, # default 0 use nearest neighbour 123 | backend="cv2", # opencv for fast processing 124 | seed=rng, 125 | ), 126 | # set position to 'center' for center crop 127 | # else 'uniform' for random crop 128 | iaa.CropToFixedSize( 129 | input_shape[0], input_shape[1], position="center" 130 | ), 131 | iaa.Fliplr(0.5, seed=rng), 132 | iaa.Flipud(0.5, seed=rng), 133 | ] 134 | 135 | input_augs = [ 136 | iaa.OneOf( 137 | [ 138 | iaa.Lambda( 139 | seed=rng, 140 | func_images= gaussian_blur, 141 | ), 142 | iaa.Lambda( 143 | seed=rng, 144 | func_images= median_blur, 145 | ), 146 | iaa.AdditiveGaussianNoise( 147 | loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5 148 | ), 149 | ] 150 | ), 151 | iaa.Sequential( 152 | [ 153 | iaa.Lambda( 154 | seed=rng, 155 | func_images=add_to_hue, 156 | ), 157 | iaa.Lambda( 158 | seed=rng, 159 | func_images=add_to_saturation 160 | , 161 | ), 162 | iaa.Lambda( 163 | seed=rng, 164 | func_images=add_to_brightness 165 | , 166 | ), 167 | iaa.Lambda( 168 | seed=rng, 169 | func_images= add_to_contrast, 170 | ), 171 | ], 172 | random_order=True, 173 | ), 174 | ] 175 | elif mode == "test": 176 | print('Using test augmentation') 177 | shape_augs = [ 178 | # set position to 'center' for center crop 179 | # else 'uniform' for random crop 180 | iaa.CropToFixedSize( 181 | input_shape[0], input_shape[1], position="center" 182 | ) 183 | ] 184 | input_augs = [] 185 | 186 | return iaa.Sequential(shape_augs), iaa.Sequential(input_augs) -------------------------------------------------------------------------------- /models/PointNuNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | from .necks.jpfm import JPFM 8 | from .base_ops.CoordConv import CoordConv2d 9 | from .necks.fpn import FPN 10 | 11 | class PointNuNet(SegBaseModel): 12 | def __init__(self, nclass, backbone='resnet50', pretrained_base=True, frozen_stages=-1,norm_eval=False, seg_feat_channels=256, stacked_convs=7,ins_out_channels=256, kernel_size=1,output_stride=4, **kwargs): 13 | super(PointNuNet, self).__init__(backbone, pretrained_base=pretrained_base, frozen_stages=frozen_stages,norm_eval=norm_eval, **kwargs) 14 | 15 | if 'res' in backbone: 16 | self.fpn=FPN() 17 | self.forward=self.forward_res 18 | elif 'hrnet' in backbone: 19 | if '32' in backbone: 20 | c=480 21 | elif '64' in backbone: 22 | c=960 23 | elif '18' in backbone: 24 | c=270 25 | else: 26 | raise NotImplementedError 27 | self.jpfm_1=JPFM(in_channel=c) 28 | self.jpfm_2=JPFM(in_channel=c) 29 | self.jpfm_3=JPFM(in_channel=c) 30 | self.forward=self.forward_hrnet 31 | elif 'swin' in backbone: 32 | self.fpn=FPN(channels=[192,384,768,1536]) 33 | self.forward=self.forward_swin 34 | else: 35 | raise NotImplementedError 36 | self.output_stride=output_stride 37 | self.heads=_PointNuNetHead(num_classes=nclass, 38 | in_channels=1024, 39 | seg_feat_channels=seg_feat_channels, 40 | stacked_convs=stacked_convs, 41 | ins_out_channels=ins_out_channels, 42 | kernel_size=kernel_size) 43 | 44 | def forward_swin(self, x): 45 | c2, c3, c4, c5 = self.pretrained(x) 46 | c2, c3, c4, c5=self.fpn(c2, c3, c4, c5) 47 | x0_h, x0_w = c2.size(2), c2.size(3) 48 | c3 = F.interpolate(c3, size=(x0_h, x0_w), mode='bilinear', align_corners=True) 49 | c4 = F.interpolate(c4, size=(x0_h, x0_w), mode='bilinear', align_corners=True) 50 | c5 = F.interpolate(c5, size=(x0_h, x0_w), mode='bilinear', align_corners=True) 51 | cat_x = torch.cat([c2,c3,c4,c5], 1) 52 | 53 | output=self.heads(cat_x, cat_x, cat_x) 54 | return output 55 | 56 | def forward_res(self, x): 57 | c1, c2, c3, c4, c5 = self.base_forward(x) 58 | c2, c3, c4, c5=self.fpn(c2, c3, c4, c5) 59 | x0_h, x0_w = c2.size(2), c2.size(3) 60 | c3 = F.interpolate(c3, size=(x0_h, x0_w), mode='bilinear', align_corners=True) 61 | c4 = F.interpolate(c4, size=(x0_h, x0_w), mode='bilinear', align_corners=True) 62 | c5 = F.interpolate(c5, size=(x0_h, x0_w), mode='bilinear', align_corners=True) 63 | cat_x = torch.cat([c2,c3,c4,c5], 1) 64 | 65 | output=self.heads(cat_x, cat_x, cat_x) 66 | return output 67 | 68 | def forward_hrnet(self, x): 69 | c2,c3,c4,c5 = self.pretrained(x) 70 | x0_h, x0_w = c2.size(2), c2.size(3) 71 | c3 = F.interpolate(c3, size=(x0_h, x0_w), mode='bilinear', align_corners=True) 72 | c4 = F.interpolate(c4, size=(x0_h, x0_w), mode='bilinear', align_corners=True) 73 | c5 = F.interpolate(c5, size=(x0_h, x0_w), mode='bilinear', align_corners=True) 74 | cat_x = torch.cat([c2,c3,c4,c5], 1) 75 | 76 | f1=self.jpfm_1(cat_x) 77 | f2=self.jpfm_2(cat_x) 78 | f3=self.jpfm_3(cat_x) 79 | if self.output_stride!=4: 80 | f2=F.interpolate(f2, size=(256//self.output_stride, 256//self.output_stride), mode='bilinear', align_corners=True) 81 | f3=F.interpolate(f3, size=(256//self.output_stride, 256//self.output_stride), mode='bilinear', align_corners=True) 82 | output=self.heads(f1,f2,f3) 83 | 84 | return output 85 | 86 | class _PointNuNetHead(nn.Module): 87 | def __init__(self,num_classes, 88 | in_channels=256*4, 89 | seg_feat_channels=256, 90 | stacked_convs=7, 91 | ins_out_channels=256, 92 | kernel_size=1 93 | ): 94 | 95 | super(_PointNuNetHead,self).__init__() 96 | self.num_classes = num_classes 97 | self.cate_out_channels = self.num_classes - 1 98 | self.in_channels = in_channels 99 | self.stacked_convs = stacked_convs 100 | self.seg_feat_channels = seg_feat_channels 101 | self.seg_out_channels = ins_out_channels 102 | self.ins_out_channels = ins_out_channels 103 | self.kernel_out_channels = (self.ins_out_channels * kernel_size * kernel_size) 104 | 105 | self._init_layers() 106 | self.init_weight() 107 | 108 | def _init_layers(self): 109 | self.mask_convs = nn.ModuleList() 110 | self.kernel_convs = nn.ModuleList() 111 | self.cate_convs = nn.ModuleList() 112 | 113 | for i in range(self.stacked_convs): 114 | chn = self.in_channels if i == 0 else self.seg_feat_channels 115 | conv = CoordConv2d if i ==0 else nn.Conv2d 116 | self.kernel_convs.append(nn.Sequential( 117 | conv(chn, self.seg_feat_channels, 3, 1, 1, bias=False), 118 | nn.BatchNorm2d(self.seg_feat_channels), 119 | nn.ReLU(True), 120 | )) 121 | 122 | chn = self.in_channels if i == 0 else self.seg_feat_channels 123 | self.cate_convs.append(nn.Sequential( 124 | nn.Conv2d(chn, self.seg_feat_channels, 3, 1, 1, bias=False), 125 | nn.BatchNorm2d(self.seg_feat_channels), 126 | nn.ReLU(True), 127 | )) 128 | 129 | self.head_kernel = nn.Conv2d(self.seg_feat_channels, self.kernel_out_channels, 1,padding=0) 130 | self.head_cate = nn.Conv2d(self.seg_feat_channels, self.cate_out_channels, 3, padding=1) 131 | 132 | self.mask_convs.append(nn.Sequential( 133 | nn.Conv2d(self.in_channels, self.seg_feat_channels, 3, 1, 1, bias=False), 134 | nn.BatchNorm2d(self.seg_feat_channels), 135 | nn.ReLU(True), 136 | nn.Conv2d(self.seg_feat_channels, self.seg_feat_channels, 3, 1, 1, bias=False), 137 | nn.BatchNorm2d(self.seg_feat_channels), 138 | nn.ReLU(True),)) 139 | 140 | self.mask_convs.append(nn.Sequential( 141 | nn.ConvTranspose2d(self.seg_feat_channels, self.seg_feat_channels, 4, 2, padding=1, output_padding=0,bias=False), 142 | nn.BatchNorm2d(self.seg_feat_channels), 143 | nn.ReLU(True), 144 | nn.Conv2d(self.seg_feat_channels, self.seg_feat_channels, 3, 1, 1, bias=False), 145 | nn.BatchNorm2d(self.seg_feat_channels), 146 | nn.ReLU(True),)) 147 | 148 | self.mask_convs.append(nn.Sequential( 149 | nn.ConvTranspose2d(self.seg_feat_channels, self.seg_feat_channels, 4, 2, padding=1, output_padding=0,bias=False), 150 | nn.BatchNorm2d(self.seg_feat_channels), 151 | nn.ReLU(True), 152 | nn.Conv2d(self.seg_feat_channels, self.seg_feat_channels, 3, 1, 1, bias=False), 153 | nn.BatchNorm2d(self.seg_feat_channels), 154 | nn.ReLU(True))) 155 | 156 | self.head_mask = nn.Sequential( 157 | nn.Conv2d(self.seg_feat_channels, self.seg_out_channels, 1, padding=0, bias=False), 158 | nn.BatchNorm2d(self.seg_out_channels), 159 | nn.ReLU(True)) 160 | 161 | def init_weight(self): 162 | prior_prob = 0.01 163 | bias_init = float(-math.log((1 - prior_prob) / prior_prob)) 164 | torch.nn.init.normal_(self.head_cate.weight, std=0.01) 165 | torch.nn.init.constant_(self.head_cate.bias, bias_init) 166 | 167 | def forward(self, feats,f2,f3): 168 | # feature branch 169 | mask_feat=feats 170 | for i, mask_layer in enumerate(self.mask_convs): 171 | mask_feat = mask_layer(mask_feat) 172 | feature_pred = self.head_mask(mask_feat) 173 | 174 | # kernel branch 175 | kernel_feat=f2 176 | for i, kernel_layer in enumerate(self.kernel_convs): 177 | kernel_feat = kernel_layer(kernel_feat) 178 | kernel_pred = self.head_kernel(kernel_feat) 179 | 180 | # cate branch 181 | cate_feat=f3 182 | for i, cate_layer in enumerate(self.cate_convs): 183 | cate_feat = cate_layer(cate_feat) 184 | cate_pred = self.head_cate(cate_feat) 185 | return feature_pred, kernel_pred, cate_pred 186 | 187 | 188 | -------------------------------------------------------------------------------- /models/backbone/res2netv1b.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import torch.nn.functional as F 7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b'] 8 | 9 | 10 | model_urls = { 11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 13 | } 14 | 15 | 16 | class Bottle2neck(nn.Module): 17 | expansion = 4 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'): 20 | """ Constructor 21 | Args: 22 | inplanes: input channel dimensionality 23 | planes: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | downsample: None when stride = 1 26 | baseWidth: basic width of conv3x3 27 | scale: number of scale. 28 | type: 'normal': normal set. 'stage': first block of a new stage. 29 | """ 30 | super(Bottle2neck, self).__init__() 31 | 32 | width = int(math.floor(planes * (baseWidth/64.0))) 33 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(width*scale) 35 | 36 | if scale == 1: 37 | self.nums = 1 38 | else: 39 | self.nums = scale -1 40 | if stype == 'stage': 41 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 42 | convs = [] 43 | bns = [] 44 | for i in range(self.nums): 45 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False)) 46 | bns.append(nn.BatchNorm2d(width)) 47 | self.convs = nn.ModuleList(convs) 48 | self.bns = nn.ModuleList(bns) 49 | 50 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 52 | 53 | self.relu = nn.ReLU(inplace=True) 54 | self.downsample = downsample 55 | self.stype = stype 56 | self.scale = scale 57 | self.width = width 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | spx = torch.split(out, self.width, 1) 67 | for i in range(self.nums): 68 | if i==0 or self.stype=='stage': 69 | sp = spx[i] 70 | else: 71 | sp = sp + spx[i] 72 | sp = self.convs[i](sp) 73 | sp = self.relu(self.bns[i](sp)) 74 | if i==0: 75 | out = sp 76 | else: 77 | out = torch.cat((out, sp), 1) 78 | if self.scale != 1 and self.stype=='normal': 79 | out = torch.cat((out, spx[self.nums]),1) 80 | elif self.scale != 1 and self.stype=='stage': 81 | out = torch.cat((out, self.pool(spx[self.nums])),1) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | class Res2Net(nn.Module): 95 | 96 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): 97 | self.inplanes = 64 98 | super(Res2Net, self).__init__() 99 | self.baseWidth = baseWidth 100 | self.scale = scale 101 | self.conv1 = nn.Sequential( 102 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 103 | nn.BatchNorm2d(32), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 106 | nn.BatchNorm2d(32), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 109 | ) 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.relu = nn.ReLU() 112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 113 | self.layer1 = self._make_layer(block, 64, layers[0]) 114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 117 | self.avgpool = nn.AdaptiveAvgPool2d(1) 118 | self.fc = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | elif isinstance(m, nn.BatchNorm2d): 124 | nn.init.constant_(m.weight, 1) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | nn.AvgPool2d(kernel_size=stride, stride=stride, 132 | ceil_mode=True, count_include_pad=False), 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=1, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 140 | stype='stage', baseWidth = self.baseWidth, scale=self.scale)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | x = self.bn1(x) 150 | x = self.relu(x) 151 | x = self.maxpool(x) 152 | 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | 158 | x = self.avgpool(x) 159 | x = x.view(x.size(0), -1) 160 | x = self.fc(x) 161 | 162 | return x 163 | 164 | 165 | def res2net50_v1b(pretrained=False, **kwargs): 166 | """Constructs a Res2Net-50_v1b model. 167 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | """ 171 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 172 | if pretrained: 173 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 174 | return model 175 | 176 | def res2net101_v1b(pretrained=False, **kwargs): 177 | """Constructs a Res2Net-50_v1b_26w_4s model. 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 182 | if pretrained: 183 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 184 | return model 185 | 186 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 187 | """Constructs a Res2Net-50_v1b_26w_4s model. 188 | Args: 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 192 | if pretrained: 193 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 194 | return model 195 | 196 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 197 | """Constructs a Res2Net-50_v1b_26w_4s model. 198 | Args: 199 | pretrained (bool): If True, returns a model pre-trained on ImageNet 200 | """ 201 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 202 | if pretrained: 203 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 204 | return model 205 | 206 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 207 | """Constructs a Res2Net-50_v1b_26w_4s model. 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth = 26, scale = 4, **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'])) 214 | return model 215 | 216 | 217 | 218 | 219 | 220 | if __name__ == '__main__': 221 | images = torch.rand(1, 3, 224, 224).cuda(0) 222 | model = res2net50_v1b_26w_4s(pretrained=True) 223 | model = model.cuda(0) 224 | print(model(images).size()) 225 | -------------------------------------------------------------------------------- /losses/lovasz_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | """ 5 | 6 | from __future__ import print_function, division 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import numpy as np 12 | 13 | try: 14 | from itertools import ifilterfalse 15 | except ImportError: # py3k 16 | from itertools import filterfalse as ifilterfalse 17 | 18 | 19 | def lovasz_grad(gt_sorted): 20 | """ 21 | Computes gradient of the Lovasz extension w.r.t sorted errors 22 | See Alg. 1 in paper 23 | """ 24 | p = len(gt_sorted) 25 | gts = gt_sorted.sum() 26 | intersection = gts - gt_sorted.float().cumsum(0) 27 | union = gts + (1 - gt_sorted).float().cumsum(0) 28 | jaccard = 1. - intersection / union 29 | if p > 1: # cover 1-pixel case 30 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 31 | return jaccard 32 | 33 | 34 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 35 | """ 36 | IoU for foreground class 37 | binary: 1 foreground, 0 background 38 | """ 39 | if not per_image: 40 | preds, labels = (preds,), (labels,) 41 | ious = [] 42 | for pred, label in zip(preds, labels): 43 | intersection = ((label == 1) & (pred == 1)).sum() 44 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 45 | if not union: 46 | iou = EMPTY 47 | else: 48 | iou = float(intersection) / float(union) 49 | ious.append(iou) 50 | iou = mean(ious) # mean accross images if per_image 51 | return 100 * iou 52 | 53 | 54 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 55 | """ 56 | Array of IoU for each (non ignored) class 57 | """ 58 | if not per_image: 59 | preds, labels = (preds,), (labels,) 60 | ious = [] 61 | for pred, label in zip(preds, labels): 62 | iou = [] 63 | for i in range(C): 64 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 65 | intersection = ((label == i) & (pred == i)).sum() 66 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 67 | if not union: 68 | iou.append(EMPTY) 69 | else: 70 | iou.append(float(intersection) / float(union)) 71 | ious.append(iou) 72 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 73 | return 100 * np.array(ious) 74 | 75 | 76 | # --------------------------- BINARY LOSSES --------------------------- 77 | 78 | 79 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 80 | """ 81 | Binary Lovasz hinge loss 82 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 83 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 84 | per_image: compute the loss per image instead of per batch 85 | ignore: void class id 86 | """ 87 | if per_image: 88 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 89 | for log, lab in zip(logits, labels)) 90 | else: 91 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 92 | return loss 93 | 94 | 95 | def lovasz_hinge_flat(logits, labels): 96 | """ 97 | Binary Lovasz hinge loss 98 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 99 | labels: [P] Tensor, binary ground truth labels (0 or 1) 100 | ignore: label to ignore 101 | """ 102 | if len(labels) == 0: 103 | # only void pixels, the gradients should be 0 104 | return logits.sum() * 0. 105 | signs = 2. * labels.float() - 1. 106 | errors = (1. - logits * Variable(signs)) 107 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 108 | perm = perm.data 109 | gt_sorted = labels[perm] 110 | grad = lovasz_grad(gt_sorted) 111 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 112 | return loss 113 | 114 | 115 | def flatten_binary_scores(scores, labels, ignore=None): 116 | """ 117 | Flattens predictions in the batch (binary case) 118 | Remove labels equal to 'ignore' 119 | """ 120 | scores = scores.view(-1) 121 | labels = labels.view(-1) 122 | if ignore is None: 123 | return scores, labels 124 | valid = (labels != ignore) 125 | vscores = scores[valid] 126 | vlabels = labels[valid] 127 | return vscores, vlabels 128 | 129 | 130 | class StableBCELoss(torch.nn.modules.Module): 131 | def __init__(self): 132 | super(StableBCELoss, self).__init__() 133 | 134 | def forward(self, input, target): 135 | neg_abs = - input.abs() 136 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 137 | return loss.mean() 138 | 139 | 140 | def binary_xloss(logits, labels, ignore=None): 141 | """ 142 | Binary Cross entropy loss 143 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 144 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 145 | ignore: void class id 146 | """ 147 | logits, labels = flatten_binary_scores(logits, labels, ignore) 148 | loss = StableBCELoss()(logits, Variable(labels.float())) 149 | return loss 150 | 151 | 152 | # --------------------------- MULTICLASS LOSSES --------------------------- 153 | 154 | 155 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 156 | """ 157 | Multi-class Lovasz-Softmax loss 158 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 159 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 160 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 161 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 162 | per_image: compute the loss per image instead of per batch 163 | ignore: void class labels 164 | """ 165 | if per_image: 166 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 167 | for prob, lab in zip(probas, labels)) 168 | else: 169 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 170 | return loss 171 | 172 | 173 | def lovasz_softmax_flat(probas, labels, classes='present'): 174 | """ 175 | Multi-class Lovasz-Softmax loss 176 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 177 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 178 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 179 | """ 180 | if probas.numel() == 0: 181 | # only void pixels, the gradients should be 0 182 | return probas * 0. 183 | C = probas.size(1) 184 | losses = [] 185 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 186 | for c in class_to_sum: 187 | fg = (labels == c).float() # foreground for class c 188 | if (classes is 'present' and fg.sum() == 0): 189 | continue 190 | if C == 1: 191 | if len(classes) > 1: 192 | raise ValueError('Sigmoid output possible only with 1 class') 193 | class_pred = probas[:, 0] 194 | else: 195 | class_pred = probas[:, c] 196 | errors = (Variable(fg) - class_pred).abs() 197 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 198 | perm = perm.data 199 | fg_sorted = fg[perm] 200 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 201 | return mean(losses) 202 | 203 | 204 | def flatten_probas(probas, labels, ignore=None): 205 | """ 206 | Flattens predictions in the batch 207 | """ 208 | if probas.dim() == 3: 209 | # assumes output of a sigmoid layer 210 | B, H, W = probas.size() 211 | probas = probas.view(B, 1, H, W) 212 | B, C, H, W = probas.size() 213 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 214 | labels = labels.view(-1) 215 | if ignore is None: 216 | return probas, labels 217 | valid = (labels != ignore) 218 | vprobas = probas[valid.nonzero().squeeze()] 219 | vlabels = labels[valid] 220 | return vprobas, vlabels 221 | 222 | 223 | def xloss(logits, labels, ignore=None): 224 | """ 225 | Cross entropy loss 226 | """ 227 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 228 | 229 | 230 | # --------------------------- HELPER FUNCTIONS --------------------------- 231 | def isnan(x): 232 | return x != x 233 | 234 | 235 | def mean(l, ignore_nan=False, empty=0): 236 | """ 237 | nanmean compatible with generators. 238 | """ 239 | l = iter(l) 240 | if ignore_nan: 241 | l = ifilterfalse(isnan, l) 242 | try: 243 | n = 1 244 | acc = next(l) 245 | except StopIteration: 246 | if empty == 'raise': 247 | raise ValueError('Empty mean') 248 | return empty 249 | for n, v in enumerate(l, 2): 250 | acc += v 251 | if n == 1: 252 | return acc 253 | return acc / n -------------------------------------------------------------------------------- /models/backbone/res2net.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import torch.nn.functional as F 7 | __all__ = ['Res2Net', 'res2net50'] 8 | 9 | 10 | model_urls = { 11 | 'res2net50_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_4s-06e79181.pth', 12 | 'res2net50_48w_2s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_48w_2s-afed724a.pth', 13 | 'res2net50_14w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_14w_8s-6527dddc.pth', 14 | 'res2net50_26w_6s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_6s-19041792.pth', 15 | 'res2net50_26w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_8s-2c7c9f12.pth', 16 | 'res2net101_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_26w_4s-02a759a1.pth', 17 | } 18 | 19 | class Bottle2neck(nn.Module): 20 | expansion = 4 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'): 23 | """ Constructor 24 | Args: 25 | inplanes: input channel dimensionality 26 | planes: output channel dimensionality 27 | stride: conv stride. Replaces pooling layer. 28 | downsample: None when stride = 1 29 | baseWidth: basic width of conv3x3 30 | scale: number of scale. 31 | type: 'normal': normal set. 'stage': first block of a new stage. 32 | """ 33 | super(Bottle2neck, self).__init__() 34 | 35 | width = int(math.floor(planes * (baseWidth/64.0))) 36 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) 37 | self.bn1 = nn.BatchNorm2d(width*scale) 38 | 39 | if scale == 1: 40 | self.nums = 1 41 | else: 42 | self.nums = scale -1 43 | if stype == 'stage': 44 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 45 | convs = [] 46 | bns = [] 47 | for i in range(self.nums): 48 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False)) 49 | bns.append(nn.BatchNorm2d(width)) 50 | self.convs = nn.ModuleList(convs) 51 | self.bns = nn.ModuleList(bns) 52 | 53 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 55 | 56 | self.relu = nn.ReLU(inplace=True) 57 | self.downsample = downsample 58 | self.stype = stype 59 | self.scale = scale 60 | self.width = width 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | spx = torch.split(out, self.width, 1) 70 | for i in range(self.nums): 71 | if i==0 or self.stype=='stage': 72 | sp = spx[i] 73 | else: 74 | sp = sp + spx[i] 75 | sp = self.convs[i](sp) 76 | sp = self.relu(self.bns[i](sp)) 77 | if i==0: 78 | out = sp 79 | else: 80 | out = torch.cat((out, sp), 1) 81 | if self.scale != 1 and self.stype=='normal': 82 | out = torch.cat((out, spx[self.nums]),1) 83 | elif self.scale != 1 and self.stype=='stage': 84 | out = torch.cat((out, self.pool(spx[self.nums])),1) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | class Res2Net(nn.Module): 98 | 99 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): 100 | self.inplanes = 64 101 | super(Res2Net, self).__init__() 102 | self.baseWidth = baseWidth 103 | self.scale = scale 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 105 | bias=False) 106 | self.bn1 = nn.BatchNorm2d(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 109 | self.layer1 = self._make_layer(block, 64, layers[0]) 110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 113 | self.avgpool = nn.AdaptiveAvgPool2d(1) 114 | self.fc = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.constant_(m.weight, 1) 121 | nn.init.constant_(m.bias, 0) 122 | 123 | def _make_layer(self, block, planes, blocks, stride=1): 124 | downsample = None 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | nn.Conv2d(self.inplanes, planes * block.expansion, 128 | kernel_size=1, stride=stride, bias=False), 129 | nn.BatchNorm2d(planes * block.expansion), 130 | ) 131 | 132 | layers = [] 133 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 134 | stype='stage', baseWidth = self.baseWidth, scale=self.scale)) 135 | self.inplanes = planes * block.expansion 136 | for i in range(1, blocks): 137 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.relu(x) 145 | x = self.maxpool(x) 146 | 147 | x = self.layer1(x) 148 | x = self.layer2(x) 149 | x = self.layer3(x) 150 | x = self.layer4(x) 151 | 152 | x = self.avgpool(x) 153 | x = x.view(x.size(0), -1) 154 | x = self.fc(x) 155 | 156 | return x 157 | 158 | 159 | def res2net50(pretrained=False, **kwargs): 160 | """Constructs a Res2Net-50 model. 161 | Res2Net-50 refers to the Res2Net-50_26w_4s. 162 | Args: 163 | pretrained (bool): If True, returns a model pre-trained on ImageNet 164 | """ 165 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 166 | if pretrained: 167 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s'])) 168 | return model 169 | 170 | def res2net50_26w_4s(pretrained=False, **kwargs): 171 | """Constructs a Res2Net-50_26w_4s model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_4s'])) 178 | return model 179 | 180 | def res2net101_26w_4s(pretrained=False, **kwargs): 181 | """Constructs a Res2Net-50_26w_4s model. 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 186 | if pretrained: 187 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_26w_4s'])) 188 | return model 189 | 190 | def res2net50_26w_6s(pretrained=False, **kwargs): 191 | """Constructs a Res2Net-50_26w_4s model. 192 | Args: 193 | pretrained (bool): If True, returns a model pre-trained on ImageNet 194 | """ 195 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 6, **kwargs) 196 | if pretrained: 197 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_6s'])) 198 | return model 199 | 200 | def res2net50_26w_8s(pretrained=False, **kwargs): 201 | """Constructs a Res2Net-50_26w_4s model. 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 8, **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_26w_8s'])) 208 | return model 209 | 210 | def res2net50_48w_2s(pretrained=False, **kwargs): 211 | """Constructs a Res2Net-50_48w_2s model. 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 48, scale = 2, **kwargs) 216 | if pretrained: 217 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_48w_2s'])) 218 | return model 219 | 220 | def res2net50_14w_8s(pretrained=False, **kwargs): 221 | """Constructs a Res2Net-50_14w_8s model. 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | """ 225 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 14, scale = 8, **kwargs) 226 | if pretrained: 227 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_14w_8s'])) 228 | return model 229 | 230 | 231 | 232 | if __name__ == '__main__': 233 | images = torch.rand(1, 3, 224, 224).cuda(0) 234 | model = res2net101_26w_4s(pretrained=True) 235 | model = model.cuda(0) 236 | print(model(images).size()) 237 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from trainer import Trainer 7 | 8 | try: 9 | from itertools import izip as zip 10 | except ImportError: # will be 3.x series 11 | pass 12 | import scipy.io as scio 13 | 14 | import torch 15 | import os 16 | import tifffile 17 | import numpy as np 18 | from PIL import Image 19 | from utils import get_config,_imageshow,_imagesave 20 | from skimage.util.shape import view_as_windows 21 | from utils.imop import get_ins_info 22 | from utils.metrics import get_fast_aji,get_fast_pq, get_dice_1,remap_label 23 | from torchvision import transforms 24 | import argparse 25 | import cv2 26 | from collections import Counter 27 | import time 28 | import scipy.io as scio 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--output_dir', type=str, default='outputs') 32 | parser.add_argument('--name', type=str, default='tmp') 33 | parser.add_argument('--epoch',type=int,default=100) 34 | parser.add_argument('--load_size',type=int,default=1024) 35 | parser.add_argument('--patch_size',type=int,default=256) 36 | parser.add_argument('--stride',type=int,default=128) 37 | 38 | opts = parser.parse_args() 39 | 40 | if __name__ == '__main__': 41 | 42 | opts.config=os.path.join(opts.output_dir,'{}/config.yaml'.format(opts.name)) 43 | 44 | config=get_config(opts.config) 45 | trainer = Trainer(config) 46 | trainer.cuda() 47 | 48 | load_size = opts.load_size 49 | patch_size = opts.patch_size 50 | stride = opts.stride 51 | 52 | state_path = os.path.join(opts.output_dir,'{}/checkpoints/model_{}.pt'.format(opts.name, '%04d' % (opts.epoch))) 53 | state_dict = torch.load(state_path) 54 | 55 | trainer.model.load_state_dict(state_dict['seg']) 56 | trainer.model.eval() 57 | if not config['image_norm_mean']: 58 | _mean = (0.5, 0.5, 0.5) 59 | _std = (0.5, 0.5, 0.5) 60 | else: 61 | _mean = tuple(config['image_norm_mean']) 62 | _std = tuple(config['image_norm_std']) 63 | im_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(_mean, _std)]) 64 | 65 | ajis = [] 66 | dices = [] 67 | pqs = [] 68 | dqs = [] 69 | sqs = [] 70 | root = os.path.join(config['dataroot'], 'test') 71 | stain_norm_type=config['stainnorm'] 72 | 73 | test_dir_fp=os.path.join(root, 'Images') if stain_norm_type is None else os.path.join(root, f'Images_{stain_norm_type}') 74 | print(test_dir_fp,stain_norm_type,_mean,_std) 75 | test_img_fp=os.listdir(test_dir_fp) 76 | 77 | test_meta = [os.path.splitext(p)[0] for p in test_img_fp] 78 | test_img_fp=[os.path.join(test_dir_fp,f) for f in test_img_fp] 79 | for i, (test_fp, test_file_name) in enumerate(zip(test_img_fp,test_meta)): 80 | if 'tif' in test_fp: 81 | with tifffile.TiffFile(test_fp) as f: 82 | test_img = f.asarray() 83 | else: 84 | test_img = np.array(Image.open(test_fp)) 85 | original_img=test_img.copy()#tifffile.TiffFile(os.path.join(root,'Images',f'{test_file_name}.tif')).asarray()# 86 | 87 | test_gt_path = os.path.join(root, 'Labels', test_file_name + '.mat') 88 | test_GT = scio.loadmat(test_gt_path)['inst_map'].astype(np.int32) 89 | 90 | im_size = test_img.shape[0] 91 | 92 | assert patch_size % stride == 0 and load_size % patch_size == 0 93 | pad_size = (load_size - im_size + patch_size) // 2 94 | 95 | test_img = np.pad(test_img, ((pad_size, pad_size), (pad_size, pad_size), (0, 0)), "reflect") 96 | crop_test_imgs = view_as_windows(test_img, (patch_size, patch_size, 3), (stride, stride, 3))[:, :, 0] 97 | 98 | 99 | pred_crop_test_imgs = [] 100 | ins_num = 1 101 | 102 | output_seg = np.zeros((load_size + patch_size, load_size + patch_size), dtype=np.int32) 103 | score_list={} 104 | for i in range(crop_test_imgs.shape[0]): 105 | for j in range(crop_test_imgs.shape[1]): 106 | crop_test_img = crop_test_imgs[i, j] 107 | crop_test_img = im_transform(crop_test_img).unsqueeze(0).cuda() 108 | with torch.no_grad(): 109 | output = trainer.prediction_single(crop_test_img) 110 | #output = trainer.prediction_fast(crop_test_img) 111 | if output is not None: 112 | seg_masks, cate_labels, cate_scores = output 113 | else: 114 | continue 115 | seg_masks = seg_masks.cpu().numpy() 116 | cate_labels = cate_labels.cpu().numpy() 117 | cate_scores = cate_scores.cpu().numpy() 118 | for ins_id in range(seg_masks.shape[0]): 119 | seg_ = seg_masks[ins_id] 120 | label_ = cate_labels[ins_id] 121 | score_ = cate_scores[ins_id] 122 | center_w, center_h, width, height = get_ins_info(seg_, method='bbox') 123 | 124 | center_h = np.ceil(center_h) 125 | center_w = np.ceil(center_w) 126 | offset_h = i * stride 127 | offset_w = j * stride 128 | if center_h >= patch_size // 2 - stride // 2 and center_h <= patch_size // 2 + stride // 2 and center_w >= patch_size // 2 - stride // 2 and center_w <= patch_size // 2 + stride // 2: 129 | focus_area = output_seg[offset_h:offset_h + patch_size, offset_w:offset_w + patch_size].copy() 130 | if np.sum(np.logical_and(focus_area > 0, seg_)) == 0: 131 | output_seg[offset_h:offset_h + patch_size, offset_w:offset_w + patch_size] = np.where( 132 | focus_area > 0, focus_area, seg_ * (ins_num)) 133 | score_list[ins_num]=score_ 134 | ins_num += 1 135 | else: 136 | compared_num, _ = Counter((focus_area * seg_).flatten()).most_common(2)[1] 137 | assert compared_num > 0 138 | compared_num = int(compared_num) 139 | compared_score = score_list[compared_num] 140 | if np.sum(np.logical_and(focus_area == compared_num, seg_)) / np.sum( 141 | np.logical_or(focus_area == compared_num, seg_)) > 0.7:#IoU>0.1判断重叠 142 | if compared_score > score_: 143 | pass 144 | else: 145 | focus_area[focus_area==compared_num]=0 146 | output_seg[offset_h:offset_h + patch_size, offset_w:offset_w + patch_size]=focus_area 147 | output_seg[offset_h:offset_h + patch_size, 148 | offset_w:offset_w + patch_size] = np.where( 149 | focus_area > 0, focus_area, seg_ * (ins_num)) 150 | score_list[ins_num] = score_ 151 | ins_num += 1 152 | else: 153 | output_seg[offset_h:offset_h + patch_size, offset_w:offset_w + patch_size] = np.where( 154 | focus_area > 0, focus_area, seg_ * (ins_num)) 155 | score_list[ins_num] = score_ 156 | ins_num += 1 157 | 158 | output_seg=output_seg[pad_size:-pad_size, pad_size:-pad_size] 159 | for ui in np.unique(output_seg): 160 | if ui ==0:continue 161 | if np.sum(output_seg==ui)<16: 162 | output_seg[output_seg==ui]=0 163 | test_GT = remap_label(test_GT) 164 | 165 | output_seg = remap_label(output_seg) 166 | aji = get_fast_aji(test_GT.copy(), output_seg.copy()) 167 | dice = get_dice_1(test_GT.copy(), output_seg.copy()) 168 | 169 | [dq, sq, pq], [paired_true, paired_pred, unpaired_true, unpaired_pred] = get_fast_pq(test_GT.copy(), output_seg.copy()) 170 | print(f'dice {round(float(dice), 3)} AJI {round(float(aji), 3)} dq {round(float(dq), 3)} sq {round(float(sq), 3)} pq {round(float(pq), 3)} gt up {len(unpaired_true)} pred up {len(unpaired_pred)}') 171 | title=f'DICE:{round(float(dice), 3)}, AJI:{round(float(aji), 3)},\n DQ:{round(float(dq), 3)}, SQ:{round(float(sq), 3)}, PQ:{round(float(pq), 3)}' 172 | #_imageshow(test_img[pad_size:-pad_size, pad_size:-pad_size],output_seg,test_GT,unpaired_pred,unpaired_true,title=title) 173 | #_imagesave(original_img,output_seg,None,f'consep_pred/{test_file_name}.png') 174 | #_imagesave(original_img, test_GT, None, f'consep_gt/{test_file_name}.png') 175 | 176 | #results={'pred':output_seg, 'gt':test_GT,'paired_true':paired_true,'paired_pred':paired_pred,'unpaired_true':unpaired_true,'unpaired_pred':unpaired_pred} 177 | 178 | #scio.savemat(f'consep_results/{test_file_name}.mat',results) 179 | #scio.savemat(f'consep_pred/{test_file_name}.mat',new_pred) 180 | 181 | 182 | dices.append(dice) 183 | ajis.append(aji) 184 | dqs.append(dq) 185 | sqs.append(sq) 186 | pqs.append(pq) 187 | 188 | 189 | 190 | print(f'dice {round(float(np.mean(dices)),3)} AJI {round(float(np.mean(ajis)),3)} dq{round(float(np.mean(dqs)),3)} sq{round(float(np.mean(sqs)),3)} pq{round(float(np.mean(pqs)),3)}') 191 | print(f'{round(float(np.mean(dices)),3)}\t{round(float(np.mean(ajis)),3)}\t{round(float(np.mean(dqs)),3)}\t{round(float(np.mean(sqs)),3)}\t{round(float(np.mean(pqs)),3)}') 192 | -------------------------------------------------------------------------------- /models/backbone/resnetv1b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from ..base_ops.DCNv2 import DeformConv 5 | 6 | __all__ = ['ResNetV1b', 'resnet18_v1b', 'resnet34_v1b', 'resnet50_v1b', 7 | 'resnet101_v1b', 'resnet152_v1b', 'resnet152_v1s', 'resnet101_v1s', 'resnet50_v1s'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | class BasicBlockV1b(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, 22 | previous_dilation=1, norm_layer=nn.BatchNorm2d): 23 | super(BasicBlockV1b, self).__init__() 24 | self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 25 | dilation, dilation, bias=False) 26 | self.bn1 = norm_layer(planes) 27 | self.relu = nn.ReLU(True) 28 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, previous_dilation, 29 | dilation=previous_dilation, bias=False) 30 | self.bn2 = norm_layer(planes) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | identity = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | identity = self.downsample(x) 46 | 47 | out += identity 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class BottleneckV1b(nn.Module): 54 | expansion = 4 55 | 56 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, 57 | previous_dilation=1, norm_layer=nn.BatchNorm2d, dcn=False): 58 | super(BottleneckV1b, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 60 | self.bn1 = norm_layer(planes) 61 | self.dcn=dcn 62 | if dcn: 63 | self.conv2 = DeformConv(planes, planes, 3, stride, dilation, dilation, bias=False) 64 | else: 65 | self.conv2 = nn.Conv2d(planes, planes, 3, stride, dilation, dilation, bias=False) 66 | self.bn2 = norm_layer(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 68 | self.bn3 = norm_layer(planes * self.expansion) 69 | self.relu = nn.ReLU(True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | 74 | def forward(self, x): 75 | identity = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | identity = self.downsample(x) 90 | 91 | out += identity 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNetV1b(nn.Module): 98 | 99 | def __init__(self, block, layers, num_classes=1000, dilated=True, deep_stem=False, 100 | zero_init_residual=False, norm_layer=nn.BatchNorm2d): 101 | self.inplanes = 128 if deep_stem else 64 102 | super(ResNetV1b, self).__init__() 103 | if deep_stem: 104 | self.conv1 = nn.Sequential( 105 | nn.Conv2d(3, 64, 3, 2, 1, bias=False), 106 | norm_layer(64), 107 | nn.ReLU(True), 108 | nn.Conv2d(64, 64, 3, 1, 1, bias=False), 109 | norm_layer(64), 110 | nn.ReLU(True), 111 | nn.Conv2d(64, 128, 3, 1, 1, bias=False) 112 | ) 113 | else: 114 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 115 | use_dcn=[False,False,False,False] 116 | #use_dcn = [False, False, False, False] 117 | self.bn1 = norm_layer(self.inplanes) 118 | self.relu = nn.ReLU(True) 119 | self.maxpool = nn.MaxPool2d(3, 2, 1) 120 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer,dcn=use_dcn[0]) 121 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer,dcn=use_dcn[1]) 122 | if dilated: 123 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer,dcn=use_dcn[2]) 124 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer,dcn=use_dcn[3]) 125 | else: 126 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer,dcn=use_dcn[2]) 127 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer,dcn=use_dcn[3]) 128 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 129 | self.fc = nn.Linear(512 * block.expansion, num_classes) 130 | 131 | for m in self.modules(): 132 | if isinstance(m, nn.Conv2d): 133 | pass 134 | 135 | elif isinstance(m, nn.BatchNorm2d): 136 | nn.init.constant_(m.weight, 1) 137 | nn.init.constant_(m.bias, 0) 138 | 139 | if zero_init_residual: 140 | for m in self.modules(): 141 | if isinstance(m, BottleneckV1b): 142 | nn.init.kaiming_normal_(m.conv1.weight, mode='fan_out', nonlinearity='relu') 143 | nn.init.kaiming_normal_(m.conv2.weight, mode='fan_out', nonlinearity='relu') 144 | nn.init.kaiming_normal_(m.conv3.weight, mode='fan_out', nonlinearity='relu') 145 | nn.init.constant_(m.bn3.weight, 0) 146 | elif isinstance(m, BasicBlockV1b): 147 | nn.init.constant_(m.bn2.weight, 0) 148 | 149 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d,dcn=False): 150 | downsample = None 151 | if stride != 1 or self.inplanes != planes * block.expansion: 152 | downsample = nn.Sequential( 153 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False), 154 | norm_layer(planes * block.expansion), 155 | ) 156 | 157 | layers = [] 158 | if dilation in (1, 2): 159 | layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, 160 | previous_dilation=dilation, norm_layer=norm_layer,dcn=dcn)) 161 | elif dilation == 4: 162 | layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, 163 | previous_dilation=dilation, norm_layer=norm_layer,dcn=dcn)) 164 | else: 165 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 166 | self.inplanes = planes * block.expansion 167 | for _ in range(1, blocks): 168 | layers.append(block(self.inplanes, planes, dilation=dilation, 169 | previous_dilation=dilation, norm_layer=norm_layer,dcn=dcn)) 170 | 171 | return nn.Sequential(*layers) 172 | 173 | def forward(self, x): 174 | x = self.conv1(x) 175 | x = self.bn1(x) 176 | x = self.relu(x) 177 | x = self.maxpool(x) 178 | 179 | x = self.layer1(x) 180 | x = self.layer2(x) 181 | x = self.layer3(x) 182 | x = self.layer4(x) 183 | 184 | x = self.avgpool(x) 185 | x = x.view(x.size(0), -1) 186 | x = self.fc(x) 187 | 188 | return x 189 | 190 | 191 | def resnet18_v1b(pretrained=False, **kwargs): 192 | model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], **kwargs) 193 | if pretrained: 194 | old_dict = model_zoo.load_url(model_urls['resnet18']) 195 | model_dict = model.state_dict() 196 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 197 | model_dict.update(old_dict) 198 | model.load_state_dict(model_dict) 199 | return model 200 | 201 | 202 | def resnet34_v1b(pretrained=False, **kwargs): 203 | model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) 204 | if pretrained: 205 | old_dict = model_zoo.load_url(model_urls['resnet34']) 206 | model_dict = model.state_dict() 207 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 208 | model_dict.update(old_dict) 209 | model.load_state_dict(model_dict) 210 | return model 211 | 212 | 213 | def resnet50_v1b(pretrained=False, **kwargs): 214 | model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], **kwargs) 215 | if pretrained: 216 | old_dict = model_zoo.load_url(model_urls['resnet50']) 217 | model_dict = model.state_dict() 218 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 219 | model_dict.update(old_dict) 220 | model.load_state_dict(model_dict) 221 | return model 222 | 223 | 224 | def resnet101_v1b(pretrained=False, **kwargs): 225 | model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], **kwargs) 226 | if pretrained: 227 | old_dict = model_zoo.load_url(model_urls['resnet101']) 228 | model_dict = model.state_dict() 229 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 230 | model_dict.update(old_dict) 231 | model.load_state_dict(model_dict) 232 | return model 233 | 234 | 235 | def resnet152_v1b(pretrained=False, **kwargs): 236 | model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], **kwargs) 237 | if pretrained: 238 | old_dict = model_zoo.load_url(model_urls['resnet152']) 239 | model_dict = model.state_dict() 240 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 241 | model_dict.update(old_dict) 242 | model.load_state_dict(model_dict) 243 | return model 244 | 245 | 246 | def resnet50_v1s(pretrained=False, root='~/.torch/models', **kwargs): 247 | model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, **kwargs) 248 | if pretrained: 249 | from ..model_store import get_resnet_file 250 | model.load_state_dict(torch.load(get_resnet_file('resnet50', root=root)), strict=False) 251 | return model 252 | 253 | def resnet101_v1s(pretrained=False, root='~/.torch/models', **kwargs): 254 | model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, **kwargs) 255 | if pretrained: 256 | from ..model_store import get_resnet_file 257 | model.load_state_dict(torch.load(get_resnet_file('resnet101', root=root)), strict=False) 258 | return model 259 | 260 | 261 | def resnet152_v1s(pretrained=False, root='~/.torch/models', **kwargs): 262 | model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, **kwargs) 263 | if pretrained: 264 | from ..model_store import get_resnet_file 265 | model.load_state_dict(torch.load(get_resnet_file('resnet152', root=root)), strict=False) 266 | return model 267 | 268 | 269 | if __name__ == '__main__': 270 | import torch 271 | img = torch.randn(4, 3, 224, 224) 272 | model = resnet50_v1s(True) 273 | 274 | -------------------------------------------------------------------------------- /losses/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from scipy.ndimage import distance_transform_edt 7 | 8 | def make_one_hot(input, num_classes=None): 9 | """Convert class index tensor to one hot encoding tensor. 10 | Args: 11 | input: A tensor of shape [N, 1, *] 12 | num_classes: An int of number of class 13 | Shapes: 14 | predict: A tensor of shape [N, *] without sigmoid activation function applied 15 | target: A tensor of shape same with predict 16 | Returns: 17 | A tensor of shape [N, num_classes, *] 18 | """ 19 | if num_classes is None: 20 | num_classes = input.max() + 1 21 | shape = np.array(input.shape) 22 | shape[1] = num_classes 23 | shape = tuple(shape) 24 | result = torch.zeros(shape) 25 | result = result.scatter_(1, input.cpu().long(), 1) 26 | return result 27 | class logcosh_DICE(nn.Module): 28 | def __init__(self, ignore_index=None, reduction='mean',**kwargs): 29 | super(logcosh_DICE, self).__init__() 30 | 31 | class BinaryDiceLoss(nn.Module): 32 | """Dice loss of binary class 33 | Args: 34 | ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient 35 | reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' 36 | Shapes: 37 | output: A tensor of shape [N, *] without sigmoid activation function applied 38 | target: A tensor of shape same with output 39 | Returns: 40 | Loss tensor according to arg reduction 41 | Raise: 42 | Exception if unexpected reduction 43 | """ 44 | 45 | def __init__(self, ignore_index=None, reduction='mean',**kwargs): 46 | super(BinaryDiceLoss, self).__init__() 47 | self.smooth = 1 # suggest set a large number when target area is large,like '10|100' 48 | self.ignore_index = ignore_index 49 | self.reduction = reduction 50 | self.batch_dice = False # treat a large map when True 51 | if 'batch_loss' in kwargs.keys(): 52 | self.batch_dice = kwargs['batch_loss'] 53 | 54 | def forward(self, output, target, use_sigmoid=True): 55 | assert output.shape[0] == target.shape[0], f"output & target batch size don't match {output.shape[0]} {target.shape[0]} " 56 | if use_sigmoid: 57 | output = torch.sigmoid(output) 58 | 59 | if self.ignore_index is not None: 60 | validmask = (target != self.ignore_index).float() 61 | output = output.mul(validmask) # can not use inplace for bp 62 | target = target.float().mul(validmask) 63 | 64 | dim0= output.shape[0] 65 | if self.batch_dice: 66 | dim0 = 1 67 | 68 | output = output.contiguous().view(dim0, -1) 69 | target = target.contiguous().view(dim0, -1).float() 70 | 71 | num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth 72 | den = torch.sum(output.abs() + target.abs(), dim=1) + self.smooth 73 | 74 | loss = 1 - (num / den) 75 | 76 | if self.reduction == 'mean': 77 | return loss.mean() 78 | elif self.reduction == 'sum': 79 | return loss.sum() 80 | elif self.reduction == 'none': 81 | return loss 82 | else: 83 | raise Exception('Unexpected reduction {}'.format(self.reduction)) 84 | 85 | 86 | class DiceLoss(nn.Module): 87 | """Dice loss, need one hot encode input 88 | Args: 89 | weight: An array of shape [num_classes,] 90 | ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient 91 | output: A tensor of shape [N, C, *] 92 | target: A tensor of same shape with output 93 | other args pass to BinaryDiceLoss 94 | Return: 95 | same as BinaryDiceLoss 96 | """ 97 | 98 | def __init__(self, weight=None, ignore_index=None, **kwargs): 99 | super(DiceLoss, self).__init__() 100 | self.kwargs = kwargs 101 | self.weight = weight 102 | if isinstance(ignore_index, (int, float)): 103 | self.ignore_index = [int(ignore_index)] 104 | elif ignore_index is None: 105 | self.ignore_index = [] 106 | elif isinstance(ignore_index, (list, tuple)): 107 | self.ignore_index = ignore_index 108 | else: 109 | raise TypeError("Expect 'int|float|list|tuple', while get '{}'".format(type(ignore_index))) 110 | 111 | def forward(self, output, target): 112 | assert output.shape == target.shape, 'output & target shape do not match' 113 | dice = BinaryDiceLoss(**self.kwargs) 114 | total_loss = 0 115 | output = F.softmax(output, dim=1) 116 | for i in range(target.shape[1]): 117 | if i not in self.ignore_index: 118 | dice_loss = dice(output[:, i], target[:, i], use_sigmoid=False) 119 | if self.weight is not None: 120 | assert self.weight.shape[0] == target.shape[1], \ 121 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) 122 | dice_loss *= self.weights[i] 123 | total_loss += (dice_loss) 124 | loss = total_loss / (target.size(1) - len(self.ignore_index)) 125 | return loss 126 | 127 | 128 | class WBCEWithLogitLoss(nn.Module): 129 | """ 130 | Weighted Binary Cross Entropy. 131 | `WBCE(p,t)=-β*t*log(p)-(1-t)*log(1-p)` 132 | To decrease the number of false negatives, set β>1. 133 | To decrease the number of false positives, set β<1. 134 | Args: 135 | @param weight: positive sample weight 136 | Shapes: 137 | output: A tensor of shape [N, 1,(d,), h, w] without sigmoid activation function applied 138 | target: A tensor of shape same with output 139 | """ 140 | 141 | def __init__(self, weight=1.0, ignore_index=None, reduction='mean'): 142 | super(WBCEWithLogitLoss, self).__init__() 143 | assert reduction in ['none', 'mean', 'sum'] 144 | self.ignore_index = ignore_index 145 | weight = float(weight) 146 | self.weight = weight 147 | self.reduction = reduction 148 | self.smooth = 0.01 149 | 150 | def forward(self, output, target): 151 | assert output.shape[0] == target.shape[0], "output & target batch size don't match" 152 | 153 | if self.ignore_index is not None: 154 | valid_mask = (target != self.ignore_index).float() 155 | output = output.mul(valid_mask) # can not use inplace for bp 156 | target = target.float().mul(valid_mask) 157 | 158 | batch_size = output.size(0) 159 | output = output.view(batch_size, -1) 160 | target = target.view(batch_size, -1) 161 | 162 | output = torch.sigmoid(output) 163 | # avoid `nan` loss 164 | eps = 1e-6 165 | output = torch.clamp(output, min=eps, max=1.0 - eps) 166 | # soft label 167 | target = torch.clamp(target, min=self.smooth, max=1.0 - self.smooth) 168 | 169 | # loss = self.bce(output, target) 170 | loss = -self.weight * target.mul(torch.log(output)) - ((1.0 - target).mul(torch.log(1.0 - output))) 171 | if self.reduction == 'mean': 172 | loss = torch.mean(loss) 173 | elif self.reduction == 'sum': 174 | loss = torch.sum(loss) 175 | elif self.reduction == 'none': 176 | loss = loss 177 | else: 178 | raise NotImplementedError 179 | return loss 180 | 181 | 182 | class WBCE_DiceLoss(nn.Module): 183 | def __init__(self, alpha=1.0, weight=1.0, ignore_index=None, reduction='mean'): 184 | """ 185 | combination of Weight Binary Cross Entropy and Binary Dice Loss 186 | Args: 187 | @param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient 188 | @param reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' 189 | @param alpha: weight between WBCE('Weight Binary Cross Entropy') and binary dice, apply on WBCE 190 | Shapes: 191 | output: A tensor of shape [N, *] without sigmoid activation function applied 192 | target: A tensor of shape same with output 193 | """ 194 | super(WBCE_DiceLoss, self).__init__() 195 | assert reduction in ['none', 'mean', 'sum'] 196 | assert 0 <= alpha <= 1, '`alpha` should in [0,1]' 197 | self.alpha = alpha 198 | self.ignore_index = ignore_index 199 | self.reduction = reduction 200 | self.dice = BinaryDiceLoss(ignore_index=ignore_index, reduction=reduction, general=True) 201 | self.wbce = WBCEWithLogitLoss(weight=weight, ignore_index=ignore_index, reduction=reduction) 202 | self.dice_loss = None 203 | self.wbce_loss = None 204 | 205 | def forward(self, output, target): 206 | self.dice_loss = self.dice(output, target) 207 | self.wbce_loss = self.wbce(output, target) 208 | loss = self.alpha * self.wbce_loss + self.dice_loss 209 | return loss 210 | 211 | 212 | 213 | def compute_edts_forPenalizedLoss(GT): 214 | """ 215 | GT.shape = (batch_size, x,y,z) 216 | only for binary segmentation 217 | """ 218 | res = np.zeros(GT.shape,dtype=np.float32) 219 | for i in range(GT.shape[0]): 220 | 221 | posmask = GT[i] 222 | negmask = ~posmask 223 | pos_edt = distance_transform_edt(posmask) 224 | pos_edt = (np.max(pos_edt) - pos_edt) * posmask 225 | neg_edt = distance_transform_edt(negmask) 226 | neg_edt = (np.max(neg_edt) - neg_edt) * negmask 227 | res[i] = pos_edt / np.max(pos_edt) + neg_edt / np.max(neg_edt) 228 | 229 | return res 230 | 231 | 232 | 233 | class DistBinaryDiceLoss(nn.Module): 234 | """ 235 | Distance map penalized Dice loss 236 | Motivated by: https://openreview.net/forum?id=B1eIcvS45V 237 | Distance Map Loss Penalty Term for Semantic Segmentation 238 | """ 239 | 240 | def __init__(self, smooth=1e-5): 241 | super(DistBinaryDiceLoss, self).__init__() 242 | self.smooth = smooth 243 | 244 | def forward(self, net_output, gt): 245 | """ 246 | net_output: (batch_size, 2, x,y,z) 247 | target: ground truth, shape: (batch_size, 1, x,y,z) 248 | """ 249 | net_output = torch.sigmoid(net_output) 250 | # one hot code for gt 251 | with torch.no_grad(): 252 | if len(net_output.shape) != len(gt.shape): 253 | gt = gt.view((gt.shape[0], 1, *gt.shape[1:])) 254 | 255 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 256 | # if this is the case then gt is probably already a one hot encoding 257 | y_onehot = gt 258 | else: 259 | gt = gt.long() 260 | y_onehot = torch.zeros(net_output.shape) 261 | if net_output.device.type == "cuda": 262 | y_onehot = y_onehot.cuda(net_output.device.index) 263 | y_onehot.scatter_(1, gt, 1) 264 | 265 | gt_temp = gt[:, 0, ...].type(torch.float32) 266 | with torch.no_grad(): 267 | dist = compute_edts_forPenalizedLoss(gt_temp.cpu().numpy() > 0.5) + 1.0 268 | # print('dist.shape: ', dist.shape) 269 | dist = torch.from_numpy(dist) 270 | 271 | if dist.device != net_output.device: 272 | dist = dist.to(net_output.device).type(torch.float32) 273 | 274 | tp = net_output * y_onehot 275 | tp = torch.sum(tp[:, 1, ...] * dist, (1, 2, 3)) 276 | 277 | dc = (2 * tp + self.smooth) / (torch.sum(net_output[:, 1, ...], (1, 2, 3)) + torch.sum(y_onehot[:, 1, ...], (1, 2, 3)) + self.smooth) 278 | 279 | dc = dc.mean() 280 | 281 | return -dc 282 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | from torchvision import transforms 4 | from tifffile import TiffFile 5 | from PIL import Image 6 | import albumentations as A 7 | import torch 8 | import numpy as np 9 | 10 | from skimage.util.shape import view_as_windows 11 | import scipy.io as scio 12 | from .io import make_dataset 13 | from ._aug import get_augmentation 14 | from .imop import get_ins_info,gaussian_radius,draw_gaussian 15 | 16 | import matplotlib.pyplot as plt 17 | 18 | class NucleiDataset(data.Dataset): 19 | def __init__(self,config,seed, is_train,output_stride=4): 20 | stain_norm=config['stainnorm'] 21 | data_root=config['dataroot'] 22 | img_size=256 23 | self.grid_size=img_size//output_stride 24 | 25 | if not config['image_norm_mean']: 26 | _mean=(0.5,0.5,0.5) 27 | _std=(0.5,0.5,0.5) 28 | else: 29 | _mean = tuple(config['image_norm_mean']) 30 | _std = tuple(config['image_norm_std']) 31 | self.transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(_mean, _std)]) 32 | print(f'Stain Norm Type: {stain_norm}, Trans mean std: {_mean}, {_std}.') 33 | 34 | self.phase='train' if is_train else 'test' 35 | 36 | self.img_dir_path = os.path.join(data_root,self.phase, 'Images') if stain_norm is None else os.path.join(data_root, self.phase, f'Images_{stain_norm}') 37 | self.img_paths = sorted(make_dataset(self.img_dir_path)) 38 | self.gt_dir_path = os.path.join(data_root,self.phase, 'Labels') 39 | self.gt_paths = sorted(make_dataset(self.gt_dir_path)) 40 | self.num_class=config['model']['num_classes'] 41 | self.use_class= config['model']['num_classes']>2 and 'CoNSeP' in config['dataroot'] 42 | print(f'use classes {self.use_class}') 43 | 44 | if 'cpm' in config['dataroot'].lower(): 45 | self.images,self.masks,self.labels=self.load_cpm17(self.img_paths,self.gt_paths) 46 | else: 47 | self.images,self.masks,self.labels=self.load_crop_data(self.img_paths,self.gt_paths) 48 | 49 | self._size =self.images.shape[0] 50 | self.setup_augmentor(seed) 51 | 52 | def load_cpm17(self, img_paths, gt_paths, patch_size=384, stride=128): 53 | out_imgs = [] 54 | out_masks = [] 55 | out_labels = [] 56 | for ip, mp in zip(img_paths, gt_paths): 57 | assert os.path.basename(ip)[:-4] == os.path.basename(mp)[:-4] 58 | if '.tif' in ip: 59 | with TiffFile(ip) as t: 60 | im=t.asarray() 61 | else: 62 | im = np.array(Image.open(ip).convert('RGB')) 63 | matfile=scio.loadmat(mp) 64 | mk=matfile['inst_map'].astype(np.int16) 65 | 66 | if self.use_class: 67 | lbl=matfile['inst_type'][:,0].astype(np.uint8) 68 | lbl[lbl == 4] = 3 69 | lbl[lbl == 5] = 4 70 | lbl[lbl == 6] = 4 71 | lbl[lbl == 7] = 4 72 | else: 73 | lbl=[1]*(np.max(mk)) 74 | hs=list(range(0,im.shape[0]-patch_size,patch_size))+[im.shape[0]-patch_size] 75 | ws=list(range(0,im.shape[1]-patch_size,patch_size))+[im.shape[1]-patch_size] 76 | ims=[] 77 | mks=[] 78 | for h in hs: 79 | for w in ws: 80 | ims.append(im[h:h+patch_size,w:w+patch_size]) 81 | mks.append(mk[h:h + patch_size, w:w + patch_size]) 82 | ims=np.stack(ims,0) 83 | mks=np.stack(mks,0) 84 | out_imgs.append(ims) 85 | out_masks.append(mks) 86 | for idx in range(mks.shape[0]): 87 | tmk=mks[idx] 88 | olabel = {} 89 | for ui in np.unique(tmk): 90 | if ui==0: 91 | continue 92 | olabel[ui]=lbl[ui-1] 93 | out_labels.append(olabel) 94 | out_imgs = np.concatenate(out_imgs) 95 | out_masks = np.concatenate(out_masks) 96 | assert len(out_imgs.shape) == 4 and out_imgs.dtype == np.uint8 97 | 98 | print(f'processed data with size {len(out_imgs)} & {len(out_labels)}') 99 | return out_imgs, out_masks, out_labels 100 | 101 | def load_crop_data(self, img_paths, gt_paths, patch_size=384, stride=128): 102 | out_imgs = [] 103 | out_masks = [] 104 | out_labels = [] 105 | resize = A.Resize(p=1, height=1024, width=1024) 106 | for ip, mp in zip(img_paths, gt_paths): 107 | assert os.path.basename(ip)[:-4] == os.path.basename(mp)[:-4] 108 | if '.tif' in ip: 109 | with TiffFile(ip) as t: 110 | im=t.asarray() 111 | else: 112 | im = np.array(Image.open(ip).convert('RGB')) 113 | matfile=scio.loadmat(mp) 114 | mk=matfile['inst_map'].astype(np.int16) 115 | if self.use_class: 116 | lbl=matfile['inst_type'][:,0].astype(np.uint8) 117 | lbl[lbl == 4] = 3 118 | lbl[lbl == 5] = 4 119 | lbl[lbl == 6] = 4 120 | lbl[lbl == 7] = 4 121 | else: 122 | lbl=[1]*(np.max(mk)) 123 | augmented = resize(image=im, mask=mk) 124 | im = augmented['image'] 125 | mk = augmented['mask'] 126 | ims = view_as_windows(im, (patch_size, patch_size, 3), (stride, stride, 3)).reshape((-1, patch_size, patch_size, 3)) 127 | mks = view_as_windows(mk, (patch_size, patch_size), (stride, stride)).reshape((-1, patch_size, patch_size)) 128 | out_imgs.append(ims) 129 | out_masks.append(mks) 130 | for idx in range(mks.shape[0]): 131 | tmk=mks[idx] 132 | olabel = {} 133 | for ui in np.unique(tmk): 134 | if ui==0: 135 | continue 136 | olabel[ui]=lbl[ui-1] 137 | out_labels.append(olabel) 138 | 139 | out_imgs = np.concatenate(out_imgs) 140 | out_masks = np.concatenate(out_masks) 141 | assert len(out_imgs.shape) == 4 and out_imgs.dtype == np.uint8 142 | 143 | print(f'processed data with size {len(out_imgs)} & {len(out_labels)}') 144 | return out_imgs, out_masks, out_labels 145 | 146 | def setup_augmentor(self, seed): 147 | self.shape_augs, self.input_augs = get_augmentation(self.phase, seed) 148 | 149 | def __getitem__(self, index): 150 | img = self.images[index] 151 | mask = self.masks[index] 152 | label_dic = self.labels[index] 153 | 154 | shape_augs = self.shape_augs.to_deterministic() 155 | img = shape_augs.augment_image(img) 156 | masks = shape_augs.augment_image(mask) 157 | 158 | input_augs = self.input_augs.to_deterministic() 159 | img = input_augs.augment_image(img) 160 | 161 | cate_labels=[] 162 | ins_labels=[] 163 | 164 | for i, ui in enumerate(np.unique(masks)): 165 | if ui ==0: 166 | assert i==ui 167 | continue 168 | tmp_mask=masks==ui 169 | label=label_dic[ui] 170 | ins_labels.append(((tmp_mask)*1).astype(np.int32)) 171 | cate_labels.append(label) 172 | 173 | if len(cate_labels)>0: 174 | cate_labels, ins_labels, ins_ind_labels= self.process_label(np.array(cate_labels), np.array(ins_labels)) 175 | cate_labels=torch.from_numpy(np.array(cate_labels)).float() 176 | ins_labels=torch.from_numpy(ins_labels) 177 | ins_ind_labels=torch.from_numpy(ins_ind_labels).bool() 178 | else: 179 | cate_labels=torch.from_numpy(np.zeros([self.num_class - 1,64,64])).float() 180 | ins_labels=None 181 | ins_ind_labels=None 182 | image=self.transform(img) 183 | output={'image': image, 'cate_labels':cate_labels, 'ins_labels':ins_labels,'ins_ind_labels':ins_ind_labels} 184 | return output 185 | 186 | def process_label(self, gt_labels_raw, gt_masks_raw, iou_threshold=0.3, tau=0.5): 187 | w,h=256,256 188 | 189 | cate_label = np.zeros([self.num_class - 1, self.grid_size, self.grid_size], dtype=np.float) 190 | ins_label = np.zeros([self.grid_size ** 2, w, h], dtype=np.int16) 191 | ins_ind_label = np.zeros([self.grid_size ** 2], dtype=np.bool) 192 | if gt_masks_raw is not None: 193 | gt_labels = gt_labels_raw 194 | gt_masks = gt_masks_raw 195 | for seg_mask, gt_label in zip(gt_masks, gt_labels): 196 | center_w, center_h, width, height = get_ins_info(seg_mask, method='bbox') 197 | radius = max(gaussian_radius((width, height), iou_threshold), 0) 198 | coord_h = int((center_h / h) / (1. / self.grid_size)) 199 | coord_w = int((center_w / w) / (1. / self.grid_size)) 200 | temp = draw_gaussian(cate_label[gt_label - 1], (coord_w, coord_h), (radius / 4)) 201 | non_zeros = (temp > tau).nonzero() 202 | label = non_zeros[0] * self.grid_size + non_zeros[1] # label = int(coord_h * grid_size + coord_w)# 203 | ins_label[label, :, :] = seg_mask 204 | ins_ind_label[label] = True 205 | ins_label=np.stack(ins_label[ins_ind_label],0) 206 | return cate_label,ins_label,ins_ind_label 207 | 208 | def __len__(self): 209 | return self._size 210 | 211 | class PannukeDataset(data.Dataset): 212 | """ 213 | img_path: original image 214 | masks: one-hot masks 215 | GT: tiff mask, one channel denote one instance 216 | """ 217 | def __init__(self, data_root, is_train,seed=888,fold=1,output_stride=4): 218 | self.grid_size=256//output_stride 219 | 220 | self.images,self.labels,self.masks= self.load_pannuke(data_root,fold) 221 | self.transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]) 222 | self.num_class=6 223 | self.A_size =self.images.shape[0] 224 | 225 | self.mode='train' if is_train else "test" 226 | self.setup_augmentor(seed) 227 | 228 | def setup_augmentor(self, seed): 229 | self.shape_augs, self.input_augs = get_augmentation(self.mode, seed) 230 | 231 | def load_pannuke(self, data_root,fold=1): 232 | out_labels = [] 233 | out_masks = [] 234 | out_imgs=np.load(os.path.join(data_root,f'images/fold{fold}/images.npy')).astype(np.uint8)#(2523, 256, 256, 3) 235 | masks=np.load(os.path.join(data_root,f'masks/fold{fold}/masks.npy')).astype(np.int16)#(2523, 256, 256, 6) 236 | for i in range(masks.shape[0]): 237 | tmask=masks[i] 238 | olabel={} 239 | omask=np.zeros((256,256),dtype=np.int16) 240 | for j in range(5): 241 | ids=np.unique(tmask[:,:,j]) 242 | if len(ids) ==1: 243 | continue 244 | else: 245 | for id in ids: 246 | if id==0:continue 247 | omask[tmask[:,:,j]==id]=id 248 | olabel[id]=j+1 249 | out_masks.append(omask) 250 | out_labels.append(olabel) 251 | out_masks = np.stack(out_masks,0) 252 | assert len(out_imgs.shape) == 4 and out_imgs.dtype == np.uint8 253 | assert len(out_masks.shape) == 3 and out_masks.dtype == np.int16, f'{out_masks.shape}, {out_masks.dtype}' 254 | assert out_masks.shape[0]==out_imgs.shape[0] and out_imgs.shape[0]==len(out_labels) 255 | print(f'processed data with size {len(out_imgs)}') 256 | return out_imgs, out_labels, out_masks 257 | 258 | def __getitem__(self, index): 259 | img = self.images[index] 260 | mask = self.masks[index] 261 | label_dic = self.labels[index] 262 | 263 | shape_augs = self.shape_augs.to_deterministic() 264 | img = shape_augs.augment_image(img) 265 | masks = shape_augs.augment_image(mask) 266 | 267 | input_augs = self.input_augs.to_deterministic() 268 | img = input_augs.augment_image(img) 269 | 270 | cate_labels=[] 271 | ins_labels=[] 272 | 273 | for i, ui in enumerate(np.unique(masks)): 274 | if ui ==0: 275 | assert i==ui 276 | continue 277 | tmp_mask=masks==ui 278 | label=label_dic[ui] 279 | ins_labels.append(((tmp_mask)*1).astype(np.int32)) 280 | cate_labels.append(label) 281 | 282 | if len(cate_labels)>0: 283 | cate_labels, ins_labels, ins_ind_labels= self.process_label(np.array(cate_labels), np.array(ins_labels)) 284 | cate_labels=torch.from_numpy(np.array(cate_labels)).float() 285 | ins_labels=torch.from_numpy(ins_labels) 286 | ins_ind_labels=torch.from_numpy(ins_ind_labels).bool() 287 | else: 288 | cate_labels=torch.from_numpy(np.zeros([self.num_class - 1,self.grid_size,self.grid_size])).float() 289 | ins_labels=None 290 | ins_ind_labels=None 291 | image=self.transform(img) 292 | output={'image': image, 'cate_labels':cate_labels, 'ins_labels':ins_labels,'ins_ind_labels':ins_ind_labels} 293 | return output 294 | 295 | def process_label(self, gt_labels_raw, gt_masks_raw, iou_threshold=0.3, tau=0.5): 296 | w,h=256,256 297 | cate_label = np.zeros([self.num_class - 1, self.grid_size, self.grid_size], dtype=np.float) 298 | ins_label = np.zeros([self.grid_size ** 2, w, h], dtype=np.int16) 299 | ins_ind_label = np.zeros([self.grid_size ** 2], dtype=np.bool) 300 | if gt_masks_raw is not None: 301 | gt_labels = gt_labels_raw 302 | gt_masks = gt_masks_raw 303 | for seg_mask, gt_label in zip(gt_masks, gt_labels): 304 | center_w, center_h, width, height = get_ins_info(seg_mask, method='bbox') 305 | radius = max(gaussian_radius((width, height), iou_threshold), 0) 306 | coord_h = int((center_h / h) / (1. / self.grid_size)) 307 | coord_w = int((center_w / w) / (1. / self.grid_size)) 308 | temp = draw_gaussian(cate_label[gt_label - 1], (coord_w, coord_h), (radius / 4)) 309 | non_zeros = (temp > tau).nonzero() 310 | label = non_zeros[0] * self.grid_size + non_zeros[1] # label = int(coord_h * grid_size + coord_w)# 311 | cate_label[gt_label - 1, coord_h, coord_w] = 1 312 | label = int(coord_h * self.grid_size + coord_w) # 313 | ins_label[label, :, :] = seg_mask 314 | ins_ind_label[label] = True 315 | ins_label=np.stack(ins_label[ins_ind_label],0) 316 | return cate_label, ins_label, ins_ind_label 317 | 318 | def __len__(self): 319 | return self.A_size 320 | --------------------------------------------------------------------------------