├── img
├── readme.md
├── HED_unit.png
├── BE_module.png
├── Tail_part.png
└── Boundary_Enhancement_Semantic_Segmentation.pdf
├── src
├── __init__.py
├── inference.py
└── train.py
├── notebooks
├── __init__.py
└── data_prep.ipynb
├── utils
├── __init__.py
├── log.py
├── config.py
├── core.py
├── data.py
└── io.py
├── nets
├── zoo
│ ├── hrnet_config.py
│ ├── hrnet.yml
│ ├── __init__.py
│ ├── unet.py
│ ├── ternaus.py
│ ├── unet_BE.py
│ ├── brrnet.py
│ ├── uspp.py
│ ├── denet.py
│ ├── resunet.py
│ ├── uspp_BE.py
│ ├── brrnet_BE.py
│ ├── ternaus_BE.py
│ ├── resunet_BE.py
│ └── enru.py
├── __init__.py
├── assembly_block.py
├── callbacks.py
├── losses.py
├── model_io.py
├── optimizers.py
├── infer.py
├── datagen.py
├── torch_callbacks.py
└── _torch_losses.py
├── yml
├── infer.yml
└── train.yml
├── README.md
└── LICENSE
/img/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/notebooks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from . import config, core, io, data
2 |
--------------------------------------------------------------------------------
/img/HED_unit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HoinJung/BEmodule-Satellite-Building-Segmentation/HEAD/img/HED_unit.png
--------------------------------------------------------------------------------
/img/BE_module.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HoinJung/BEmodule-Satellite-Building-Segmentation/HEAD/img/BE_module.png
--------------------------------------------------------------------------------
/img/Tail_part.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HoinJung/BEmodule-Satellite-Building-Segmentation/HEAD/img/Tail_part.png
--------------------------------------------------------------------------------
/img/Boundary_Enhancement_Semantic_Segmentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HoinJung/BEmodule-Satellite-Building-Segmentation/HEAD/img/Boundary_Enhancement_Semantic_Segmentation.pdf
--------------------------------------------------------------------------------
/nets/zoo/hrnet_config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import sys
3 | import os
4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname("__file__"))))
5 |
6 | def parse(path):
7 |
8 | with open(path, 'r') as f:
9 | config = yaml.safe_load(f)
10 | f.close()
11 | return config
12 |
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | weights_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)),
4 | 'weights')
5 |
6 | from . import callbacks, datagen, infer, losses, model_io
7 | from . import optimizers , losses, model_io, train
8 |
9 | if not os.path.isdir(weights_dir):
10 | os.mkdir(weights_dir)
11 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ['CUDA_VISIBLE_DEVICES']='1'
4 | import sys
5 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname("__file__"))))
6 |
7 | import nets
8 | import utils
9 |
10 |
11 | config_path = '../yml/infer.yml'
12 | config = utils.config.parse(config_path)
13 | # print('Config:')
14 | # print(config)
15 |
16 | # make infernce output dir
17 | # os.makedirs(os.path.dirname(config['inference']['output_dir']), exist_ok=True)
18 |
19 | inferer = nets.infer.Inferer(config)
20 | inferer()
21 | # inferer.Inferer()
22 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname("__file__"))))
5 |
6 | import nets
7 | import utils
8 | import time
9 | os.environ['CUDA_VISIBLE_DEVICES']='0,1'
10 |
11 | config_path = '../yml/train.yml'
12 | config = utils.config.parse(config_path)
13 |
14 | # make model output dir
15 |
16 | os.makedirs(os.path.dirname(config['training']['callbacks']['model_checkpoint']['filepath']), exist_ok=True)
17 | start_time = str(int(time.time()))
18 | config['start_time'] = start_time
19 | trainer = nets.train.Trainer(config=config)
20 | trainer.train()
21 |
--------------------------------------------------------------------------------
/utils/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def _get_logging_level(level_int):
5 | """Convert a logging level integer into a log level."""
6 | if isinstance(level_int, bool):
7 | level_int = int(level_int)
8 | if level_int < 0:
9 | return logging.CRITICAL + 1 # silence all possible outputs
10 | elif level_int == 0:
11 | return logging.WARNING
12 | elif level_int == 1:
13 | return logging.INFO
14 | elif level_int == 2:
15 | return logging.DEBUG
16 | elif level_int in [10, 20, 30, 40, 50]: # if user provides the logger int
17 | return level_int
18 | elif isinstance(level_int, int): # if it's an int but not one of the above
19 | return level_int
20 | else:
21 | raise ValueError(f"logging level set to {level_int}, "
22 | "but it must be an integer <= 2.")
23 |
--------------------------------------------------------------------------------
/nets/assembly_block.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from ._torch_losses import torch_losses
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Variable
7 | import torch
8 | import skimage
9 |
10 |
11 | def assembly_block(mask_64, mask):
12 |
13 | # obtain boundray from 1-channel mask
14 | arr_mask = mask.cpu().detach().numpy()
15 | mask_boundary_arr = skimage.segmentation.find_boundaries(arr_mask, mode='inner', background=0).astype(np.float32)
16 | mask_boundary = torch.from_numpy(mask_boundary_arr).cuda().float()
17 |
18 |
19 | # recall 64-chanel mask before final conv
20 | # mask_boundary_arr = skimage.segmentation.find_boundaries(mask_64, mode='inner', background=0).astype(np.float32)
21 | conv1 = nn.Conv2d(1,64,3,padding=1).cuda()
22 | conv2 = nn.Conv2d(128,64,3,padding=1).cuda()
23 | conv3 = nn.Conv2d(64,1,3,padding=1).cuda()
24 |
25 | x = Variable(mask_boundary, requires_grad=True)
26 | x = conv1(mask_boundary)
27 | x = torch.cat([mask_64, x], dim=1)
28 | x = conv2(x)
29 | x = conv3(x)
30 |
31 | return mask_boundary, x
32 |
--------------------------------------------------------------------------------
/nets/zoo/hrnet.yml:
--------------------------------------------------------------------------------
1 | # HRNET_32 :
2 | FINAL_CONV_KERNEL : 1
3 | STAGE1 :
4 | NUM_MODULES : 1
5 | NUM_BRANCHES : 1
6 | NUM_BLOCKS : [4]
7 | NUM_CHANNELS : [64]
8 | BLOCK : 'BOTTLENECK'
9 | FUSE_METHOD : 'SUM'
10 | STAGE2 :
11 | NUM_MODULES : 1
12 | NUM_BRANCHES : 2
13 | NUM_BLOCKS : [4,4]
14 | NUM_CHANNELS : [32,64]
15 | BLOCK : 'BASIC'
16 | FUSE_METHOD : 'SUM'
17 | STAGE3 :
18 | NUM_MODULES : 4
19 | NUM_BRANCHES : 3
20 | NUM_BLOCKS : [4,4,4]
21 | NUM_CHANNELS : [32,64,128]
22 | BLOCK : 'BASIC'
23 | FUSE_METHOD : 'SUM'
24 | STAGE4 :
25 | NUM_MODULES : 3
26 | NUM_BRANCHES : 4
27 | NUM_BLOCKS : [4,4,4,4]
28 | NUM_CHANNELS : [32,64,128,256]
29 | BLOCK : 'BASIC'
30 | FUSE_METHOD : 'SUM'
31 |
32 |
33 | # HRNET_32 :
34 | # FINAL_CONV_KERNEL : 1
35 | # STAGE1 :
36 | # NUM_MODULES : 1
37 | # NUM_BRANCHES : 1
38 | # NUM_BLOCKS : [4]
39 | # NUM_CHANNELS : [64]
40 | # BLOCK : 'BOTTLENECK'
41 | # FUSE_METHOD : 'SUM'
42 | # STAGE2 :
43 | # NUM_MODULES : 1
44 | # NUM_BRANCHES : 2
45 | # NUM_BLOCKS : [4,4]
46 | # NUM_CHANNELS : [32,64]
47 | # BLOCK : 'BASIC'
48 | # FUSE_METHOD : 'SUM'
49 | # STAGE3 :
50 | # NUM_MODULES : 4
51 | # NUM_BRANCHES : 3
52 | # NUM_BLOCKS : [4,4,4]
53 | # NUM_CHANNELS : [32,64,128]
54 | # BLOCK : 'BASIC'
55 | # FUSE_METHOD : 'SUM'
56 | # STAGE4 :
57 | # NUM_MODULES : 3
58 | # NUM_BRANCHES : 4
59 | # NUM_BLOCKS : [4,4,4,4]
60 | # NUM_CHANNELS : [32,64,128,256]
61 | # BLOCK : 'BASIC'
62 | # FUSE_METHOD : 'SUM'
63 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from nets import zoo
3 |
4 |
5 | def parse(path):
6 |
7 | with open(path, 'r') as f:
8 | config = yaml.safe_load(f)
9 | f.close()
10 |
11 | if not config['train'] and not config['infer']:
12 | raise ValueError('"train", "infer", or both must be true.')
13 | if config['train'] and config['train_csv_dir'] is None:
14 | raise ValueError('"train_csv_dir" must be provided if training.')
15 | if config['infer'] and config['inference_data_csv'] is None:
16 | raise ValueError('"inference_csv_dir" must be provided if "infer".')
17 |
18 | train_aoi = config['aoi']
19 |
20 | """ Custom AOI """
21 | if train_aoi == 2:
22 | aoi = 'AOI_2_Vegas'
23 | elif train_aoi == 3:
24 | aoi = 'AOI_3_Paris'
25 | elif train_aoi == 4:
26 | aoi = 'AOI_4_Shanghai'
27 | elif train_aoi == 5:
28 | aoi = 'AOI_5_Khartoum'
29 | elif train_aoi == 6:
30 | aoi = 'Urban3D'
31 | elif train_aoi == 7:
32 | aoi = 'WHU'
33 | elif train_aoi == 8:
34 | aoi = 'mass'
35 | elif train_aoi == 9:
36 | aoi = 'WHU_asia'
37 | config['get_aoi'] = aoi
38 |
39 | if config['training']['lr'] is not None:
40 | config['training']['lr'] = float(config['training']['lr'])
41 |
42 |
43 | if config['validation_augmentation'] is not None \
44 | and config['inference_augmentation'] is None:
45 | config['inference_augmentation'] = config['validation_augmentation']
46 |
47 | return config
48 |
--------------------------------------------------------------------------------
/yml/infer.yml:
--------------------------------------------------------------------------------
1 | model_name : unet_BE
2 | model_path: '../result/models_weight/'
3 | training_date : '1629037296'
4 | aoi : 6
5 | boundary : True
6 | inference:
7 | window_step_size_x:
8 | window_step_size_y:
9 | output_dir: '../result/infer/'
10 | weight_file : 'final.pth'
11 | train: false
12 | infer: true
13 | pretrained: false
14 | nn_framework: torch
15 | batch_size: 4
16 | data_specs:
17 | width: 512
18 | height: 512
19 | dtype:
20 | image_type: zscore
21 | rescale: false
22 | rescale_minima: auto
23 | rescale_maxima: auto
24 | channels: 3
25 | label_type: mask
26 | is_categorical: false
27 | mask_channels: 1
28 | val_holdout_frac: 0.1
29 | data_workers:
30 |
31 | training_data_csv:
32 | validation_data_csv:
33 | inference_data_csv: '../csvs/'
34 | training_augmentation:
35 | augmentations:
36 | p: 1.0
37 | shuffle: true
38 | validation_augmentation:
39 | augmentations:
40 | p: 1.0
41 | inference_augmentation:
42 | augmentations:
43 | p: 1.0
44 | training:
45 | epochs: 300
46 | steps_per_epoch:
47 | optimizer: Adam
48 | lr: 1e-4
49 | opt_args:
50 | loss:
51 | bcewithlogits:
52 | jaccard:
53 | loss_weights:
54 | bcewithlogits: 10
55 | jaccard: 2.5
56 | metrics:
57 | training: f1_score
58 | validation: f1_score
59 | checkpoint_frequency: 10
60 | callbacks:
61 | early_stopping:
62 | patience: 24
63 | model_checkpoint:
64 | filepath:
65 | monitor: val_loss
66 | lr_schedule:
67 | schedule_type: arbitrary
68 | schedule_dict:
69 | milestones:
70 | - 200
71 | gamma: 0.1
72 | model_dest_path:
73 | verbose: true
74 |
--------------------------------------------------------------------------------
/nets/callbacks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .torch_callbacks import torch_callback_dict
3 | import torch
4 |
5 |
6 | def get_callbacks(framework, config):
7 | callbacks = []
8 | if framework == 'torch':
9 | for callback, params in config['training']['callbacks'].items():
10 | if callback == 'lr_schedule':
11 | callbacks.append(get_lr_schedule(framework, config))
12 | else:
13 | callbacks.append(torch_callback_dict[callback](**params))
14 |
15 | return callbacks
16 |
17 |
18 | def get_lr_schedule(framework, config):
19 |
20 |
21 | schedule_type = config['training'][
22 | 'callbacks']['lr_schedule']['schedule_type']
23 | initial_lr = config['training']['lr']
24 | update_frequency = config['training']['callbacks']['lr_schedule'].get(
25 | 'update_frequency', 1)
26 | factor = config['training']['callbacks']['lr_schedule'].get(
27 | 'factor', 0)
28 | schedule_dict = config['training']['callbacks']['lr_schedule'].get(
29 | 'schedule_dict', None)
30 | if framework == 'torch':
31 | # just get the class itself to use; don't instantiate until the
32 | # optimizer has been created.
33 | if config['training'][
34 | 'callbacks']['lr_schedule']['schedule_type'] == 'linear':
35 | lr_scheduler = torch.optim.lr_scheduler.StepLR
36 | elif config['training'][
37 | 'callbacks']['lr_schedule']['schedule_type'] == 'exponential':
38 | lr_scheduler = torch.optim.lr_scheduler.ExponentialLR
39 | # elif config['training'][
40 | # 'callbacks']['lr_schedule']['schedule_type'] == 'arbitrary':
41 | # lr_scheduler = torch.optim.lr_scheduler.MultiStepLR
42 | elif config['training'][
43 | 'callbacks']['lr_schedule']['schedule_type'] == 'arbitrary':
44 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR
45 |
46 | elif config['training'][
47 | 'callbacks']['lr_schedule']['schedule_type'] == 'cycle':
48 | print("check callback")
49 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR
50 |
51 |
52 | return lr_scheduler
53 |
54 |
--------------------------------------------------------------------------------
/nets/zoo/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .. import weights_dir
3 |
4 | from .unet import UNet
5 | from .unet_BE import UNet_BE
6 | from .resunet import ResUnetPlusPlus
7 | from .resunet_BE import ResUnetPlusPlus_BE
8 | from .ternaus import ternaus11
9 | from .ternaus_BE import ternaus_BE
10 | from .uspp import Uspp
11 | from .uspp_BE import Uspp_BE
12 | from .denet import DeNet
13 | from .brrnet import BRRNet
14 | from .brrnet_BE import BRRNet_BE
15 | from .enru import ENRUNet
16 | from .enru_BE import ENRUNet_BE
17 |
18 | model_dict = {
19 | 'unet' : {
20 | 'weight_path': None,
21 | 'weight_url': None,
22 | 'arch': UNet},
23 | 'enru' : {
24 | 'weight_path': None,
25 | 'weight_url': None,
26 | 'arch': ENRUNet},
27 | 'enru_BE' : {
28 | 'weight_path': None,
29 | 'weight_url': None,
30 | 'arch': ENRUNet_BE},
31 | 'brrnet' : {
32 | 'weight_path': None,
33 | 'weight_url': None,
34 | 'arch': BRRNet},
35 | 'brrnet_BE' : {
36 | 'weight_path': None,
37 | 'weight_url': None,
38 | 'arch': BRRNet_BE},
39 | 'denet' : {
40 | 'weight_path': None,
41 | 'weight_url': None,
42 | 'arch': DeNet},
43 | 'uspp' : {
44 | 'weight_path': None,
45 | 'weight_url': None,
46 | 'arch': Uspp},
47 | 'uspp_BE' : {
48 | 'weight_path': None,
49 | 'weight_url': None,
50 | 'arch': Uspp_BE},
51 | 'resunet_BE' : {
52 | 'weight_path':None,
53 | 'weight_url': None,
54 | 'arch': ResUnetPlusPlus_BE},
55 | 'resunet' : {
56 | 'weight_path': None,
57 | 'weight_url': None,
58 | 'arch': ResUnetPlusPlus},
59 | 'unet_BE' : {
60 | 'weight_path':None,
61 | 'weight_url': None,
62 | 'arch': UNet_BE},
63 | 'ternaus' : {
64 | 'weight_path':None,
65 | 'weight_url': None,
66 | 'arch': ternaus11},
67 | 'ternaus_BE' : {
68 | 'weight_path':None,
69 | 'weight_url': None,
70 | 'arch': ternaus_BE},
71 | 'hrnetv2' : {
72 | 'weight_path':None,
73 | 'weight_url': None,
74 | 'arch': hrnetv2},
75 | 'hrnetv2_BE' : {
76 | 'weight_path':None,
77 | 'weight_url': None,
78 | 'arch': hrnetv2_BE},
79 | }
80 |
--------------------------------------------------------------------------------
/yml/train.yml:
--------------------------------------------------------------------------------
1 |
2 | # Choose model refered to 'net/zoo/__init__.py'
3 |
4 | # model_name : ternaus
5 | model_name: unet_BE
6 |
7 | # aoi 2 3 4 5 6 7 8 9
8 | #Area Vegas Paris Shanghai Khartoum Urban3D WHU-HR Massachusetts WHU-LR
9 |
10 | aoi : 6
11 |
12 | # If you adopt BE module in your model, change 'boundary' as True.
13 | boundary : True
14 | # boundary : False
15 |
16 | # Number of stage in encoder. U-Net, ResUNet have 5 stage in their architecutres, while TernausNet has 6.
17 | num_stage : 5
18 |
19 | # Pretrained model path
20 | model_path : ''
21 | # model_path: '../result/models_weight/{WEIGHT_DIR}/{PRETRAIN_FILE}.pth'
22 |
23 |
24 | train: true
25 | infer: false
26 | pretrained: False
27 | nn_framework: torch
28 | batch_size: 4
29 | data_specs:
30 | width: 512
31 | height: 512
32 | dtype:
33 | image_type: zscore
34 | rescale: false
35 | rescale_minima: auto
36 | rescale_maxima: auto
37 | channels: 3
38 | label_type: mask
39 | is_categorical: false
40 | mask_channels: 2
41 | val_holdout_frac: 0.175
42 | data_workers:
43 | num_classes : 1
44 |
45 | train_csv_dir : '../csvs/'
46 | validation_data_csv:
47 | inference_data_csv:
48 |
49 | # No augmentation!
50 | # If you want to add any of them, follow the discription in 'nets/transform.py/'
51 | training_augmentation:
52 | augmentations:
53 | CenterCrop :
54 | height : 512
55 | width : 512
56 | p : 1.0
57 | p: 1.0
58 | shuffle: true
59 | validation_augmentation:
60 | augmentations:
61 | CenterCrop :
62 | height : 512
63 | width : 512
64 | p : 1.0
65 | p: 1.0
66 | inference_augmentation:
67 | augmentations:
68 | p: 1.0
69 |
70 | # Enough epoch was set, because we use EarlyStopping.
71 | training:
72 | epochs: 10000
73 | steps_per_epoch:
74 | optimizer: adam
75 | lr: 1e-4
76 | opt_args:
77 |
78 |
79 | # BE module use focal+msssim+bce loss.
80 | # If you don't need BE module(boundary=False), 'loss_mask' and 'loss_boundary' do not work.
81 | loss :
82 | focal:
83 | loss_weights :
84 | focal : 1
85 | loss_mask:
86 | msssim :
87 | loss_mask_weights:
88 | msssim : 1
89 | loss_boundary:
90 | bce :
91 | loss_boundary_weights:
92 | bce : 1
93 | metrics:
94 | training: p
95 | validation: f
96 | checkpoint_frequency: 10
97 | callbacks:
98 | early_stopping:
99 | patience: 15
100 | model_checkpoint:
101 | filepath: '../result/models_weight/'
102 | path_aoi :
103 | monitor: val_loss
104 | lr_schedule:
105 | schedule_type: arbitrary
106 | schedule_dict:
107 | verbose: true
108 |
109 | inference:
110 | window_step_size_x:
111 | window_step_size_y:
112 | output_dir:
113 |
--------------------------------------------------------------------------------
/nets/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from ._torch_losses import torch_losses
3 | from torch import nn
4 |
5 |
6 | def get_loss(framework, loss, loss_weights=None, custom_losses=None):
7 |
8 | # lots of exception handling here. TODO: Refactor.
9 |
10 | if not isinstance(loss, dict):
11 | raise TypeError('The loss description is formatted improperly.'
12 | ' See the docs for details.')
13 | if len(loss) > 1:
14 |
15 | # get the weights for each loss within the composite
16 | if loss_weights is None:
17 | # weight all losses equally
18 | weights = {k: 1 for k in loss.keys()}
19 | else:
20 | weights = loss_weights
21 |
22 | # check if sublosses dict and weights dict have the same keys
23 | if list(loss.keys()).sort() != list(weights.keys()).sort():
24 | raise ValueError(
25 | 'The losses and weights must have the same name keys.')
26 |
27 | if framework in ['pytorch', 'torch']:
28 | return TorchCompositeLoss(loss, weights, custom_losses)
29 |
30 | else: # parse individual loss functions
31 | loss_name, loss_dict = list(loss.items())[0]
32 | return get_single_loss(framework, loss_name, loss_dict, custom_losses)
33 |
34 |
35 | def get_single_loss(framework, loss_name, params_dict, custom_losses=None):
36 |
37 | if framework in ['torch', 'pytorch']:
38 | if params_dict is None:
39 | if custom_losses is not None and loss_name in custom_losses:
40 | return custom_losses.get(loss_name)()
41 | else:
42 | return torch_losses.get(loss_name.lower())()
43 | else:
44 | if custom_losses is not None and loss_name in custom_losses:
45 | return custom_losses.get(loss_name)(**params_dict)
46 | else:
47 | return torch_losses.get(loss_name.lower())(**params_dict)
48 |
49 |
50 | class TorchCompositeLoss(nn.Module):
51 | """Composite loss function."""
52 |
53 | def __init__(self, loss_dict, weight_dict=None, custom_losses=None):
54 | """Create a composite loss function from a set of pytorch losses."""
55 | super().__init__()
56 | self.weights = weight_dict
57 | self.losses = {loss_name: get_single_loss('pytorch',
58 | loss_name,
59 | loss_params,
60 | custom_losses)
61 | for loss_name, loss_params in loss_dict.items()}
62 | self.values = {} # values from the individual loss functions
63 |
64 | def forward(self, outputs, targets):
65 | loss = 0
66 | for func_name, weight in self.weights.items():
67 | self.values[func_name] = self.losses[func_name](outputs, targets)
68 | loss += weight*self.values[func_name]
69 |
70 | return loss
71 |
--------------------------------------------------------------------------------
/utils/core.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | import skimage
5 |
6 |
7 | def _check_skimage_im_load(im):
8 | """Check if `im` is already loaded in; if not, load it in."""
9 | if isinstance(im, str):
10 | return skimage.io.imread(im)
11 | elif isinstance(im, np.ndarray):
12 | return im
13 | else:
14 | raise ValueError(
15 | "{} is not an accepted image format for scikit-image.".format(im))
16 |
17 |
18 | def _check_df_load(df):
19 | """Check if `df` is already loaded in, if not, load from file."""
20 | if isinstance(df, str):
21 | if df.lower().endswith('json'):
22 | return _check_gdf_load(df)
23 | else:
24 | return pd.read_csv(df)
25 | elif isinstance(df, pd.DataFrame):
26 | return df
27 | else:
28 | raise ValueError(f"{df} is not an accepted DataFrame format.")
29 |
30 |
31 |
32 | def get_data_paths(path, infer=False):
33 | """Get a pandas dataframe of images and labels from a csv.
34 |
35 | This file is designed to parse image:label reference CSVs (or just image)
36 | for inferencde) as defined in the documentation. Briefly, these should be
37 | CSVs containing two columns:
38 |
39 | ``'image'``: the path to images.
40 | ``'label'``: the path to the label file that corresponds to the image.
41 |
42 | Arguments
43 | ---------
44 | path : str
45 | Path to a .CSV-formatted reference file defining the location of
46 | training, validation, or inference data. See docs for details.
47 | infer : bool, optional
48 | If ``infer=True`` , the ``'label'`` column will not be returned (as it
49 | is unnecessary for inference), even if it is present.
50 |
51 | Returns
52 | -------
53 | df : :class:`pandas.DataFrame`
54 | A :class:`pandas.DataFrame` containing the relevant `image` and `label`
55 | information from the CSV at `path` (unless ``infer=True`` , in which
56 | case only the `image` column is returned.)
57 |
58 | """
59 | df = pd.read_csv(path)
60 | if infer:
61 | return df[['image']] # no labels in those files
62 | else:
63 | return df[['image', 'label']] # remove anything extraneous
64 |
65 |
66 | def get_files_recursively(path, traverse_subdirs=False, extension='.tif'):
67 | """Get files from subdirs of `path`, joining them to the dir."""
68 | if traverse_subdirs:
69 | walker = os.walk(path)
70 | path_list = []
71 | for step in walker:
72 | if not step[2]: # if there are no files in the current dir
73 | continue
74 | path_list += [os.path.join(step[0], fname)
75 | for fname in step[2] if
76 | fname.lower().endswith(extension)]
77 | return path_list
78 | else:
79 | return [os.path.join(path, f) for f in os.listdir(path)
80 | if f.endswith(extension)]
81 |
--------------------------------------------------------------------------------
/nets/zoo/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def double_conv(in_channels, out_channels):
6 | return nn.Sequential(
7 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
8 | nn.ReLU(inplace=True),
9 | nn.Conv2d(out_channels, out_channels, 3, padding=1),
10 | nn.ReLU(inplace=True)
11 | )
12 |
13 | class _up_deconv(nn.Module):
14 | def __init__(self, in_channels, out_channels):
15 | super(_up_deconv, self).__init__()
16 |
17 | self.deconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
18 | self.bn_i = nn.BatchNorm2d(num_features=out_channels)
19 | self.relu = nn.ReLU()
20 |
21 | def forward(self, x):
22 |
23 | out = self.bn_i(self.deconv(x))
24 | out = self.relu(out)
25 |
26 | return out
27 |
28 | class UNet(nn.Module):
29 |
30 | def __init__(self, n_class=1, pretrained=False, mode='Train'):
31 | super().__init__()
32 | self.mode=mode
33 | self.dconv_down1 = double_conv(3, 64)
34 | self.dconv_down2 = double_conv(64, 128)
35 | self.dconv_down3 = double_conv(128, 256)
36 | self.dconv_down4 = double_conv(256, 512)
37 | self.dconv_down5 = double_conv(512, 1024)
38 |
39 | self.maxpool = nn.MaxPool2d(2)
40 | # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
41 | self.upsample5 = _up_deconv(1024,512)
42 | self.upsample4 = _up_deconv(512,256)
43 | self.upsample3 = _up_deconv(256,128)
44 | self.upsample2 = _up_deconv(128,64)
45 | self.dconv_up4 = double_conv(512 + 512, 512)
46 | self.dconv_up3 = double_conv(256 + 256, 256)
47 | self.dconv_up2 = double_conv(128 + 128, 128)
48 | self.dconv_up1 = double_conv(64 + 64, 64)
49 |
50 | self.conv_last = nn.Conv2d(64, n_class, 1)
51 |
52 |
53 | def forward(self, x):
54 |
55 | conv1 = self.dconv_down1(x)
56 | x = self.maxpool(conv1)
57 |
58 | conv2 = self.dconv_down2(x)
59 | x = self.maxpool(conv2)
60 |
61 | conv3 = self.dconv_down3(x)
62 | x = self.maxpool(conv3)
63 |
64 | conv4 = self.dconv_down4(x)
65 | x = self.maxpool(conv4)
66 |
67 | x = self.dconv_down5(x)
68 | x = self.upsample5(x)
69 | x = torch.cat([x, conv4],1)
70 |
71 | x = self.dconv_up4(x)
72 | x = self.upsample4(x)
73 | x = torch.cat([x, conv3], dim=1)
74 |
75 | x = self.dconv_up3(x)
76 | x = self.upsample3(x)
77 | x = torch.cat([x, conv2], dim=1)
78 |
79 | x = self.dconv_up2(x)
80 | x = self.upsample2(x)
81 | x = torch.cat([x, conv1], dim=1)
82 |
83 | x = self.dconv_up1(x)
84 |
85 | out = self.conv_last(x)
86 |
87 | if self.mode == 'Train':
88 | return F.sigmoid(out)
89 | elif self.mode == 'Infer':
90 | return out
--------------------------------------------------------------------------------
/nets/model_io.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from warnings import warn
4 | import requests
5 | import numpy as np
6 | from tqdm.auto import tqdm
7 | from nets import weights_dir
8 | from .zoo import model_dict
9 |
10 |
11 | def get_model(model_name, framework, mode='Train', model_path=None, pretrained=False,
12 | custom_model_dict=None, num_classes=1):
13 | """Load a model from a file based on its name."""
14 | if custom_model_dict is not None:
15 | md = custom_model_dict
16 | else:
17 | md = model_dict.get(model_name, None)
18 | if md is None: # if the model's not provided by solaris
19 | raise ValueError(f"{model_name} can't be found in solaris and no "
20 | "custom_model_dict was provided. Check your "
21 | "model_name in the config file and/or provide a "
22 | "custom_model_dict argument to Trainer(). ")
23 | if model_path is None or custom_model_dict is not None:
24 |
25 | model_path = md.get('weight_path')
26 | if num_classes == 1:
27 | model = md.get('arch')(pretrained=pretrained, mode=mode)
28 | else:
29 | model = md.get('arch')(num_classes=num_classes, pretrained=pretrained)
30 |
31 | if model is not None and pretrained:
32 | try:
33 | model = _load_model_weights(model, model_path, framework)
34 | except (OSError, FileNotFoundError):
35 | warn(f'The model weights file {model_path} was not found.'
36 | ' Attempting to download from the SpaceNet repository.')
37 | weight_path = _download_weights(md)
38 | model = _load_model_weights(model, weight_path, framework)
39 |
40 | return model
41 |
42 |
43 | def _load_model_weights(model, path, framework):
44 | """Backend for loading the model."""
45 |
46 | if framework.lower() in ['torch', 'pytorch']:
47 | # pytorch already throws the right error on failed load, so no need
48 | # to fix exception
49 | if torch.cuda.is_available():
50 | try:
51 | loaded = torch.load(path)
52 | except FileNotFoundError:
53 | # first, check to see if the weights are in the default sol dir
54 | default_path = os.path.join(weights_dir,
55 | os.path.split(path)[1])
56 | loaded = torch.load(path)
57 | else:
58 | try:
59 | loaded = torch.load(path, map_location='cpu')
60 | except FileNotFoundError:
61 | default_path = os.path.join(weights_dir,
62 | os.path.split(path)[1])
63 | loaded = torch.load(path, map_location='cpu')
64 |
65 | if isinstance(loaded, torch.nn.Module): # if it's a full model already
66 | model.load_state_dict(loaded.state_dict())
67 | else:
68 | model.load_state_dict(loaded)
69 |
70 | return model
71 |
72 |
73 | def reset_weights(model, framework):
74 |
75 | if framework == 'torch':
76 | reinit_model = model.apply(_reset_torch_weights)
77 |
78 | return reinit_model
79 |
80 |
81 | def _reset_torch_weights(torch_layer):
82 | if isinstance(torch_layer, torch.nn.Conv2d) or \
83 | isinstance(torch_layer, torch.nn.Linear):
84 | torch_layer.reset_parameters()
85 |
86 |
87 | def _download_weights(model_dict):
88 | """Download pretrained weights for a model."""
89 | weight_url = model_dict.get('weight_url', None)
90 | weight_dest_path = model_dict.get('weight_path', os.path.join(
91 | weights_dir, weight_url.split('/')[-1]))
92 | if weight_url is None:
93 | raise KeyError("Can't find the weights file.")
94 | else:
95 | r = requests.get(weight_url, stream=True)
96 | if r.status_code != 200:
97 | raise ValueError('The file could not be downloaded. Check the URL'
98 | ' and network connections.')
99 | total_size = int(r.headers.get('content-length', 0))
100 | block_size = 1024
101 | with open(weight_dest_path, 'wb') as f:
102 | for chunk in tqdm(r.iter_content(block_size),
103 | total=np.ceil(total_size//block_size),
104 | unit='KB', unit_scale=False):
105 | if chunk:
106 | f.write(chunk)
107 |
108 | return weight_dest_path
109 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Boundary Enhancement Semantic Segmentation for Building Extraction from Remote Sensed Image
2 | ## Introduction
3 | This repository includes implementations for binary semantic segmentation, especially for building extraction in satellite images.[Link](https://ieeexplore.ieee.org/document/9527893) [pdf](./img/Boundary_Enhancement_Semantic_Segmentation.pdf)
4 | Furthermore, the boundary enhanced methods (BE module) are also contained in ```/net/zoo/```.
5 |
6 | 
7 | 
8 | 
9 |
10 | ## Requirements
11 |
12 | ```
13 | Python >= 3.7.0
14 |
15 | Pytorch > =1.9.0
16 |
17 | skimage >= 0.18.2
18 |
19 | cuda >= 10.1
20 | ```
21 |
22 | ## Data prep
23 |
24 | ### Urban3D Dataset example
25 |
26 | - The experiments were conducted with cropped images as 512 X 512 size, and splitted 2,912 and 672 for training and test subset, respectively.
27 | - The original dataset can be downloaded [Urban3D](https://github.com/topcoderinc/Urban3d).
28 |
29 | - The data should be arranged like this.
30 |
31 | ```
32 | |-- Test
33 | | `-- Urban3D_Test
34 | | |-- RGB
35 | | `-- masks
36 | `-- Train
37 | `-- Urban3D_Train
38 | |-- RGB
39 | `-- masks
40 | ```
41 |
42 | - Open ```/notebook/data_prep.py``` and make dataframes for train and test set.
43 | ```Urban3D_Train_df.csv``` and ```Urban3D_Test_df.csv``` would be made in ```/csv/```.
44 |
45 | ## Train
46 |
47 | - Check and set hyperparameters in ```/yml/train.yml```.
48 | - Choose model refrered to ```/net/zoo/__init__py/```.
49 | - Choose area of interest. ```6``` is default for Urban3D dataset.
50 | - Set ```num_stage``` as following the number of backbone architecture's stage.
51 | - Set training hyperparameters ; epochs, optimizer, lr, loss functions.
52 | - If you want to train *Boundary Enhancement* model, set ```boundary``` as ```True```.
53 | - Run ```/src/train.py```.
54 | - ```result``` directory and ```/result/models_weight``` directory would be created automatically.
55 | - Model weights will be saved in ```/result/models_weight/{DATASET_NAME}_{MODEL_NAME}_{TRAINING_ID}```. ```TRAINING_ID``` is an UNIX time when the training was started.
56 |
57 | ## Inference
58 |
59 | - Check and setup parameters in ```/yml/infer.yml/```.
60 |
61 | - ```model_name``` and ```aoi``` should be same with those in ```train.yml```.
62 |
63 | - If you want to train *Boundary Enhancement* model, set ```boundary``` as ```True```.
64 |
65 | - Set training_date same as ```TRAINING_ID```.
66 |
67 |
68 |
69 | - Run ```/src/infer.py```.
70 |
71 | - Inferred images will be saved in ```/result/infer/```.
72 |
73 | ## Evaluation
74 |
75 | - Open ```/notebook/get_mask_eval.ipynb```.
76 | - Check ```aois``` and ```training date```. ``training date`` is ```TRAINING_ID``` in training procedure.
77 | - Running all cells will create mask image from inferred image.
78 | - Evaluation result will show up comparing ground truth and predicted mask. The result will be saved in ```/result/eval_result/```.
79 |
80 |
81 |
82 | ## Implemented model and dataset
83 |
84 | ### Model
85 |
86 | - U-Net
87 | - ResUNet++
88 | - TernausNet
89 | - BRR-Net
90 | - USPP
91 | - DE-Net
92 |
93 | ### Dataset
94 |
95 | - DeepGlobe Dataset(Vegas, Paris, Shanghai, Khartoum)
96 | - Urban3D Dataset
97 | - WHU Dataset(aerial and satellite)
98 | - Massachusetts Dataset
99 |
100 | ## File tree
101 |
102 | ```
103 | |-- data
104 | | |-- Test
105 | | `-- Train
106 | |-- nets
107 | | |-- __init__.py
108 | | |-- _torch_losses.py
109 | | |-- assembly_block.py
110 | | |-- callbacks.py
111 | | |-- datagen.py
112 | | |-- infer.py
113 | | |-- losses.py
114 | | |-- model_io.py
115 | | |-- optimizers.py
116 | | |-- torch_callbacks.py
117 | | |-- train.py
118 | | |-- transform.py
119 | | |-- weights
120 | | `-- zoo
121 | |-- notebooks
122 | | |-- __init__.py
123 | | |-- data_prep.ipynb
124 | | `-- get_mask_eval.ipynb
125 | |-- result
126 | | |-- infer
127 | | |-- infer_masks
128 | | `-- models_weight
129 | |-- src
130 | | |-- __init__.py
131 | | |-- inference.py
132 | | `-- train.py
133 | |-- utils
134 | | |-- __init__.py
135 | | |-- config.py
136 | | |-- core.py
137 | | |-- data.py
138 | | |-- io.py
139 | | `-- log.py
140 | `-- yml
141 | |-- infer.yml
142 | `-- train.yml
143 | ```
144 |
145 | ## Contribution
146 |
147 | This codes are modified and simplified version of [Solaris](https://github.com/CosmiQ/solaris) for my own research.
148 |
149 |
--------------------------------------------------------------------------------
/nets/optimizers.py:
--------------------------------------------------------------------------------
1 | """Wrappers for training optimizers."""
2 | import math
3 | import torch
4 |
5 | def get_optimizer(framework, config):
6 |
7 | if config['training']['optimizer'] is None:
8 | raise ValueError('An optimizer must be specified in the config '
9 | 'file.')
10 |
11 | if framework in ['torch', 'pytorch']:
12 | return torch_optimizers.get(config['training']['optimizer'].lower())
13 |
14 | class TorchAdamW(torch.optim.Optimizer):
15 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
16 | weight_decay=1e-2, amsgrad=False):
17 | if not 0.0 <= lr:
18 | raise ValueError("Invalid learning rate: {}".format(lr))
19 | if not 0.0 <= eps:
20 | raise ValueError("Invalid epsilon value: {}".format(eps))
21 | if not 0.0 <= betas[0] < 1.0:
22 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
23 | if not 0.0 <= betas[1] < 1.0:
24 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
25 | defaults = dict(lr=lr, betas=betas, eps=eps,
26 | weight_decay=weight_decay, amsgrad=amsgrad)
27 | super(TorchAdamW, self).__init__(params, defaults)
28 |
29 | def __setstate__(self, state):
30 | super(TorchAdamW, self).__setstate__(state)
31 | for group in self.param_groups:
32 | group.setdefault('amsgrad', False)
33 |
34 | def step(self, closure=None):
35 | """Performs a single optimization step.
36 | Arguments:
37 | closure (callable, optional): A closure that reevaluates the model
38 | and returns the loss.
39 | """
40 | loss = None
41 | if closure is not None:
42 | loss = closure()
43 |
44 | for group in self.param_groups:
45 | for p in group['params']:
46 | if p.grad is None:
47 | continue
48 |
49 | # Perform stepweight decay
50 | p.data.mul_(1 - group['lr'] * group['weight_decay'])
51 |
52 | # Perform optimization step
53 | grad = p.grad.data
54 | if grad.is_sparse:
55 | raise RuntimeError('Adam does not support sparse'
56 | 'gradients, please consider SparseAdam'
57 | ' instead')
58 | amsgrad = group['amsgrad']
59 |
60 | state = self.state[p]
61 |
62 | # State initialization
63 | if len(state) == 0:
64 | state['step'] = 0
65 | # Exponential moving average of gradient values
66 | state['exp_avg'] = torch.zeros_like(p.data)
67 | # Exponential moving average of squared gradient values
68 | state['exp_avg_sq'] = torch.zeros_like(p.data)
69 | if amsgrad:
70 | # Maintains max of all exp. moving avg. of sq. grad. values
71 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
72 |
73 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
74 | if amsgrad:
75 | max_exp_avg_sq = state['max_exp_avg_sq']
76 | beta1, beta2 = group['betas']
77 |
78 | state['step'] += 1
79 |
80 | # Decay the first and second moment running average coefficient
81 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
82 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
83 | if amsgrad:
84 | # Maintains the maximum of all 2nd moment running avg. till now
85 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
86 | # Use the max. for normalizing running avg. of gradient
87 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
88 | else:
89 | denom = exp_avg_sq.sqrt().add_(group['eps'])
90 |
91 | bias_correction1 = 1 - beta1 ** state['step']
92 | bias_correction2 = 1 - beta2 ** state['step']
93 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
94 |
95 | p.data.addcdiv_(-step_size, exp_avg, denom)
96 |
97 | return loss
98 |
99 |
100 | torch_optimizers = {
101 | 'adadelta': torch.optim.Adadelta,
102 | 'adam': torch.optim.Adam,
103 | 'adamw': TorchAdamW,
104 | 'sparseadam': torch.optim.SparseAdam,
105 | 'adamax': torch.optim.Adamax,
106 | 'asgd': torch.optim.ASGD,
107 | 'rmsprop': torch.optim.RMSprop,
108 | 'sgd': torch.optim.SGD,
109 | }
110 |
111 |
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from .log import _get_logging_level
4 | from .core import get_files_recursively
5 | import logging
6 |
7 |
8 | def make_dataset_csv(im_dir, im_ext='tif', label_dir=None, label_ext='json',
9 | output_path='dataset.csv', stage='train', match_re=None,
10 | recursive=False, ignore_mismatch=None, verbose=0):
11 |
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(_get_logging_level(int(verbose)))
14 | logger.debug('Checking arguments.')
15 |
16 | if stage != 'infer' and label_dir is None:
17 | raise ValueError("label_dir must be provided if stage is not infer.")
18 | logger.info('Matching images to labels.')
19 | logger.debug('Getting image file paths.')
20 | im_fnames = get_files_recursively(im_dir, traverse_subdirs=recursive,
21 | extension=im_ext)
22 | logger.debug(f"Got {len(im_fnames)} image file paths.")
23 | temp_im_df = pd.DataFrame({'image_path': im_fnames})
24 |
25 | if stage != 'infer':
26 | logger.debug('Preparing training or validation set.')
27 | logger.debug('Getting label file paths.')
28 | label_fnames = get_files_recursively(label_dir,
29 | traverse_subdirs=recursive,
30 | extension=label_ext)
31 | logger.debug(f"Got {len(label_fnames)} label file paths.")
32 | if len(im_fnames) != len(label_fnames):
33 | logger.warn('The number of images and label files is not equal.')
34 |
35 | logger.debug("Matching image files to label files.")
36 | logger.debug("Extracting image filename substrings for matching.")
37 | temp_label_df = pd.DataFrame({'label_path': label_fnames})
38 | temp_im_df['image_fname'] = temp_im_df['image_path'].apply(
39 | lambda x: os.path.split(x)[1])
40 | temp_label_df['label_fname'] = temp_label_df['label_path'].apply(
41 | lambda x: os.path.split(x)[1])
42 | if match_re:
43 | logger.debug('match_re is True, extracting regex matches')
44 | im_match_strs = temp_im_df['image_fname'].str.extract(match_re)
45 | label_match_strs = temp_label_df['label_fname'].str.extract(
46 | match_re)
47 | if len(im_match_strs.columns) > 1 or \
48 | len(label_match_strs.columns) > 1:
49 | raise ValueError('Multiple regex matches occurred within '
50 | 'individual filenames.')
51 | else:
52 | temp_im_df['match_str'] = im_match_strs
53 | temp_label_df['match_str'] = label_match_strs
54 | else:
55 | logger.debug('match_re is False, will match by fname without ext')
56 | temp_im_df['match_str'] = temp_im_df['image_fname'].apply(
57 | lambda x: os.path.splitext(x)[0])
58 | temp_label_df['match_str'] = temp_label_df['label_fname'].apply(
59 | lambda x: os.path.splitext(x)[0])
60 |
61 | logger.debug('Aligning label and image dataframes by'
62 | ' match_str.')
63 | temp_join_df = pd.merge(temp_im_df, temp_label_df, on='match_str',
64 | how='inner')
65 | logger.debug(f'Length of joined dataframe: {len(temp_join_df)}')
66 | if len(temp_join_df) < len(temp_im_df) and \
67 | ignore_mismatch is None:
68 | raise ValueError('There is not a perfect 1:1 match of images '
69 | 'to label files. To allow this behavior, see '
70 | 'the make_dataset_csv() ignore_mismatch '
71 | 'argument.')
72 | elif len(temp_join_df) > len(temp_im_df) and ignore_mismatch is None:
73 | raise ValueError('There are multiple label files matching at '
74 | 'least one image file.')
75 | elif len(temp_join_df) > len(temp_im_df) and ignore_mismatch == 'skip':
76 | logger.info('ignore_mismatch="skip", so dropping any images with '
77 | f'duplicates. Original images: {len(temp_im_df)}')
78 | dup_rows = temp_join_df.duplicated(subset='match_str', keep=False)
79 | temp_join_df = temp_join_df.loc[~dup_rows, :]
80 | logger.info('Remaining images after dropping duplicates: '
81 | f'{len(temp_join_df)}')
82 | logger.debug('Dropping extra columns from output dataframe.')
83 | output_df = temp_join_df[['image_path', 'label_path']].rename(
84 | columns={'image_path': 'image', 'label_path': 'label'})
85 |
86 | elif stage == 'infer':
87 | logger.debug('Preparing inference dataset dataframe.')
88 | output_df = temp_im_df.rename(columns={'image_path': 'image'})
89 |
90 | logger.debug(f'Saving output dataframe to {output_path} .')
91 | output_df.to_csv(output_path, index=False)
92 |
93 | return output_df
94 |
--------------------------------------------------------------------------------
/notebooks/data_prep.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 7,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# Dataset location (edit as needed)\n",
10 | "import os\n",
11 | "import pandas as pd\n",
12 | "root_dir = '../'"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 8,
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "name": "stdout",
22 | "output_type": "stream",
23 | "text": [
24 | "3.7.10 (default, Feb 26 2021, 18:47:35) \n",
25 | "[GCC 7.3.0]\n",
26 | "Python 3.7.10\n"
27 | ]
28 | }
29 | ],
30 | "source": [
31 | "import sys\n",
32 | "print(sys.version)\n",
33 | "!python --version"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 12,
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "data": {
43 | "text/html": [
44 | "
\n",
45 | "\n",
58 | "
\n",
59 | " \n",
60 | " \n",
61 | " | \n",
62 | " image | \n",
63 | " label | \n",
64 | "
\n",
65 | " \n",
66 | " \n",
67 | " \n",
68 | " | 0 | \n",
69 | " ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... | \n",
70 | " ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... | \n",
71 | "
\n",
72 | " \n",
73 | " | 1 | \n",
74 | " ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... | \n",
75 | " ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... | \n",
76 | "
\n",
77 | " \n",
78 | " | 2 | \n",
79 | " ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... | \n",
80 | " ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... | \n",
81 | "
\n",
82 | " \n",
83 | " | 3 | \n",
84 | " ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... | \n",
85 | " ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... | \n",
86 | "
\n",
87 | " \n",
88 | " | 4 | \n",
89 | " ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... | \n",
90 | " ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... | \n",
91 | "
\n",
92 | " \n",
93 | "
\n",
94 | "
"
95 | ],
96 | "text/plain": [
97 | " image \\\n",
98 | "0 ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... \n",
99 | "1 ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... \n",
100 | "2 ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... \n",
101 | "3 ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... \n",
102 | "4 ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... \n",
103 | "\n",
104 | " label \n",
105 | "0 ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... \n",
106 | "1 ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... \n",
107 | "2 ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... \n",
108 | "3 ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... \n",
109 | "4 ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... "
110 | ]
111 | },
112 | "metadata": {},
113 | "output_type": "display_data"
114 | },
115 | {
116 | "name": "stdout",
117 | "output_type": "stream",
118 | "text": [
119 | "output csv: ../csvs/Urban3D_Train_df.csv\n"
120 | ]
121 | }
122 | ],
123 | "source": [
124 | "# Make dataframe csvs for train/test\n",
125 | "\n",
126 | "out_dir = os.path.join(root_dir, 'csvs/')\n",
127 | "os.makedirs(out_dir, exist_ok=True)\n",
128 | "# data_dir = 'Test/Urban3D_Test'\n",
129 | "data_dir = 'Train/Urban3D_Train'\n",
130 | "\n",
131 | "\n",
132 | "d = os.path.join(root_dir, 'data', data_dir)\n",
133 | "subdirs = sorted([f for f in os.listdir(d)]) \n",
134 | "outpath = os.path.join(out_dir, data_dir.split('/')[1] + '_df.csv')\n",
135 | "im_list, mask_list = [], []\n",
136 | "\n",
137 | "\n",
138 | "im_files = [os.path.join( d,'RGB', f.split('.')[0] + '.tif')\n",
139 | "for f in sorted(os.listdir(os.path.join(d,'RGB' )))]\n",
140 | "\n",
141 | "mask_files = [os.path.join(d, 'masks', f.split('.')[0] + '.tif')\n",
142 | "for f in sorted(os.listdir(os.path.join(d, 'masks')))]\n",
143 | "\n",
144 | "\n",
145 | "im_list.extend(im_files)\n",
146 | "mask_list.extend(mask_files)\n",
147 | "\n",
148 | "\n",
149 | "df = pd.DataFrame({'image': im_list, 'label': mask_list})\n",
150 | "display(df.head())\n",
151 | "\n",
152 | "df.to_csv(outpath, index=False)\n",
153 | "\n",
154 | "print(\"output csv:\", outpath)"
155 | ]
156 | }
157 | ],
158 | "metadata": {
159 | "kernelspec": {
160 | "display_name": "Python 3 (ipykernel)",
161 | "language": "python",
162 | "name": "python3"
163 | },
164 | "language_info": {
165 | "codemirror_mode": {
166 | "name": "ipython",
167 | "version": 3
168 | },
169 | "file_extension": ".py",
170 | "mimetype": "text/x-python",
171 | "name": "python",
172 | "nbconvert_exporter": "python",
173 | "pygments_lexer": "ipython3",
174 | "version": "3.7.10"
175 | }
176 | },
177 | "nbformat": 4,
178 | "nbformat_minor": 4
179 | }
180 |
--------------------------------------------------------------------------------
/nets/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import skimage.io
5 | import numpy as np
6 | from warnings import warn
7 | from .model_io import get_model
8 | from .transform import process_aug_dict
9 | from .datagen import InferenceTiler as InferenceTiler
10 | from utils.core import get_data_paths
11 | import torch.nn.functional as F
12 |
13 | class Inferer(object):
14 | """Object for training `solaris` models using PyTorch or Keras."""
15 |
16 | def __init__(self, config, custom_model_dict=None):
17 | self.config = config
18 | self.batch_size = self.config['batch_size']
19 | self.framework = self.config['nn_framework']
20 | self.model_name = self.config['model_name']
21 | self.aoi = self.config["get_aoi"]
22 | self.date = self.config["training_date"]
23 | self.boundary = self.config["boundary"]
24 | self.weight_file = self.config['weight_file']
25 |
26 | # check if the model was trained as part of the same pipeline; if so,
27 | # use the output from that. If not, use the pre-trained model directly.
28 | if self.config['train']:
29 | warn('Because the configuration specifies both training and '
30 | 'inference, solaris is switching the model weights path '
31 | 'to the training output path.')
32 | self.model_path = self.config['training']['model_dest_path']
33 | if custom_model_dict is not None:
34 | custom_model_dict['weight_path'] = self.config[
35 | 'training']['model_dest_path']
36 | else:
37 |
38 | if len(self.model_name.split('_'))==2:
39 | self.model_path = self.config.get('model_path', None) + self.aoi + '_' +self.model_name.split('_')[0]+ '_' + self.model_name.split('_')[1]+ '_'+ self.date + '/' + self.weight_file
40 | else :
41 | self.model_path = self.config.get('model_path', None) + self.aoi + '_' +self.model_name.split('_')[0]+ '_' + self.date + '/' + self.weight_file
42 | self.infer_mode = self.config['infer']
43 | if self.infer_mode :
44 | self.mode = 'Infer'
45 |
46 | self.model = get_model(self.model_name, self.framework, self.mode,
47 | self.model_path, pretrained=True, custom_model_dict=custom_model_dict)
48 | self.window_step_x = self.config['inference'].get('window_step_size_x',
49 | None)
50 | self.window_step_y = self.config['inference'].get('window_step_size_y',
51 | None)
52 | if self.window_step_x is None:
53 | self.window_step_x = self.config['data_specs']['width']
54 | if self.window_step_y is None:
55 | self.window_step_y = self.config['data_specs']['height']
56 | self.stitching_method = self.config['inference'].get(
57 | 'stitching_method', 'average')
58 | self.output_dir = self.config['inference']['output_dir'] + self.aoi + '_' + self.date + '/'
59 |
60 | if not os.path.isdir(self.output_dir):
61 | os.makedirs(self.output_dir)
62 |
63 | if self.framework in ['torch', 'pytorch']:
64 | self.gpu_available = torch.cuda.is_available()
65 | if self.gpu_available:
66 | self.gpu_count = torch.cuda.device_count()
67 | else:
68 | self.gpu_count = 0
69 | def __call__(self, infer_df=None):
70 |
71 | with torch.no_grad():
72 | print(self.model_path)
73 | if infer_df is None:
74 | infer_df = get_infer_df(self.config)
75 |
76 | inf_tiler = InferenceTiler(
77 | self.framework,
78 | width=self.config['data_specs']['width'],
79 | height=self.config['data_specs']['height'],
80 | x_step=self.window_step_x,
81 | y_step=self.window_step_y,
82 | augmentations=process_aug_dict(
83 | self.config['inference_augmentation']))
84 | for idx, im_path in enumerate(infer_df['image']):
85 | leng=len(infer_df['image'])
86 | print(idx,'/',leng, ' (%0.2f%%)' % float(100*idx/leng))
87 |
88 | inf_input, idx_refs, (
89 | src_im_height, src_im_width) = inf_tiler(im_path)
90 |
91 | if self.framework in ['torch', 'pytorch']:
92 |
93 | with torch.no_grad():
94 | self.model.eval()
95 |
96 | if torch.cuda.is_available():
97 | device = torch.device('cuda')
98 | self.model = self.model.cuda()
99 | else:
100 | device = torch.device('cpu')
101 |
102 | inf_input = torch.from_numpy(inf_input).float().to(device)
103 |
104 | # add additional input data, if applicable
105 | if self.config['data_specs'].get('additional_inputs',
106 | None) is not None:
107 | inf_input = [inf_input]
108 | for i in self.config['data_specs']['additional_inputs']:
109 | inf_input.append(
110 | infer_df[i].iloc[idx].to(device))
111 |
112 |
113 |
114 | subarr_preds = self.model(inf_input)
115 |
116 |
117 | subarr_preds = subarr_preds.cpu().data.numpy()
118 | subarr_preds = subarr_preds[:, :, :src_im_height,:src_im_width]
119 |
120 |
121 |
122 | skimage.io.imsave(os.path.join(self.output_dir,os.path.split(im_path)[1]), subarr_preds)
123 |
124 |
125 | def get_infer_df(config):
126 |
127 | infer_df = get_data_paths(config['inference_data_csv']+config['get_aoi']+'_Test_df.csv' , infer=True)
128 | return infer_df
129 |
--------------------------------------------------------------------------------
/nets/zoo/ternaus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torchvision import models
6 |
7 |
8 | def conv3x3(in_: int, out: int) -> nn.Module:
9 | return nn.Conv2d(in_, out, 3, padding=1)
10 |
11 |
12 | class ConvRelu(nn.Module):
13 | def __init__(self, in_: int, out: int) -> None:
14 | super().__init__()
15 | self.conv = conv3x3(in_, out)
16 | self.activation = nn.ReLU(inplace=True)
17 |
18 | def forward(self, x: torch.Tensor) -> torch.Tensor:
19 | x = self.conv(x)
20 | x = self.activation(x)
21 | return x
22 |
23 |
24 | class DecoderBlock(nn.Module):
25 | def __init__(
26 | self, in_channels: int, middle_channels: int, out_channels: int
27 | ) -> None:
28 | super().__init__()
29 |
30 | self.block = nn.Sequential(
31 | ConvRelu(in_channels, middle_channels),
32 | nn.ConvTranspose2d(
33 | middle_channels,
34 | out_channels,
35 | kernel_size=3,
36 | stride=2,
37 | padding=1,
38 | output_padding=1,
39 | ),
40 | nn.ReLU(inplace=True),
41 | )
42 |
43 | def forward(self, x: torch.Tensor) -> torch.Tensor:
44 | return self.block(x)
45 |
46 |
47 | class ternaus11(nn.Module):
48 | def __init__(self, num_filters: int = 32, pretrained: bool = False,mode='Train') -> None:
49 | """
50 | Args:
51 | num_filters:
52 | pretrained:
53 | False - no pre-trained network is used
54 | True - encoder is pre-trained with VGG11
55 | """
56 | super().__init__()
57 | self.pool = nn.MaxPool2d(2, 2)
58 | self.mode=mode
59 | self.encoder = models.vgg11(pretrained=pretrained).features
60 |
61 | self.relu = self.encoder[1]
62 | self.conv1 = self.encoder[0]
63 | self.conv2 = self.encoder[3]
64 | self.conv3s = self.encoder[6]
65 | self.conv3 = self.encoder[8]
66 | self.conv4s = self.encoder[11]
67 | self.conv4 = self.encoder[13]
68 | self.conv5s = self.encoder[16]
69 | self.conv5 = self.encoder[18]
70 |
71 | self.center = DecoderBlock(
72 | num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8
73 | )
74 | self.dec5 = DecoderBlock(
75 | num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8
76 | )
77 | self.dec4 = DecoderBlock(
78 | num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4
79 | )
80 | self.dec3 = DecoderBlock(
81 | num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2
82 | )
83 | self.dec2 = DecoderBlock(
84 | num_filters * (4 + 2), num_filters * 2 * 2, num_filters
85 | )
86 | self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)
87 |
88 | self.final = nn.Conv2d(num_filters, 1, kernel_size=1)
89 |
90 | def forward(self, x: torch.Tensor) -> torch.Tensor:
91 | conv1 = self.relu(self.conv1(x))
92 | conv2 = self.relu(self.conv2(self.pool(conv1)))
93 | conv3s = self.relu(self.conv3s(self.pool(conv2)))
94 | conv3 = self.relu(self.conv3(conv3s))
95 | conv4s = self.relu(self.conv4s(self.pool(conv3)))
96 | conv4 = self.relu(self.conv4(conv4s))
97 | conv5s = self.relu(self.conv5s(self.pool(conv4)))
98 | conv5 = self.relu(self.conv5(conv5s))
99 |
100 | center = self.center(self.pool(conv5))
101 |
102 | dec5 = self.dec5(torch.cat([center, conv5], 1))
103 | dec4 = self.dec4(torch.cat([dec5, conv4], 1))
104 | dec3 = self.dec3(torch.cat([dec4, conv3], 1))
105 | dec2 = self.dec2(torch.cat([dec3, conv2], 1))
106 | dec1 = self.dec1(torch.cat([dec2, conv1], 1))
107 | out = self.final(dec1)
108 | if self.mode == 'Train':
109 | return F.sigmoid(out)
110 | elif self.mode == 'Infer':
111 | return out
112 |
113 | class Interpolate(nn.Module):
114 | def __init__(
115 | self,
116 | size: int = None,
117 | scale_factor: int = None,
118 | mode: str = "nearest",
119 | align_corners: bool = False,
120 | ):
121 | super().__init__()
122 | self.interp = nn.functional.interpolate
123 | self.size = size
124 | self.mode = mode
125 | self.scale_factor = scale_factor
126 | self.align_corners = align_corners
127 |
128 | def forward(self, x: torch.Tensor) -> torch.Tensor:
129 | x = self.interp(
130 | x,
131 | size=self.size,
132 | scale_factor=self.scale_factor,
133 | mode=self.mode,
134 | align_corners=self.align_corners,
135 | )
136 | return x
137 |
138 |
139 | class DecoderBlockV2(nn.Module):
140 | def __init__(
141 | self,
142 | in_channels: int,
143 | middle_channels: int,
144 | out_channels: int,
145 | is_deconv: bool = True,
146 | ):
147 | super().__init__()
148 | self.in_channels = in_channels
149 |
150 | if is_deconv:
151 | """
152 | Paramaters for Deconvolution were chosen to avoid artifacts, following
153 | link https://distill.pub/2016/deconv-checkerboard/
154 | """
155 |
156 | self.block = nn.Sequential(
157 | ConvRelu(in_channels, middle_channels),
158 | nn.ConvTranspose2d(
159 | middle_channels, out_channels, kernel_size=4, stride=2, padding=1
160 | ),
161 | nn.ReLU(inplace=True),
162 | )
163 | else:
164 | self.block = nn.Sequential(
165 | Interpolate(scale_factor=2, mode="bilinear"),
166 | ConvRelu(in_channels, middle_channels),
167 | ConvRelu(middle_channels, out_channels),
168 | )
169 |
170 | def forward(self, x: torch.Tensor) -> torch.Tensor:
171 | return self.block(x)
--------------------------------------------------------------------------------
/nets/zoo/unet_BE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import skimage
4 | import numpy as np
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 |
8 |
9 | def double_conv(in_channels, out_channels):
10 | return nn.Sequential(
11 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
12 | nn.ReLU(inplace=True),
13 | nn.Conv2d(out_channels, out_channels, 3, padding=1),
14 | nn.ReLU(inplace=True)
15 | )
16 |
17 | class _up_deconv(nn.Module):
18 | def __init__(self, in_channels, out_channels):
19 | super(_up_deconv, self).__init__()
20 |
21 | self.deconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
22 |
23 | self.relu = nn.ReLU()
24 |
25 | def forward(self, x):
26 |
27 | out = self.deconv(x)
28 | out = self.relu(out)
29 |
30 | return out
31 |
32 | class UNet_BE(nn.Module):
33 |
34 | def __init__(self, n_class=1, pretrained=False, mode='Train'):
35 | super().__init__()
36 | self.mode = mode
37 | n_channels = 32
38 |
39 |
40 | self.dconv_down1 = double_conv(3, 64)
41 | self.dconv_down2 = double_conv(64, 128)
42 | self.dconv_down3 = double_conv(128, 256)
43 | self.dconv_down4 = double_conv(256, 512)
44 | self.dconv_down5 = double_conv(512, 1024)
45 | self.maxpool = nn.MaxPool2d(2)
46 | self.upsample5 = _up_deconv(1024,512)
47 | self.upsample4 = _up_deconv(512,256)
48 | self.upsample3 = _up_deconv(256,128)
49 | self.upsample2 = _up_deconv(128,64)
50 | self.dconv_up4 = double_conv(512 + 512, 512)
51 | self.dconv_up3 = double_conv(256 + 256, 256)
52 | self.dconv_up2 = double_conv(128 + 128, 128)
53 | self.dconv_up1 = double_conv(64 + 64, 64)
54 | # HED Block
55 | self.dsn1 = nn.Conv2d(64, 1, 1)
56 | self.dsn2 = nn.Conv2d(128, 1, 1)
57 | self.dsn3 = nn.Conv2d(256, 1, 1)
58 | self.dsn4 = nn.Conv2d(512, 1, 1)
59 | self.dsn5 = nn.Conv2d(1024, 1, 1)
60 |
61 | #boundary enhancement part
62 | self.fuse = nn.Sequential(nn.Conv2d(5, 64, 1),nn.ReLU(inplace=True))
63 | self.SE_mimic = nn.Sequential(
64 | nn.Linear(64, 64, bias=False),
65 | nn.ReLU(inplace=True),
66 | nn.Linear(64, 5, bias=False),
67 | nn.Sigmoid()
68 | )
69 | self.final_boundary = nn.Conv2d(5,2,1)
70 | self.final_conv = nn.Sequential(
71 | nn.Conv2d(128,64,3, padding=1),
72 | nn.ReLU(inplace=True)
73 | )
74 | self.final_mask = nn.Conv2d(64,2,1)
75 | self.relu = nn.ReLU()
76 | self.out = nn.Conv2d(64,1,1)
77 |
78 |
79 | def forward(self, x):
80 | h = x.size(2)
81 | w = x.size(3)
82 |
83 |
84 | conv1 = self.dconv_down1(x)
85 | x = self.maxpool(conv1)
86 | conv2 = self.dconv_down2(x)
87 | x = self.maxpool(conv2)
88 | conv3 = self.dconv_down3(x)
89 | x = self.maxpool(conv3)
90 | conv4 = self.dconv_down4(x)
91 | x = self.maxpool(conv4)
92 | conv5 = self.dconv_down5(x)
93 | x = self.upsample5(conv5)
94 | x = torch.cat([x, conv4],1)
95 | x = self.dconv_up4(x)
96 | x = self.upsample4(x)
97 | x = torch.cat([x, conv3], dim=1)
98 | x = self.dconv_up3(x)
99 | x = self.upsample3(x)
100 | x = torch.cat([x, conv2], dim=1)
101 | x = self.dconv_up2(x)
102 | x = self.upsample2(x)
103 | x = torch.cat([x, conv1], dim=1)
104 | x = self.dconv_up1(x)
105 | # out = F.sigmoid(self.out(x))
106 |
107 |
108 | ## side output
109 | d1 = self.dsn1(conv1)
110 | d2 = F.upsample_bilinear(self.dsn2(conv2), size=(h,w))
111 | d3 = F.upsample_bilinear(self.dsn3(conv3), size=(h,w))
112 | d4 = F.upsample_bilinear(self.dsn4(conv4), size=(h,w))
113 | d5 = F.upsample_bilinear(self.dsn5(conv5), size=(h,w))
114 |
115 | d1_out = F.sigmoid(d1)
116 | d2_out = F.sigmoid(d2)
117 | d3_out = F.sigmoid(d3)
118 | d4_out = F.sigmoid(d4)
119 | d5_out = F.sigmoid(d5)
120 | concat = torch.cat((d1_out, d2_out, d3_out, d4_out, d5_out), 1)
121 |
122 | fuse_box = self.fuse(concat)
123 | GAP = F.adaptive_avg_pool2d(fuse_box,(1,1))
124 | GAP = GAP.view(-1, 64)
125 | se_like = self.SE_mimic(GAP)
126 | se_like = torch.unsqueeze(se_like, 2)
127 | se_like = torch.unsqueeze(se_like, 3)
128 |
129 | feat_se = concat * se_like.expand_as(concat)
130 | boundary = self.final_boundary(feat_se)
131 | boundary_out = torch.unsqueeze(boundary[:,1,:,:],1)
132 | bd_sftmax = F.softmax(boundary, dim=1)
133 | boundary_scale = torch.unsqueeze(bd_sftmax[:,1,:,:],1)
134 |
135 | feat_concat = torch.cat( [x, fuse_box], 1)
136 | feat_concat_conv = self.final_conv(feat_concat)
137 | mask = self.final_mask(feat_concat_conv)
138 | mask_sftmax = F.softmax(mask,dim=1)
139 | mask_scale = torch.unsqueeze(mask_sftmax[:,1,:,:],1)
140 |
141 | if self.mode == 'Train':
142 | scalefactor = torch.clamp(mask_scale+boundary_scale,0,1)
143 | elif self.mode == 'Infer':
144 | scalefactor = torch.clamp(mask_scale+5*boundary_scale,0,1)
145 |
146 |
147 | mask_out = torch.unsqueeze(mask[:,1,:,:],1)
148 | relu = self.relu(mask_out)
149 | scalar = relu.cpu().detach().numpy()
150 | if np.sum(scalar) == 0:
151 | average = 0
152 | else :
153 | average = scalar[np.nonzero(scalar)].mean()
154 | mask_out = mask_out-relu + (average*scalefactor)
155 |
156 | if self.mode == 'Train':
157 | mask_out = F.sigmoid(mask_out)
158 | boundary_out = F.sigmoid(boundary_out)
159 |
160 | return d1_out, d2_out, d3_out, d4_out, d5_out, boundary_out, mask_out
161 | elif self.mode =='Infer':
162 | return mask_out
163 |
164 |
165 |
166 | #
167 |
--------------------------------------------------------------------------------
/nets/zoo/brrnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def conv_block(in_channels, out_channels):
6 | return nn.Sequential(
7 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
8 | nn.BatchNorm2d(num_features=out_channels),
9 | nn.ReLU(inplace=True),
10 | nn.Conv2d(out_channels, out_channels, 3, padding=1),
11 | nn.BatchNorm2d(num_features=out_channels),
12 | nn.ReLU(inplace=True)
13 | )
14 |
15 |
16 | def up_transpose(in_channels, out_channels):
17 | return nn.Sequential(
18 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
19 | )
20 | class center_block(nn.Module):
21 | def __init__(self, in_channels, out_channels):
22 | super(center_block, self).__init__()
23 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1,dilation=1)
24 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=2,dilation=2)
25 | self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=4,dilation=4)
26 | self.conv4 = nn.Conv2d(out_channels, out_channels, 3, padding=8,dilation=8)
27 | self.conv5 = nn.Conv2d(out_channels, out_channels, 3, padding=16,dilation=16)
28 | self.conv6 = nn.Conv2d(out_channels, out_channels, 3, padding=32,dilation=32)
29 |
30 | self.bn_1 = nn.BatchNorm2d(num_features=out_channels)
31 | self.bn_2 = nn.BatchNorm2d(num_features=out_channels)
32 | self.bn_3 = nn.BatchNorm2d(num_features=out_channels)
33 | self.bn_4 = nn.BatchNorm2d(num_features=out_channels)
34 | self.bn_5 = nn.BatchNorm2d(num_features=out_channels)
35 | self.bn_6 = nn.BatchNorm2d(num_features=out_channels)
36 | self.relu = nn.ReLU()
37 |
38 |
39 |
40 | def forward(self,x):# 지금 rrm쪽이랑 센터랑 섞임..
41 |
42 |
43 | x1 = self.relu(self.bn_1(self.conv1(x)))
44 |
45 | x2 = self.relu(self.bn_2(self.conv2(x1)))
46 |
47 | x3 = self.relu(self.bn_3(self.conv3(x2)))
48 |
49 | x4 = self.relu(self.bn_4(self.conv4(x3)))
50 |
51 | x5 = self.relu(self.bn_5(self.conv5(x4)))
52 |
53 | x6 = self.relu(self.bn_6(self.conv6(x5)))
54 |
55 |
56 | x = x1+x2+x3+x4+x5+x6
57 |
58 | return x
59 |
60 | class rrm_module(nn.Module):
61 | def __init__(self, in_channels, out_channels):
62 | super(rrm_module,self).__init__()
63 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1,dilation=1)
64 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=2,dilation=2)
65 | self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=4,dilation=4)
66 | self.conv4 = nn.Conv2d(out_channels, out_channels, 3, padding=8,dilation=8)
67 | self.conv5 = nn.Conv2d(out_channels, out_channels, 3, padding=16,dilation=16)
68 | self.conv6 = nn.Conv2d(out_channels, out_channels, 3, padding=32,dilation=32)
69 |
70 | self.bn_1 = nn.BatchNorm2d(num_features=out_channels)
71 | self.bn_2 = nn.BatchNorm2d(num_features=out_channels)
72 | self.bn_3 = nn.BatchNorm2d(num_features=out_channels)
73 | self.bn_4 = nn.BatchNorm2d(num_features=out_channels)
74 | self.bn_5 = nn.BatchNorm2d(num_features=out_channels)
75 | self.bn_6 = nn.BatchNorm2d(num_features=out_channels)
76 | self.relu = nn.ReLU()
77 |
78 | self.out = nn.Conv2d(out_channels, 1, 3, padding=1,dilation=1)
79 |
80 | def forward(self,x):
81 | residual = x
82 | x1 = self.relu(self.bn_1(self.conv1(x)))
83 |
84 | x2 = self.relu(self.bn_2(self.conv2(x1)))
85 | x3 = self.relu(self.bn_3(self.conv3(x2)))
86 | x4 = self.relu(self.bn_4(self.conv4(x3)))
87 | x5 = self.relu(self.bn_5(self.conv5(x4)))
88 | x6 = self.relu(self.bn_6(self.conv6(x5)))
89 | x = x1+x2+x3+x4+x5+x6
90 | x = self.out(x)
91 | x = residual + x
92 |
93 | return x
94 |
95 | class decoder_block(nn.Module):
96 | def __init__(self, in_channels, out_channels):
97 | super(decoder_block,self).__init__()
98 | self.bn_i = nn.BatchNorm2d(num_features=in_channels)
99 | self.relu = nn.ReLU()
100 | self.conv = conv_block(in_channels, out_channels)
101 | def forward(self, x):
102 |
103 | out = self.bn_i(x)
104 | out = self.relu(out)
105 | out = self.conv(out)
106 | return out
107 |
108 | class BRRNet(nn.Module):
109 |
110 | def __init__(self, n_class=1, pretrained=False,mode='Train'):
111 | super().__init__()
112 | self.mode=mode
113 | self.dconv_down1 = conv_block(3, 64)
114 | self.dconv_down2 = conv_block(64, 128)
115 | self.dconv_down3 = conv_block(128, 256)
116 |
117 | self.maxpool = nn.MaxPool2d(2,2)
118 | self.center = center_block(256,512)
119 | self.deconv3 = up_transpose(512,256)
120 | self.deconv2 = up_transpose(256,128)
121 | self.deconv1 = up_transpose(128,64)
122 |
123 | self.decoder_3 = decoder_block(512, 256)
124 | self.decoder_2 = decoder_block(256, 128)
125 | self.decoder_1 = decoder_block(128, 64)
126 | self.output_1 = nn.Conv2d(64,n_class, 1)
127 | self.rrm = rrm_module(1,64)
128 | def forward(self, x):
129 |
130 | conv1 = self.dconv_down1(x)
131 | # print(conv1.shape)
132 | x = self.maxpool(conv1)
133 | # print(x.shape)
134 | conv2 = self.dconv_down2(x)
135 | x = self.maxpool(conv2)
136 |
137 | conv3 = self.dconv_down3(x)
138 | x = self.maxpool(conv3)
139 |
140 | x = self.center(x)
141 |
142 | x = self.deconv3(x) # 512 256
143 | x = torch.cat([conv3,x],1) # 256 + 256
144 |
145 | x = self.decoder_3(x) # 512 256
146 |
147 | x = self.deconv2(x)
148 | x = torch.cat([conv2,x],1)
149 | x = self.decoder_2(x)
150 |
151 | x = self.deconv1(x)
152 | x = torch.cat([conv1,x],1)
153 | x = self.decoder_1(x)
154 |
155 | x = self.output_1(x)
156 | out = self.rrm(x)
157 | if self.mode == 'Train':
158 | return F.sigmoid(out)
159 | elif self.mode == 'Infer':
160 | return out
--------------------------------------------------------------------------------
/nets/zoo/uspp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class _stage_block(nn.Module):
7 | def __init__(self, channel_var):
8 | super(_stage_block, self).__init__()
9 |
10 | channel_in, channel_out = channel_var
11 |
12 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=3, stride=1, padding=1)
13 | self.bn = nn.BatchNorm2d(channel_out)
14 | self.relu = nn.ReLU()
15 |
16 | def forward(self, x):
17 | out = self.bn( self.conv(x) )
18 | out = self.relu(out)
19 | return out
20 |
21 |
22 | class _upss_block(nn.Module):
23 | def __init__(self, channel_in):
24 | super(_upss_block, self).__init__()
25 | self.conv1 = nn.Sequential(
26 | nn.MaxPool2d(1),
27 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=1, stride=1, padding=0),
28 | )
29 | self.conv2 = nn.Sequential(
30 | nn.MaxPool2d(2),
31 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=2, stride=1, padding=1),
32 | )
33 | self.conv3 = nn.Sequential(
34 | nn.MaxPool2d(3),
35 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=3, stride=1, padding=1),
36 | )
37 | self.conv4 = nn.Sequential(
38 | nn.MaxPool2d(6),
39 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=4, stride=1, padding=2),
40 | )
41 |
42 | def forward(self, x):
43 | residual = x
44 |
45 | h, w = x.size(2), x.size(3)
46 |
47 | out1 = self.conv1(x)
48 | out1 = F.upsample(input=out1, size=(h, w), mode='bilinear')
49 | out2 = self.conv2(x)
50 | out2 = F.upsample(input=out2, size=(h, w), mode='bilinear')
51 | out3 = self.conv3(x)
52 | out3 = F.upsample(input=out3, size=(h, w), mode='bilinear')
53 | out4 = self.conv4(x)
54 | out4 = F.upsample(input=out4, size=(h, w), mode='bilinear')
55 |
56 | out = torch.cat([out1, out2, out3, out4, residual], 1)
57 | return out
58 |
59 |
60 | class _down(nn.Module):
61 | def __init__(self, channel_in):
62 | super(_down, self).__init__()
63 | self.maxpool = nn.MaxPool2d(2)
64 |
65 | def forward(self, x):
66 | out = self.maxpool(x)
67 | return out
68 |
69 |
70 | class _up(nn.Module):
71 | def __init__(self, channel_in):
72 | super(_up, self).__init__()
73 |
74 | #self.relu = nn.PReLU()
75 | #self.subpixel = nn.PixelShuffle(2)
76 | self.subpixel = nn.ConvTranspose2d(in_channels=channel_in, out_channels=int(channel_in/2.), kernel_size=2, stride=2)
77 | #self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_in, kernel_size=1, stride=1, padding=0)
78 |
79 | def forward(self, x):
80 | #out = self.relu(self.conv(x))
81 | #out = self.subpixel(out)
82 | out = self.subpixel(x)
83 | return out
84 |
85 |
86 | class Uspp(nn.Module):
87 | def __init__(self, pretrained=False,mode='Train'):
88 | super(Uspp, self).__init__()
89 | self.mode=mode
90 | self.DCR_block11 = self.make_layer(_stage_block, [ 3, 64])
91 | self.DCR_block12 = self.make_layer(_stage_block, [ 64, 64])
92 | self.down1 = self.make_layer(_down, 64)
93 | self.DCR_block21 = self.make_layer(_stage_block, [ 64,128])
94 | self.DCR_block22 = self.make_layer(_stage_block, [128,128])
95 | self.down2 = self.make_layer(_down, 128)
96 | self.DCR_block31 = self.make_layer(_stage_block, [128,256])
97 | self.DCR_block32 = self.make_layer(_stage_block, [256,256])
98 | self.down3 = self.make_layer(_down, 256)
99 | self.DCR_block41 = self.make_layer(_stage_block, [256,512])
100 | self.DCR_block42 = self.make_layer(_stage_block, [512,512])
101 | self.down4 = self.make_layer(_down, 512)
102 |
103 | self.uspp = self.make_layer(_upss_block, 512)
104 |
105 | self.up4 = self.make_layer(_up, 1024)
106 | self.DCR_block43 = self.make_layer(_stage_block,[1024,512])
107 | self.DCR_block44 = self.make_layer(_stage_block, [512,512])
108 | self.up3 = self.make_layer(_up, 512)
109 | self.DCR_block33 = self.make_layer(_stage_block, [512,256])
110 | self.DCR_block34 = self.make_layer(_stage_block, [256,256])
111 | self.up2 = self.make_layer(_up, 256)
112 | self.DCR_block23 = self.make_layer(_stage_block, [256,128])
113 | self.DCR_block24 = self.make_layer(_stage_block, [128,128])
114 | self.up1 = self.make_layer(_up, 128)
115 | self.DCR_block13 = self.make_layer(_stage_block, [128, 64])
116 | self.DCR_block14 = self.make_layer(_stage_block, [ 64, 1])
117 |
118 | def make_layer(self, block, channel_in):
119 | layers = []
120 | layers.append(block(channel_in))
121 | return nn.Sequential(*layers)
122 |
123 | def forward(self, x):
124 | residual = x
125 |
126 | out = self.DCR_block11(x)
127 | conc1= self.DCR_block12(out)
128 | out = self.down1(conc1)
129 |
130 | out = self.DCR_block21(out)
131 | conc2= self.DCR_block22(out)
132 | out = self.down2(conc2)
133 |
134 | out = self.DCR_block31(out)
135 | conc3= self.DCR_block32(out)
136 | out = self.down3(conc3)
137 |
138 | out = self.DCR_block41(out)
139 | conc4= self.DCR_block42(out)
140 | out = self.down4(conc4)
141 |
142 | # bridge part
143 | out = self.uspp(out)
144 |
145 | out = self.up4(out)
146 | out = torch.cat([conc4, out], 1)
147 | out = self.DCR_block43(out)
148 | out = self.DCR_block44(out)
149 |
150 | out = self.up3(out)
151 | out = torch.cat([conc3, out], 1)
152 | out = self.DCR_block33(out)
153 | out = self.DCR_block34(out)
154 |
155 | out = self.up2(out)
156 | out = torch.cat([conc2, out], 1)
157 | out = self.DCR_block23(out)
158 | out = self.DCR_block24(out)
159 |
160 | out = self.up1(out)
161 | out = torch.cat([conc1, out], 1)
162 | out = self.DCR_block13(out)
163 | out = self.DCR_block14(out)
164 | if self.mode == 'Train':
165 | return F.sigmoid(out)
166 | elif self.mode == 'Infer':
167 | return out
--------------------------------------------------------------------------------
/nets/zoo/denet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import skimage
4 | import numpy as np
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 |
8 | class _downsampling(nn.Module):
9 | def __init__(self, channel_in):
10 | super(_downsampling, self).__init__()
11 | #channel_in, channel_out = channel_var
12 |
13 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_in, kernel_size=3, stride=2, padding=1)
14 | self.maxpool = nn.MaxPool2d(2)
15 |
16 | self.bn = nn.BatchNorm2d(2*channel_in)
17 | self.relu = nn.ReLU()
18 |
19 | def forward(self, x):
20 | out1= self.conv(x)
21 | out2= self.maxpool(x)
22 |
23 | out = torch.cat([out1, out2], 1)
24 | out = self.relu(self.bn(out))
25 | return out
26 |
27 |
28 | class _linear_residual(nn.Module):
29 | def __init__(self, channel_in):
30 | super(_linear_residual, self).__init__()
31 | #channel_in, channel_out = channel_var
32 |
33 | self.conv1 = nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=1, stride=1, padding=0)
34 | self.bn1 = nn.BatchNorm2d(int(channel_in/4.))
35 | self.relu1= nn.ELU(alpha=1.673)
36 |
37 | self.conv2 = nn.Conv2d(in_channels=int(channel_in/4.), out_channels=int(channel_in/4.), kernel_size=3, stride=1, padding=1)
38 | self.bn2 = nn.BatchNorm2d(int(channel_in/4.))
39 | self.relu2= nn.ELU(alpha=1.673)
40 |
41 | self.conv3 = nn.Conv2d(in_channels=int(channel_in/4.), out_channels=channel_in, kernel_size=1, stride=1, padding=0)
42 |
43 | def forward(self, x):
44 | residual = x
45 | _lambda = 1.051
46 |
47 | out = self.bn1(self.conv1(x))
48 | out = self.relu1(out) * _lambda
49 |
50 | out = self.bn2(self.conv2(out))
51 | out = self.relu2(out) * _lambda
52 |
53 | out = self.conv3(out)
54 |
55 | out = torch.add(out, residual)
56 | return out
57 |
58 | class _encoding_block(nn.Module):
59 | def __init__(self, channel_in):
60 | super(_encoding_block, self).__init__()
61 |
62 | self.block_1 = nn.Sequential(
63 | _linear_residual(channel_in=channel_in),
64 | _linear_residual(channel_in=channel_in),
65 | _linear_residual(channel_in=channel_in),
66 | _linear_residual(channel_in=channel_in),
67 | _linear_residual(channel_in=channel_in),
68 | _linear_residual(channel_in=channel_in),
69 | )
70 |
71 | def forward(self, x):
72 | return self.block_1(x)
73 |
74 |
75 | class _compressing_module(nn.Module):
76 | def __init__(self, channel_in):
77 | super(_compressing_module, self).__init__()
78 |
79 | self.conv1 = nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=1, stride=1, padding=0)
80 | self.bn1 = nn.BatchNorm2d(int(channel_in/4.))
81 | self.relu1= nn.ReLU()
82 |
83 | self.conv2 = nn.Conv2d(in_channels=int(channel_in/4.), out_channels=int(channel_in/4.), kernel_size=3, stride=1, padding=1)
84 | self.bn2 = nn.BatchNorm2d(int(channel_in/4.))
85 | self.relu2= nn.ReLU()
86 |
87 | self.conv3 = nn.Conv2d(in_channels=int(channel_in/4.), out_channels=channel_in, kernel_size=1, stride=1, padding=0)
88 |
89 | def forward(self, x):
90 | residual = x
91 |
92 | out = self.bn1(self.conv1(x))
93 | out = self.relu1(out)
94 |
95 | out = self.bn2(self.conv2(out))
96 | out = self.relu2(out)
97 |
98 | out = self.conv3(out)
99 | return out
100 |
101 |
102 | class _duc(nn.Module):
103 | def __init__(self):
104 | super(_duc, self).__init__()
105 |
106 | self.subpixel = nn.PixelShuffle(8)
107 |
108 | def forward(self, x):
109 | #out = self.relu(self.conv(x))
110 | #out = self.subpixel(out)
111 | out = self.subpixel(x)
112 | return out
113 |
114 |
115 | class DeNet(nn.Module):
116 | def __init__(self, pretrained=False,mode='Train'):
117 | super(DeNet, self).__init__()
118 | self.mode=mode
119 | self.conv_i = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1, stride=1, padding=0)
120 | #self.relu1 = nn.PReLU()
121 | self.relu1 = nn.ReLU()
122 | #self.relu1 = nn.LeakyReLU(0.1)
123 | self.DS_block_1 = self.make_layer(_downsampling, 64)
124 | self.EC_block_1 = self.make_layer(_encoding_block, 128)
125 |
126 | self.DS_block_2 = self.make_layer(_downsampling, 128)
127 | self.EC_block_2 = self.make_layer(_encoding_block, 256)
128 |
129 | self.DS_block_3 = self.make_layer(_downsampling, 256)
130 | self.EC_block_3 = self.make_layer(_encoding_block, 512)
131 |
132 | self.CP_block_41= self.make_layer(_compressing_module, 512)
133 | self.EC_block_42= self.make_layer(_encoding_block, 512)
134 | self.CP_block_43= self.make_layer(_compressing_module, 512)
135 | self.EC_block_44= self.make_layer(_encoding_block, 512)
136 | self.CP_block_45= self.make_layer(_compressing_module, 512)
137 | self.EC_block_46= self.make_layer(_encoding_block, 512)
138 | self.CP_block_47= self.make_layer(_compressing_module, 512)
139 |
140 | self.conv_f = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=1, stride=1, padding=0)
141 | #self.relu2 = nn.PReLU()
142 | self.relu2 = nn.ReLU()
143 | #self.relu2 = nn.LeakyReLU(0.1)
144 |
145 | self.dcu = _duc()
146 |
147 | def make_layer(self, block, channel_in):
148 | layers = []
149 | layers.append(block(channel_in))
150 | return nn.Sequential(*layers)
151 |
152 | def forward(self, x):
153 | residual = x
154 |
155 | out = self.relu1(self.conv_i(x))
156 | out = self.DS_block_1(out)
157 | out = self.EC_block_1(out)
158 |
159 | out = self.DS_block_2(out)
160 | out = self.EC_block_2(out)
161 |
162 | out = self.DS_block_3(out)
163 | out = self.EC_block_3(out)
164 |
165 | out = self.CP_block_41(out)
166 | out = self.EC_block_42(out)
167 | out = self.CP_block_43(out)
168 | out = self.EC_block_44(out)
169 | out = self.CP_block_45(out)
170 | out = self.EC_block_46(out)
171 | out = self.CP_block_47(out)
172 |
173 | out = self.relu2(self.conv_f(out))
174 | out = self.dcu(out)
175 | if self.mode == 'Train':
176 | return F.sigmoid(out)
177 | elif self.mode == 'Infer':
178 | return out
--------------------------------------------------------------------------------
/nets/datagen.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import rasterio
4 | from torch.utils.data import Dataset, DataLoader
5 | from .transform import _check_augs, process_aug_dict
6 | from utils.core import _check_df_load
7 | from utils.io import imread, _check_channel_order
8 | import skimage
9 | import skimage.segmentation
10 | def make_data_generator(framework, config, df, stage='train'):
11 |
12 | if framework.lower() not in ['pytorch', 'torch']:
13 | raise ValueError('{} is not an accepted value for `framework`'.format(
14 | framework))
15 |
16 | # make sure the df is loaded
17 | df = _check_df_load(df)
18 |
19 | if stage == 'train':
20 | augs = config['training_augmentation']
21 | shuffle = config['training_augmentation']['shuffle']
22 | elif stage == 'validate':
23 | augs = config['validation_augmentation']
24 | shuffle = False
25 | try:
26 | num_classes = config['data_specs']['num_classes']
27 | except KeyError:
28 | num_classes = 1
29 |
30 |
31 | if framework in ['torch', 'pytorch']:
32 | dataset = TorchDataset(
33 | df,
34 | augs=augs,
35 | batch_size=config['batch_size'],
36 | label_type=config['data_specs']['label_type'],
37 | is_categorical=config['data_specs']['is_categorical'],
38 | num_classes=num_classes,
39 | dtype=config['data_specs']['dtype'])
40 | # set up workers for DataLoader for pytorch
41 | data_workers = config['data_specs'].get('data_workers')
42 | if data_workers == 1 or data_workers is None:
43 | data_workers = 0 # for DataLoader to run in main process
44 | data_gen = DataLoader(
45 | dataset,
46 | batch_size=config['batch_size'],
47 | shuffle=config['training_augmentation']['shuffle'],
48 | num_workers=data_workers,
49 | drop_last=True)
50 |
51 | return data_gen
52 |
53 |
54 |
55 | class TorchDataset(Dataset):
56 |
57 | def __init__(self, df, augs, batch_size, label_type='mask',
58 | is_categorical=False, num_classes=1, dtype=None):
59 |
60 | super().__init__()
61 |
62 | self.df = df
63 | self.batch_size = batch_size
64 | self.n_batches = int(np.floor(len(self.df)/self.batch_size))
65 | self.aug = _check_augs(augs)
66 | self.is_categorical = is_categorical
67 | self.num_classes = num_classes
68 |
69 | if dtype is None:
70 | self.dtype = np.float32 # default
71 | # if it's a string, get the appropriate object
72 | elif isinstance(dtype, str):
73 | try:
74 | self.dtype = getattr(np, dtype)
75 | except AttributeError:
76 | raise ValueError(
77 | 'The data type {} is not supported'.format(dtype))
78 | # lastly, check if it's already defined in the right format for use
79 | elif issubclass(dtype, np.number) or isinstance(dtype, np.dtype):
80 | self.dtype = dtype
81 |
82 | def __len__(self):
83 | return len(self.df)
84 |
85 | def __getitem__(self, idx):
86 | """Get one image, mask pair"""
87 | # Generate indexes of the batch
88 | image = imread(self.df['image'].iloc[idx])
89 | mask = imread(self.df['label'].iloc[idx])
90 | boundary = mask
91 |
92 | if not self.is_categorical:
93 | mask[mask != 0] = 1
94 | if len(mask.shape) == 2:
95 | mask = mask[:, :, np.newaxis]
96 | if len(image.shape) == 2:
97 | image = image[:, :, np.newaxis]
98 |
99 | if len(boundary.shape) == 2:
100 | boundary = boundary[:, :, np.newaxis]
101 |
102 | sample = {'image': image, 'mask': mask, 'boundary' : boundary}
103 |
104 | if self.aug:
105 | sample = self.aug(**sample)
106 |
107 |
108 |
109 | sample['image'] = _check_channel_order(sample['image'],
110 | 'torch').astype(self.dtype)
111 | sample['mask'] = _check_channel_order(sample['mask'],
112 | 'torch').astype(np.float32)
113 |
114 | sample['boundary'] = _check_channel_order(skimage.segmentation.find_boundaries(sample['mask'], mode='inner', background=0),
115 | 'torch').astype(np.float32)
116 |
117 | return sample
118 |
119 |
120 | class InferenceTiler(object):
121 |
122 |
123 | def __init__(self, framework, width, height, x_step=None, y_step=None,
124 | augmentations=None):
125 |
126 | self.framework = framework
127 | self.width = width
128 | self.height = height
129 | if x_step is None:
130 | self.x_step = self.width
131 | else:
132 | self.x_step = x_step
133 | if y_step is None:
134 | self.y_step = self.height
135 | else:
136 | self.y_step = y_step
137 | self.aug = _check_augs(augmentations)
138 |
139 | def __call__(self, im):
140 |
141 | # read in the image if it's a path
142 | if isinstance(im, str):
143 | im = imread(im)
144 |
145 | # determine how many samples will be generated with the sliding window
146 | src_im_height = im.shape[0]
147 | src_im_width = im.shape[1]
148 |
149 |
150 |
151 | y_steps = int(1+np.ceil((src_im_height-self.height)/self.y_step))
152 | x_steps = int(1+np.ceil((src_im_width-self.width)/self.x_step))
153 | if len(im.shape) == 2: # if there's no channel axis
154 | im = im[:, :, np.newaxis] # create one - will be needed for model
155 | top_left_corner_idxs = []
156 | output_arr = []
157 | for y in range(y_steps):
158 | if self.y_step*y + self.height > im.shape[0]:
159 | y_min = im.shape[0] - self.height
160 | else:
161 | y_min = self.y_step*y
162 |
163 | for x in range(x_steps):
164 | if self.x_step*x + self.width > im.shape[1]:
165 | x_min = im.shape[1] - self.width
166 | else:
167 | x_min = self.x_step*x
168 |
169 | subarr = im[y_min:y_min + self.height,
170 | x_min:x_min + self.width,
171 | :]
172 | if self.aug is not None:
173 | subarr = self.aug(image=subarr)['image']
174 | output_arr.append(subarr)
175 | top_left_corner_idxs.append((y_min, x_min))
176 | output_arr = np.stack(output_arr).astype(np.float32)
177 | if self.framework in ['torch', 'pytorch']:
178 | output_arr = np.moveaxis(output_arr, 3, 1)
179 |
180 |
181 | return output_arr, top_left_corner_idxs, (src_im_height, src_im_width)
182 |
--------------------------------------------------------------------------------
/nets/zoo/resunet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResUnetPlusPlus(nn.Module):
7 | def __init__(self, filters=[32, 64, 128, 256, 512], pretrained=False,mode='Train'):
8 | super(ResUnetPlusPlus, self).__init__()
9 | self.mode=mode
10 | self.input_layer = nn.Sequential(
11 | nn.Conv2d(3, filters[0], kernel_size=3, padding=1),
12 | nn.BatchNorm2d(filters[0]),
13 | nn.ReLU(),
14 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15 | )
16 | self.input_skip = nn.Sequential(
17 | nn.Conv2d(3, filters[0], kernel_size=3, padding=1)
18 | )
19 |
20 | self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])
21 |
22 | self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)
23 |
24 | self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])
25 |
26 | self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)
27 |
28 | self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])
29 |
30 | self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)
31 |
32 | self.aspp_bridge = ASPP(filters[3], filters[4])
33 |
34 | self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])
35 | self.upsample1 = Upsample_(2)
36 | self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)
37 |
38 | self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])
39 | self.upsample2 = Upsample_(2)
40 | self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)
41 |
42 | self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])
43 | self.upsample3 = Upsample_(2)
44 | self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)
45 |
46 | self.aspp_out = ASPP(filters[1], filters[0])
47 |
48 | self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1))
49 |
50 | def forward(self, x):
51 | x1 = self.input_layer(x) + self.input_skip(x)
52 |
53 | x2 = self.squeeze_excite1(x1)
54 | x2 = self.residual_conv1(x2)
55 |
56 | x3 = self.squeeze_excite2(x2)
57 | x3 = self.residual_conv2(x3)
58 |
59 | x4 = self.squeeze_excite3(x3)
60 | x4 = self.residual_conv3(x4)
61 |
62 | x5 = self.aspp_bridge(x4)
63 |
64 | x6 = self.attn1(x3, x5)
65 | x6 = self.upsample1(x6)
66 | x6 = torch.cat([x6, x3], dim=1)
67 | x6 = self.up_residual_conv1(x6)
68 |
69 | x7 = self.attn2(x2, x6)
70 | x7 = self.upsample2(x7)
71 | x7 = torch.cat([x7, x2], dim=1)
72 | x7 = self.up_residual_conv2(x7)
73 |
74 | x8 = self.attn3(x1, x7)
75 | x8 = self.upsample3(x8)
76 | x8 = torch.cat([x8, x1], dim=1)
77 | x8 = self.up_residual_conv3(x8)
78 |
79 | x9 = self.aspp_out(x8)
80 | out = self.output_layer(x9)
81 | if self.mode == 'Train':
82 | return F.sigmoid(out)
83 | elif self.mode == 'Infer':
84 | return out
85 | class ResidualConv(nn.Module):
86 | def __init__(self, input_dim, output_dim, stride, padding):
87 | super(ResidualConv, self).__init__()
88 |
89 | self.conv_block = nn.Sequential(
90 | nn.BatchNorm2d(input_dim),
91 | nn.ReLU(),
92 | nn.Conv2d(
93 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
94 | ),
95 | nn.BatchNorm2d(output_dim),
96 | nn.ReLU(),
97 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
98 | )
99 | self.conv_skip = nn.Sequential(
100 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
101 | nn.BatchNorm2d(output_dim),
102 | )
103 |
104 | def forward(self, x):
105 |
106 | return self.conv_block(x) + self.conv_skip(x)
107 |
108 |
109 | class Upsample(nn.Module):
110 | def __init__(self, input_dim, output_dim, kernel, stride):
111 | super(Upsample, self).__init__()
112 |
113 | self.upsample = nn.ConvTranspose2d(
114 | input_dim, output_dim, kernel_size=kernel, stride=stride
115 | )
116 |
117 | def forward(self, x):
118 | return self.upsample(x)
119 |
120 |
121 | class Squeeze_Excite_Block(nn.Module):
122 | def __init__(self, channel, reduction=16):
123 | super(Squeeze_Excite_Block, self).__init__()
124 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
125 | self.fc = nn.Sequential(
126 | nn.Linear(channel, channel // reduction, bias=False),
127 | nn.ReLU(inplace=True),
128 | nn.Linear(channel // reduction, channel, bias=False),
129 | nn.Sigmoid(),
130 | )
131 |
132 | def forward(self, x):
133 | b, c, _, _ = x.size()
134 | y = self.avg_pool(x).view(b, c)
135 | y = self.fc(y).view(b, c, 1, 1)
136 | return x * y.expand_as(x)
137 |
138 |
139 | class ASPP(nn.Module):
140 | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
141 | super(ASPP, self).__init__()
142 |
143 | self.aspp_block1 = nn.Sequential(
144 | nn.Conv2d(
145 | in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
146 | ),
147 | nn.ReLU(inplace=True),
148 | nn.BatchNorm2d(out_dims),
149 | )
150 | self.aspp_block2 = nn.Sequential(
151 | nn.Conv2d(
152 | in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
153 | ),
154 | nn.ReLU(inplace=True),
155 | nn.BatchNorm2d(out_dims),
156 | )
157 | self.aspp_block3 = nn.Sequential(
158 | nn.Conv2d(
159 | in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
160 | ),
161 | nn.ReLU(inplace=True),
162 | nn.BatchNorm2d(out_dims),
163 | )
164 |
165 | self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
166 | self._init_weights()
167 |
168 | def forward(self, x):
169 | x1 = self.aspp_block1(x)
170 | x2 = self.aspp_block2(x)
171 | x3 = self.aspp_block3(x)
172 | out = torch.cat([x1, x2, x3], dim=1)
173 | return self.output(out)
174 |
175 | def _init_weights(self):
176 | for m in self.modules():
177 | if isinstance(m, nn.Conv2d):
178 | nn.init.kaiming_normal_(m.weight)
179 | elif isinstance(m, nn.BatchNorm2d):
180 | m.weight.data.fill_(1)
181 | m.bias.data.zero_()
182 |
183 |
184 | class Upsample_(nn.Module):
185 | def __init__(self, scale=2):
186 | super(Upsample_, self).__init__()
187 |
188 | self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
189 |
190 | def forward(self, x):
191 | return self.upsample(x)
192 |
193 |
194 | class AttentionBlock(nn.Module):
195 | def __init__(self, input_encoder, input_decoder, output_dim):
196 | super(AttentionBlock, self).__init__()
197 |
198 | self.conv_encoder = nn.Sequential(
199 | nn.BatchNorm2d(input_encoder),
200 | nn.ReLU(),
201 | nn.Conv2d(input_encoder, output_dim, 3, padding=1),
202 | nn.MaxPool2d(2, 2),
203 | )
204 |
205 | self.conv_decoder = nn.Sequential(
206 | nn.BatchNorm2d(input_decoder),
207 | nn.ReLU(),
208 | nn.Conv2d(input_decoder, output_dim, 3, padding=1),
209 | )
210 |
211 | self.conv_attn = nn.Sequential(
212 | nn.BatchNorm2d(output_dim),
213 | nn.ReLU(),
214 | nn.Conv2d(output_dim, 1, 1),
215 | )
216 |
217 | def forward(self, x1, x2):
218 | out = self.conv_encoder(x1) + self.conv_decoder(x2)
219 | out = self.conv_attn(out)
220 | return out * x2
--------------------------------------------------------------------------------
/nets/zoo/uspp_BE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import skimage
4 | import numpy as np
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 |
8 |
9 |
10 | class _stage_block(nn.Module):
11 | def __init__(self, channel_var):
12 | super(_stage_block, self).__init__()
13 |
14 | channel_in, channel_out = channel_var
15 |
16 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=3, stride=1, padding=1)
17 | self.bn = nn.BatchNorm2d(channel_out)
18 | self.relu = nn.ReLU()
19 |
20 | def forward(self, x):
21 | out = self.bn( self.conv(x) )
22 | out = self.relu(out)
23 | return out
24 |
25 |
26 | class _upss_block(nn.Module):
27 | def __init__(self, channel_in):
28 | super(_upss_block, self).__init__()
29 | self.conv1 = nn.Sequential(
30 | nn.MaxPool2d(1),
31 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=1, stride=1, padding=0),
32 | )
33 | self.conv2 = nn.Sequential(
34 | nn.MaxPool2d(2),
35 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=2, stride=1, padding=1),
36 | )
37 | self.conv3 = nn.Sequential(
38 | nn.MaxPool2d(3),
39 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=3, stride=1, padding=1),
40 | )
41 | self.conv4 = nn.Sequential(
42 | nn.MaxPool2d(6),
43 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=4, stride=1, padding=2),
44 | )
45 |
46 | def forward(self, x):
47 | residual = x
48 |
49 | h, w = x.size(2), x.size(3)
50 |
51 | out1 = self.conv1(x)
52 | out1 = F.upsample(input=out1, size=(h, w), mode='bilinear')
53 | out2 = self.conv2(x)
54 | out2 = F.upsample(input=out2, size=(h, w), mode='bilinear')
55 | out3 = self.conv3(x)
56 | out3 = F.upsample(input=out3, size=(h, w), mode='bilinear')
57 | out4 = self.conv4(x)
58 | out4 = F.upsample(input=out4, size=(h, w), mode='bilinear')
59 |
60 | out = torch.cat([out1, out2, out3, out4, residual], 1)
61 | return out
62 |
63 |
64 | class _down(nn.Module):
65 | def __init__(self, channel_in):
66 | super(_down, self).__init__()
67 | self.maxpool = nn.MaxPool2d(2)
68 |
69 | def forward(self, x):
70 | out = self.maxpool(x)
71 | return out
72 |
73 |
74 | class _up(nn.Module):
75 | def __init__(self, channel_in):
76 | super(_up, self).__init__()
77 |
78 | #self.relu = nn.PReLU()
79 | #self.subpixel = nn.PixelShuffle(2)
80 | self.subpixel = nn.ConvTranspose2d(in_channels=channel_in, out_channels=int(channel_in/2.), kernel_size=2, stride=2)
81 | #self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_in, kernel_size=1, stride=1, padding=0)
82 |
83 | def forward(self, x):
84 | #out = self.relu(self.conv(x))
85 | #out = self.subpixel(out)
86 | out = self.subpixel(x)
87 | return out
88 |
89 |
90 | class Uspp_BE(nn.Module):
91 | def __init__(self, pretrained=False, mode= 'Train'):
92 | super(Uspp_BE, self).__init__()
93 | self.mode = mode
94 | self.DCR_block11 = self.make_layer(_stage_block, [ 3, 64])
95 | self.DCR_block12 = self.make_layer(_stage_block, [ 64, 64])
96 | self.down1 = self.make_layer(_down, 64)
97 | self.DCR_block21 = self.make_layer(_stage_block, [ 64,128])
98 | self.DCR_block22 = self.make_layer(_stage_block, [128,128])
99 | self.down2 = self.make_layer(_down, 128)
100 | self.DCR_block31 = self.make_layer(_stage_block, [128,256])
101 | self.DCR_block32 = self.make_layer(_stage_block, [256,256])
102 | self.down3 = self.make_layer(_down, 256)
103 | self.DCR_block41 = self.make_layer(_stage_block, [256,512])
104 | self.DCR_block42 = self.make_layer(_stage_block, [512,512])
105 | self.down4 = self.make_layer(_down, 512)
106 |
107 | self.uspp = self.make_layer(_upss_block, 512)
108 |
109 | self.up4 = self.make_layer(_up, 1024)
110 | self.DCR_block43 = self.make_layer(_stage_block,[1024,512])
111 | self.DCR_block44 = self.make_layer(_stage_block, [512,512])
112 | self.up3 = self.make_layer(_up, 512)
113 | self.DCR_block33 = self.make_layer(_stage_block, [512,256])
114 | self.DCR_block34 = self.make_layer(_stage_block, [256,256])
115 | self.up2 = self.make_layer(_up, 256)
116 | self.DCR_block23 = self.make_layer(_stage_block, [256,128])
117 | self.DCR_block24 = self.make_layer(_stage_block, [128,128])
118 | self.up1 = self.make_layer(_up, 128)
119 | self.DCR_block13 = self.make_layer(_stage_block, [128, 64])
120 | # self.DCR_block14 = self.make_layer(_stage_block, [ 64, 1])
121 | self.DCR_block14 = self.make_layer(_stage_block, [ 64, 64])
122 | # HED Block
123 | self.dsn1 = nn.Conv2d(64, 1, 1)
124 | self.dsn2 = nn.Conv2d(128, 1, 1)
125 | self.dsn3 = nn.Conv2d(256, 1, 1)
126 | self.dsn4 = nn.Conv2d(512, 1, 1)
127 | self.dsn5 = nn.Conv2d(1024, 1, 1)
128 |
129 | #boundary enhancement part
130 | self.fuse = nn.Sequential(nn.Conv2d(5, 64, 1),nn.ReLU(inplace=True))
131 |
132 | self.SE_mimic = nn.Sequential(
133 | nn.Linear(64, 64, bias=False),
134 | nn.ReLU(inplace=True),
135 | nn.Linear(64, 5, bias=False),
136 | nn.Sigmoid()
137 | )
138 | self.final_boundary = nn.Conv2d(5,2,1)
139 |
140 | self.final_conv = nn.Sequential(
141 | nn.Conv2d(128,64,3, padding=1),
142 | nn.ReLU(inplace=True)
143 | )
144 | self.final_mask = nn.Conv2d(64,2,1)
145 |
146 |
147 |
148 | self.relu = nn.ReLU()
149 | self.out = nn.Conv2d(64,1,1)
150 |
151 |
152 | def make_layer(self, block, channel_in):
153 | layers = []
154 | layers.append(block(channel_in))
155 | return nn.Sequential(*layers)
156 |
157 | def forward(self, x):
158 | residual = x
159 | h = x.size(2)
160 | w = x.size(3)
161 |
162 | out = self.DCR_block11(x)
163 | conc1= self.DCR_block12(out)
164 | out = self.down1(conc1)
165 |
166 | out = self.DCR_block21(out)
167 | conc2= self.DCR_block22(out)
168 | out = self.down2(conc2)
169 |
170 | out = self.DCR_block31(out)
171 | conc3= self.DCR_block32(out)
172 | out = self.down3(conc3)
173 |
174 | out = self.DCR_block41(out)
175 | conc4= self.DCR_block42(out)
176 | out = self.down4(conc4)
177 |
178 | # bridge part
179 | conc5 = self.uspp(out)
180 |
181 | out = self.up4(conc5)
182 | out = torch.cat([conc4, out], 1)
183 | out = self.DCR_block43(out)
184 | out = self.DCR_block44(out)
185 |
186 | out = self.up3(out)
187 | out = torch.cat([conc3, out], 1)
188 | out = self.DCR_block33(out)
189 | out = self.DCR_block34(out)
190 |
191 | out = self.up2(out)
192 | out = torch.cat([conc2, out], 1)
193 | out = self.DCR_block23(out)
194 | out = self.DCR_block24(out)
195 |
196 | out = self.up1(out)
197 | out = torch.cat([conc1, out], 1)
198 | out = self.DCR_block13(out)
199 | out = self.DCR_block14(out)
200 |
201 | d1 = self.dsn1(conc1)
202 | d2 = F.upsample_bilinear(self.dsn2(conc2), size=(h,w))
203 | d3 = F.upsample_bilinear(self.dsn3(conc3), size=(h,w))
204 | d4 = F.upsample_bilinear(self.dsn4(conc4), size=(h,w))
205 | d5 = F.upsample_bilinear(self.dsn5(conc5), size=(h,w))
206 | d1_out = F.sigmoid(d1)
207 | d2_out = F.sigmoid(d2)
208 | d3_out = F.sigmoid(d3)
209 | d4_out = F.sigmoid(d4)
210 | d5_out = F.sigmoid(d5)
211 | concat = torch.cat((d1_out, d2_out, d3_out, d4_out, d5_out), 1)
212 |
213 | fuse_box = self.fuse(concat)
214 | GAP = F.adaptive_avg_pool2d(fuse_box,(1,1))
215 | GAP = GAP.view(-1, 64)
216 | se_like = self.SE_mimic(GAP)
217 | se_like = torch.unsqueeze(se_like, 2)
218 | se_like = torch.unsqueeze(se_like, 3)
219 |
220 | feat_se = concat * se_like.expand_as(concat)
221 | boundary = self.final_boundary(feat_se)
222 | boundary_out = torch.unsqueeze(boundary[:,1,:,:],1)
223 | bd_sftmax = F.softmax(boundary, dim=1)
224 | boundary_scale = torch.unsqueeze(bd_sftmax[:,1,:,:],1)
225 |
226 | feat_concat = torch.cat( [out, fuse_box], 1)
227 | feat_concat_conv = self.final_conv(feat_concat)
228 | mask = self.final_mask(feat_concat_conv)
229 | mask_sftmax = F.softmax(mask,dim=1)
230 | mask_scale = torch.unsqueeze(mask_sftmax[:,1,:,:],1)
231 |
232 | if self.mode == 'Train':
233 | scalefactor = torch.clamp(mask_scale+boundary_scale,0,1)
234 | elif self.mode == 'Infer':
235 | scalefactor = torch.clamp(mask_scale+5*boundary_scale,0,1)
236 |
237 |
238 | mask_out = torch.unsqueeze(mask[:,1,:,:],1)
239 | relu = self.relu(mask_out)
240 | scalar = relu.cpu().detach().numpy()
241 | if np.sum(scalar) == 0:
242 | average = 0
243 | else :
244 | average = scalar[np.nonzero(scalar)].mean()
245 | mask_out = mask_out-relu + (average*scalefactor)
246 |
247 | if self.mode == 'Train':
248 | mask_out = F.sigmoid(mask_out)
249 | boundary_out = F.sigmoid(boundary_out)
250 |
251 | return d1_out, d2_out, d3_out, d4_out, d5_out, boundary_out, mask_out
252 | elif self.mode =='Infer':
253 | return mask_out
254 |
255 |
--------------------------------------------------------------------------------
/nets/zoo/brrnet_BE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import skimage
4 | import numpy as np
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 |
8 | def conv_block(in_channels, out_channels):
9 | return nn.Sequential(
10 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
11 | nn.BatchNorm2d(num_features=out_channels),
12 | nn.ReLU(inplace=True),
13 | nn.Conv2d(out_channels, out_channels, 3, padding=1),
14 | nn.BatchNorm2d(num_features=out_channels),
15 | nn.ReLU(inplace=True)
16 | )
17 |
18 |
19 | def up_transpose(in_channels, out_channels):
20 | return nn.Sequential(
21 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
22 | )
23 | class center_block(nn.Module):
24 | def __init__(self, in_channels, out_channels):
25 | super(center_block, self).__init__()
26 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1,dilation=1)
27 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=2,dilation=2)
28 | self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=4,dilation=4)
29 | self.conv4 = nn.Conv2d(out_channels, out_channels, 3, padding=8,dilation=8)
30 | self.conv5 = nn.Conv2d(out_channels, out_channels, 3, padding=16,dilation=16)
31 | self.conv6 = nn.Conv2d(out_channels, out_channels, 3, padding=32,dilation=32)
32 |
33 | self.bn_1 = nn.BatchNorm2d(num_features=out_channels)
34 | self.bn_2 = nn.BatchNorm2d(num_features=out_channels)
35 | self.bn_3 = nn.BatchNorm2d(num_features=out_channels)
36 | self.bn_4 = nn.BatchNorm2d(num_features=out_channels)
37 | self.bn_5 = nn.BatchNorm2d(num_features=out_channels)
38 | self.bn_6 = nn.BatchNorm2d(num_features=out_channels)
39 | self.relu = nn.ReLU()
40 |
41 |
42 |
43 | def forward(self,x):# 지금 rrm쪽이랑 센터랑 섞임..
44 |
45 |
46 | x1 = self.relu(self.bn_1(self.conv1(x)))
47 |
48 | x2 = self.relu(self.bn_2(self.conv2(x1)))
49 |
50 | x3 = self.relu(self.bn_3(self.conv3(x2)))
51 |
52 | x4 = self.relu(self.bn_4(self.conv4(x3)))
53 |
54 | x5 = self.relu(self.bn_5(self.conv5(x4)))
55 |
56 | x6 = self.relu(self.bn_6(self.conv6(x5)))
57 |
58 |
59 | x = x1+x2+x3+x4+x5+x6
60 |
61 | return x
62 |
63 | class rrm_module(nn.Module):
64 | def __init__(self, in_channels, out_channels):
65 | super(rrm_module,self).__init__()
66 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1,dilation=1)
67 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=2,dilation=2)
68 | self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=4,dilation=4)
69 | self.conv4 = nn.Conv2d(out_channels, out_channels, 3, padding=8,dilation=8)
70 | self.conv5 = nn.Conv2d(out_channels, out_channels, 3, padding=16,dilation=16)
71 | self.conv6 = nn.Conv2d(out_channels, out_channels, 3, padding=32,dilation=32)
72 |
73 | self.bn_1 = nn.BatchNorm2d(num_features=out_channels)
74 | self.bn_2 = nn.BatchNorm2d(num_features=out_channels)
75 | self.bn_3 = nn.BatchNorm2d(num_features=out_channels)
76 | self.bn_4 = nn.BatchNorm2d(num_features=out_channels)
77 | self.bn_5 = nn.BatchNorm2d(num_features=out_channels)
78 | self.bn_6 = nn.BatchNorm2d(num_features=out_channels)
79 | self.relu = nn.ReLU()
80 |
81 | # self.out = nn.Conv2d(out_channels, 1, 3, padding=1,dilation=1)
82 | # BE mode
83 | self.out = nn.Conv2d(out_channels, 64, 3, padding=1,dilation=1)
84 |
85 | def forward(self,x):
86 | residual = x
87 | x1 = self.relu(self.bn_1(self.conv1(x)))
88 |
89 | x2 = self.relu(self.bn_2(self.conv2(x1)))
90 | x3 = self.relu(self.bn_3(self.conv3(x2)))
91 | x4 = self.relu(self.bn_4(self.conv4(x3)))
92 | x5 = self.relu(self.bn_5(self.conv5(x4)))
93 | x6 = self.relu(self.bn_6(self.conv6(x5)))
94 | x = x1+x2+x3+x4+x5+x6
95 | x = self.out(x)
96 | x = residual + x
97 | output = x
98 | # output = F.sigmoid(x)
99 | return output
100 |
101 | class decoder_block(nn.Module):
102 | def __init__(self, in_channels, out_channels):
103 | super(decoder_block,self).__init__()
104 | self.bn_i = nn.BatchNorm2d(num_features=in_channels)
105 | self.relu = nn.ReLU()
106 | self.conv = conv_block(in_channels, out_channels)
107 | def forward(self, x):
108 |
109 | out = self.bn_i(x)
110 | out = self.relu(out)
111 | out = self.conv(out)
112 | return out
113 |
114 | class BRRNet_BE(nn.Module):
115 |
116 | def __init__(self, n_class=1, pretrained=False, mode= 'Train'):
117 | super().__init__()
118 | self.mode = mode
119 | self.dconv_down1 = conv_block(3, 64)
120 | self.dconv_down2 = conv_block(64, 128)
121 | self.dconv_down3 = conv_block(128, 256)
122 |
123 | self.maxpool = nn.MaxPool2d(2,2)
124 | self.center = center_block(256,512)
125 | self.deconv3 = up_transpose(512,256)
126 | self.deconv2 = up_transpose(256,128)
127 | self.deconv1 = up_transpose(128,64)
128 |
129 | self.decoder_3 = decoder_block(512, 256)
130 | self.decoder_2 = decoder_block(256, 128)
131 | self.decoder_1 = decoder_block(128, 64)
132 | # self.output_1 = nn.Conv2d(64,n_class, 1)
133 | # self.rrm = rrm_module(1,64)
134 | # BE mode
135 | self.output_1 = nn.Conv2d(64,64, 1)
136 | self.rrm = rrm_module(64,64)
137 |
138 | # HED Block
139 | self.dsn1 = nn.Conv2d(64, 1, 1)
140 | self.dsn2 = nn.Conv2d(128, 1, 1)
141 | self.dsn3 = nn.Conv2d(256, 1, 1)
142 | self.dsn4 = nn.Conv2d(512, 1, 1)
143 |
144 |
145 | #boundary enhancement part
146 | self.fuse = nn.Sequential(nn.Conv2d(4, 64, 1),nn.ReLU(inplace=True))
147 |
148 | self.SE_mimic = nn.Sequential(
149 | nn.Linear(64, 64, bias=False),
150 | nn.ReLU(inplace=True),
151 | nn.Linear(64, 4, bias=False),
152 | nn.Sigmoid()
153 | )
154 | self.final_boundary = nn.Conv2d(4,2,1)
155 |
156 | self.final_conv = nn.Sequential(
157 | nn.Conv2d(128,64,3, padding=1),
158 | nn.ReLU(inplace=True)
159 | )
160 | self.final_mask = nn.Conv2d(64,2,1)
161 |
162 |
163 |
164 | self.relu = nn.ReLU()
165 | self.out = nn.Conv2d(64,1,1)
166 |
167 |
168 |
169 |
170 |
171 | def forward(self, x):
172 | h = x.size(2)
173 | w = x.size(3)
174 | conv1 = self.dconv_down1(x)
175 | # print(conv1.shape)
176 | x = self.maxpool(conv1)
177 | # print(x.shape)
178 | conv2 = self.dconv_down2(x)
179 | x = self.maxpool(conv2)
180 |
181 | conv3 = self.dconv_down3(x)
182 | x = self.maxpool(conv3)
183 |
184 | conv4 = self.center(x)
185 |
186 | x = self.deconv3(conv4) # 512 256
187 | x = torch.cat([conv3,x],1) # 256 + 256
188 |
189 | x = self.decoder_3(x) # 512 256
190 |
191 | x = self.deconv2(x)
192 | x = torch.cat([conv2,x],1)
193 | x = self.decoder_2(x)
194 |
195 | x = self.deconv1(x)
196 | x = torch.cat([conv1,x],1)
197 | x = self.decoder_1(x)
198 |
199 | x = self.output_1(x)
200 | out = self.rrm(x)
201 |
202 |
203 | d1 = self.dsn1(conv1)
204 | d2 = F.upsample_bilinear(self.dsn2(conv2), size=(h,w))
205 | d3 = F.upsample_bilinear(self.dsn3(conv3), size=(h,w))
206 | d4 = F.upsample_bilinear(self.dsn4(conv4), size=(h,w))
207 |
208 | d1_out = F.sigmoid(d1)
209 | d2_out = F.sigmoid(d2)
210 | d3_out = F.sigmoid(d3)
211 | d4_out = F.sigmoid(d4)
212 |
213 | concat = torch.cat((d1_out, d2_out, d3_out, d4_out), 1)
214 |
215 | fuse_box = self.fuse(concat)
216 | GAP = F.adaptive_avg_pool2d(fuse_box,(1,1))
217 | GAP = GAP.view(-1, 64)
218 | se_like = self.SE_mimic(GAP)
219 | se_like = torch.unsqueeze(se_like, 2)
220 | se_like = torch.unsqueeze(se_like, 3)
221 |
222 | feat_se = concat * se_like.expand_as(concat)
223 | boundary = self.final_boundary(feat_se)
224 | boundary_out = torch.unsqueeze(boundary[:,1,:,:],1)
225 | bd_sftmax = F.softmax(boundary, dim=1)
226 | boundary_scale = torch.unsqueeze(bd_sftmax[:,1,:,:],1)
227 |
228 | feat_concat = torch.cat( [out, fuse_box], 1)
229 | feat_concat_conv = self.final_conv(feat_concat)
230 | mask = self.final_mask(feat_concat_conv)
231 | mask_sftmax = F.softmax(mask,dim=1)
232 | mask_scale = torch.unsqueeze(mask_sftmax[:,1,:,:],1)
233 |
234 | if self.mode == 'Train':
235 | scalefactor = torch.clamp(mask_scale+boundary_scale,0,1)
236 | elif self.mode == 'Infer':
237 | scalefactor = torch.clamp(mask_scale+5*boundary_scale,0,1)
238 |
239 |
240 | mask_out = torch.unsqueeze(mask[:,1,:,:],1)
241 | relu = self.relu(mask_out)
242 | scalar = relu.cpu().detach().numpy()
243 | if np.sum(scalar) == 0:
244 | average = 0
245 | else :
246 | average = scalar[np.nonzero(scalar)].mean()
247 | mask_out = mask_out-relu + (average*scalefactor)
248 |
249 | if self.mode == 'Train':
250 | mask_out = F.sigmoid(mask_out)
251 | boundary_out = F.sigmoid(boundary_out)
252 |
253 | return d1_out, d2_out, d3_out, d4_out, boundary_out, mask_out
254 | elif self.mode =='Infer':
255 | return mask_out
256 |
257 |
--------------------------------------------------------------------------------
/nets/torch_callbacks.py:
--------------------------------------------------------------------------------
1 | """PyTorch Callbacks."""
2 |
3 | import os
4 | import numpy as np
5 |
6 | import torch
7 | import time
8 |
9 | now = time.localtime()
10 |
11 | class TorchEarlyStopping(object):
12 | """Tracks if model training should stop based on rate of improvement.
13 |
14 | Arguments
15 | ---------
16 | patience : int, optional
17 | The number of epochs to wait before stopping the model if the metric
18 | didn't improve. Defaults to 5.
19 | threshold : float, optional
20 | The minimum metric improvement required to count as "improvement".
21 | Defaults to ``0.0`` (any improvement satisfies the requirement).
22 | verbose : bool, optional
23 | Verbose text output. Defaults to off (``False``). _NOTE_ : This
24 | currently does nothing.
25 | """
26 |
27 | def __init__(self, patience=5, threshold=0.0, verbose=False):
28 | self.patience = patience
29 | self.threshold = threshold
30 | self.counter = 0
31 | self.best = None
32 | self.stop = False
33 |
34 | def __call__(self, metric_score):
35 |
36 | if self.best is None:
37 | self.best = metric_score
38 | self.counter = 0
39 | else:
40 | if self.best - self.threshold < metric_score:
41 | self.counter += 1
42 | else:
43 | self.best = metric_score
44 | self.counter = 0
45 |
46 | if self.counter >= self.patience:
47 | self.stop = True
48 |
49 |
50 | class TorchTerminateOnNaN(object):
51 | """Sets a stop condition if the model loss achieves an NaN or inf value.
52 |
53 | Arguments
54 | ---------
55 | patience : int, optional
56 | The number of epochs that must display an NaN loss value before
57 | stopping. Defaults to ``1``.
58 | verbose : bool, optional
59 | Verbose text output. Defaults to off (``False``). _NOTE_ : This
60 | currently does nothing.
61 | """
62 |
63 | def __init__(self, patience=1, verbose=False):
64 | self.patience = patience
65 | self.counter = 0
66 | self.stop = False
67 |
68 | def __call__(self, loss):
69 | if np.isnan(loss) or np.isinf(loss):
70 | self.counter += 1
71 | if self.counter >= self.patience:
72 | self.stop = True
73 | else:
74 | self.counter = 0
75 |
76 |
77 |
78 | class TorchModelCheckpoint(object):
79 | """Save the model at specific points using Keras checkpointing args.
80 |
81 | Arguments
82 | ---------
83 | filepath : str, optional
84 | Path to save the model file to. The end of the path (before the
85 | file extension) will have ``'_[epoch]'`` added to it to ID specific
86 | checkpoints.
87 | monitor : str, optional
88 | The loss value to monitor. Options are
89 | ``['loss', 'val_loss', 'periodic']`` or a metric from the keys in
90 | :const:`solaris.nets.metrics.metric_dict` . Defaults to ``'loss'`` . If
91 | ``'periodic'``, it saves every n epochs (see `period` below).
92 | verbose : bool, optional
93 | Verbose text output. Defaults to ``False`` .
94 | save_best_only : bool, optional
95 | Save only the model with the best value? Defaults to no (``False`` ).
96 | mode : str, optional
97 | One of ``['auto', 'min', 'max']``. Is a better value higher or lower?
98 | Defaults to ``'auto'`` in which case it tries to infer it (if
99 | ``monitor='loss'`` or ``monitor='val_loss'`` , it assumes ``'min'`` ,
100 | if it's a metric it assumes ``'max'`` .) If ``'min'``, it assumes lower
101 | values are better; if ``'max'`` , it assumes higher values are better.
102 | period : int, optional
103 | If using ``monitor='periodic'`` , this saves models every `period`
104 | epochs. Otherwise, it sets the minimum number of epochs between
105 | checkpoints.
106 | """
107 |
108 | def __init__(self, filepath='', path_aoi='',monitor='loss', verbose=False,
109 | save_best_only=False, mode='auto', period=1,
110 | weights_only=True):
111 |
112 | self.filepath = filepath
113 | self.monitor = monitor
114 | self.aoi = path_aoi
115 | if self.monitor not in ['loss', 'val_loss', 'periodic']:
116 | self.monitor = metric_dict[self.monitor]
117 | self.verbose = verbose
118 | self.save_best_only = save_best_only
119 | self.period = period
120 | self.weights_only = weights_only
121 | self.mode = mode
122 | if self.mode == 'auto':
123 | if self.monitor in ['loss', 'val_loss']:
124 | self.mode = 'min'
125 | else:
126 | self.mode = 'max'
127 |
128 | self.epoch = 0
129 | self.last_epoch = 0
130 | self.last_saved_value = None
131 |
132 | def __call__(self, model, file_path, loss_value=None, y_true=None, y_pred=None):
133 | """Run a round of model checkpointing for an epoch.
134 |
135 | Arguments
136 | ---------
137 | model : model object
138 | The model to be saved during checkpoints. Must be a PyTorch model.
139 | loss_value : numeric, optional
140 | The numeric output of the loss function. Only required if using
141 | ``monitor='loss'`` or ``monitor='val_loss'`` .
142 | y_true : :class:`np.array` , optional
143 | The labels for the validation data. Only required if using
144 | a metric as the monitored value.
145 | y_pred : :class:`np.array` , optional
146 | The predicted values from the model. Only required if using
147 | a metric as the monitored value.
148 | """
149 |
150 | self.epoch += 1
151 | if self.monitor == 'periodic': # update based on period
152 | if self.last_epoch + self.period <= self.epoch:
153 | # self.last_saved_value = loss_value if loss_value else 0
154 | self.save(model, file_path,self.weights_only)
155 | self.last_epoch = self.epoch
156 |
157 |
158 | elif self.monitor in ['loss', 'val_loss']:
159 | if self.last_saved_value is None:
160 | self.last_saved_value = loss_value
161 | if self.last_epoch + self.period <= self.epoch:
162 | self.save(model,file_path, self.weights_only)
163 | self.last_epoch = self.epoch
164 | if self.last_epoch + self.period <= self.epoch:
165 | if self.check_is_best_value(loss_value):
166 | self.last_saved_value = loss_value
167 | self.save(model, file_path,self.weights_only)
168 | self.last_epoch = self.epoch
169 |
170 | else:
171 | if self.last_saved_value is None:
172 | self.last_saved_value = self.monitor(y_true, y_pred)
173 | if self.last_epoch + self.period <= self.epoch:
174 | self.save(model,file_path, self.weights_only)
175 | self.last_epoch = self.epoch
176 | if self.last_epoch + self.period <= self.epoch:
177 | metric_value = self.monitor(y_true, y_pred)
178 | if self.check_is_best_value(metric_value):
179 | self.last_saved_value = metric_value
180 | self.save(model, file_path, self.weights_only)
181 | self.last_epoch = self.epoch
182 |
183 | def check_is_best_value(self, value):
184 | """Check if `value` is better than the best stored value."""
185 | if self.mode == 'min' and self.last_saved_value > value:
186 | return True
187 | elif self.mode == 'max' and self.last_saved_value < value:
188 | return True
189 | else:
190 | return False
191 |
192 | def save(self, model, file_path, weights_only):
193 | """Save the model.
194 |
195 | Arguments
196 | ---------
197 | model : :class:`torch.nn.Module`
198 | A PyTorch model instance to save.
199 | weights_only : bool, optional
200 | Should the entire model be saved, or only its weights (also known
201 | as the state_dict)? Defaults to ``False`` (saves entire model). The
202 | entire model must be saved to resume training without re-defining
203 | the model architecture, optimizer, and loss function.
204 | """
205 | # print("saved time : %04d/%02d/%02d %02d:%02d:%02d"% (now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec))
206 |
207 | # print(self.aoi)
208 | # print("aoi")
209 | lossvalue=np.round(self.last_saved_value,3)
210 | # save_path = self.filepath + self.aoi + '_' + str(now.tm_mon) + str(now.tm_mday)+str(now.tm_hour)+str(now.tm_min)+ '/'
211 | save_path = file_path
212 | save_name = save_path + 'best'+ '_epoch{}_{}'.format(self.epoch, str(lossvalue))
213 | #save_name = save_path + 'best'+ '_epoch_{}_{}'.format( str(self.epoch).zfill(3), str(lossvalue))
214 |
215 | # save_name = self.filepath + self.aoi + '_' + str(now.tm_mon) + str(now.tm_mday)+ '/' + 'best'+ '_epoch{}_{}'.format(
216 | # self.epoch, np.round(self.last_saved_value, 1))
217 | # save_name = os.path.splitext(self.filepath)[0] + self.aoi + + '_epoch{}_{}'.format(
218 | # self.epoch, np.round(self.last_saved_value, 3))
219 |
220 | save_name = save_name + '.pth'
221 | print("saved path : ", save_path)
222 | print()
223 | print()
224 | if not os.path.exists(save_path) :
225 | os.makedirs(save_path)
226 | else :
227 | pass
228 | # os.makedirs(save_path)
229 | if isinstance(model, torch.nn.DataParallel):
230 | to_save = model.module
231 | else:
232 | to_save = model
233 | if weights_only:
234 | # os.makedirs(save_path)
235 | # torch.save(save_path, save_name)
236 | torch.save(to_save.state_dict(), save_name)
237 |
238 | else:
239 | torch.save(to_save, save_name)
240 |
241 |
242 | torch_callback_dict = {
243 | "early_stopping": TorchEarlyStopping,
244 | "model_checkpoint": TorchModelCheckpoint,
245 | "terminate_on_nan": TorchTerminateOnNaN,
246 | }
247 |
--------------------------------------------------------------------------------
/nets/zoo/ternaus_BE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torchvision import models
6 | import numpy as np
7 |
8 | def conv3x3(in_: int, out: int) -> nn.Module:
9 | return nn.Conv2d(in_, out, 3, padding=1)
10 |
11 |
12 | class ConvRelu(nn.Module):
13 | def __init__(self, in_: int, out: int) -> None:
14 | super().__init__()
15 | self.conv = conv3x3(in_, out)
16 | self.activation = nn.ReLU(inplace=True)
17 |
18 | def forward(self, x: torch.Tensor) -> torch.Tensor:
19 | x = self.conv(x)
20 | x = self.activation(x)
21 | return x
22 |
23 |
24 | class DecoderBlock(nn.Module):
25 | def __init__(
26 | self, in_channels: int, middle_channels: int, out_channels: int
27 | ) -> None:
28 | super().__init__()
29 |
30 | self.block = nn.Sequential(
31 | ConvRelu(in_channels, middle_channels),
32 | nn.ConvTranspose2d(
33 | middle_channels,
34 | out_channels,
35 | kernel_size=3,
36 | stride=2,
37 | padding=1,
38 | output_padding=1,
39 | ),
40 | nn.ReLU(inplace=True),
41 | )
42 |
43 | def forward(self, x: torch.Tensor) -> torch.Tensor:
44 | return self.block(x)
45 |
46 |
47 | class ternaus_BE(nn.Module):
48 | def __init__(self, num_filters: int = 32, pretrained: bool = False, mode= 'Train') -> None:
49 | """
50 | Args:
51 | num_filters:
52 | pretrained:
53 | False - no pre-trained network is used
54 | True - encoder is pre-trained with VGG11
55 | """
56 | super().__init__()
57 | self.pool = nn.MaxPool2d(2, 2)
58 |
59 | self.encoder = models.vgg11(pretrained=pretrained).features
60 |
61 | self.relu = self.encoder[1]
62 | self.conv1 = self.encoder[0]
63 | self.conv2 = self.encoder[3]
64 | self.conv3s = self.encoder[6]
65 | self.conv3 = self.encoder[8]
66 | self.conv4s = self.encoder[11]
67 | self.conv4 = self.encoder[13]
68 | self.conv5s = self.encoder[16]
69 | self.conv5 = self.encoder[18]
70 | self.conv6 = ConvRelu(num_filters * 8 * 2, num_filters * 8 * 2)
71 | self.decoder6 = nn.ConvTranspose2d(num_filters * 8 * 2,
72 | num_filters * 8,
73 | kernel_size=3,
74 | stride=2,
75 | padding=1,
76 | output_padding=1,)
77 |
78 | self.center = DecoderBlock(
79 | num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8
80 | )
81 | self.dec5 = DecoderBlock(
82 | num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8
83 | )
84 | self.dec4 = DecoderBlock(
85 | num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4
86 | )
87 | self.dec3 = DecoderBlock(
88 | num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2
89 | )
90 | self.dec2 = DecoderBlock(
91 | num_filters * (4 + 2), num_filters * 2 * 2, num_filters
92 | )
93 | self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)
94 |
95 | self.final = nn.Conv2d(num_filters, 1, kernel_size=1)
96 |
97 | # HED Block
98 | self.dsn1 = nn.Conv2d(num_filters*2, 1, 1)
99 | self.dsn2 = nn.Conv2d(num_filters*4, 1, 1)
100 | self.dsn3 = nn.Conv2d(num_filters*8, 1, 1)
101 | self.dsn4 = nn.Conv2d(num_filters*16, 1, 1)
102 | self.dsn5 = nn.Conv2d(num_filters*16, 1, 1)
103 | self.dsn6 = nn.Conv2d(num_filters*8, 1, 1)
104 | self.fuse = nn.Sequential(nn.Conv2d(6, 32, 1),nn.ReLU(inplace=True))
105 | # self.fuse = nn.Conv2d(5, 64, 1)
106 |
107 | self.SE_mimic = nn.Sequential(
108 | nn.Linear(32, 32, bias=False),
109 | nn.ReLU(inplace=True),
110 | nn.Linear(32, 6, bias=False),
111 | nn.Sigmoid()
112 | )
113 | self.final_boundary = nn.Conv2d(6,2,1)
114 |
115 | self.final_conv = nn.Sequential(
116 | nn.Conv2d(64,64,3, padding=1),
117 | nn.ReLU(inplace=True)
118 | )
119 | self.final_mask = nn.Conv2d(64,2,1)
120 |
121 |
122 |
123 | self.relu = nn.ReLU()
124 | self.out = nn.Conv2d(64,1,1)
125 |
126 |
127 | def forward(self, x: torch.Tensor) -> torch.Tensor:
128 | h = x.size(2)
129 | w = x.size(3)
130 | conv1 = self.relu(self.conv1(x))
131 | conv1p = self.pool(conv1)
132 | conv2 = self.relu(self.conv2(conv1p))
133 | conv2p = self.pool(conv2)
134 | conv3s = self.relu(self.conv3s(conv2p))
135 | conv3 = self.relu(self.conv3(conv3s))
136 | conv3p = self.pool(conv3)
137 | conv4s = self.relu(self.conv4s(conv3p))
138 | conv4 = self.relu(self.conv4(conv4s))
139 | conv4p = self.pool(conv4)
140 | conv5s = self.relu(self.conv5s(conv4p))
141 | conv5 = self.relu(self.conv5(conv5s))
142 | conv5p = self.pool(conv5)
143 |
144 | # center = self.center(conv5p)
145 | conv6s = self.conv6(conv5p)
146 | conv6 = self.relu(self.decoder6(conv6s))
147 | center = conv6
148 | dec5 = self.dec5(torch.cat([center, conv5], 1))
149 | dec4 = self.dec4(torch.cat([dec5, conv4], 1))
150 | dec3 = self.dec3(torch.cat([dec4, conv3], 1))
151 | dec2 = self.dec2(torch.cat([dec3, conv2], 1))
152 | dec1 = self.dec1(torch.cat([dec2, conv1], 1))
153 | xx = dec1
154 | # xx = self.final(dec1)
155 | # out = F.sigmoid(out)
156 |
157 | ## side output
158 | d1 = self.dsn1(conv1)
159 | d2 = F.upsample_bilinear(self.dsn2(conv2), size=(h,w))
160 | d3 = F.upsample_bilinear(self.dsn3(conv3), size=(h,w))
161 | d4 = F.upsample_bilinear(self.dsn4(conv4), size=(h,w))
162 | d5 = F.upsample_bilinear(self.dsn5(conv5), size=(h,w))
163 |
164 | d6 = F.upsample_bilinear(self.dsn6(conv6), size=(h,w))
165 | #
166 | ###########sigmoid ver
167 | d1_out = F.sigmoid(d1)
168 | d2_out = F.sigmoid(d2)
169 | d3_out = F.sigmoid(d3)
170 | d4_out = F.sigmoid(d4)
171 | d5_out = F.sigmoid(d5)
172 | d6_out = F.sigmoid(d6)
173 |
174 | # concat = torch.cat((d1_out, d2_out, d3_out, d4_out, d5_out), 1)
175 | concat = torch.cat((d1_out, d2_out, d3_out, d4_out, d5_out,d6_out ), 1)
176 |
177 | fuse_box = self.fuse(concat)
178 | GAP = F.adaptive_avg_pool2d(fuse_box,(1,1))
179 | GAP = GAP.view(-1, 32)
180 | se_like = self.SE_mimic(GAP)
181 | se_like = torch.unsqueeze(se_like, 2)
182 | se_like = torch.unsqueeze(se_like, 3)
183 |
184 | feat_se = concat * se_like.expand_as(concat)
185 | boundary = self.final_boundary(feat_se)
186 | boundary_out = torch.unsqueeze(boundary[:,1,:,:],1)
187 | bd_sftmax = F.softmax(boundary, dim=1)
188 | boundary_scale = torch.unsqueeze(bd_sftmax[:,1,:,:],1)
189 |
190 | feat_concat = torch.cat( [xx, fuse_box], 1)
191 | feat_concat_conv = self.final_conv(feat_concat)
192 | mask = self.final_mask(feat_concat_conv)
193 | mask_sftmax = F.softmax(mask,dim=1)
194 | mask_scale = torch.unsqueeze(mask_sftmax[:,1,:,:],1)
195 |
196 | if self.mode == 'Train':
197 | scalefactor = torch.clamp(mask_scale+boundary_scale,0,1)
198 | elif self.mode == 'Infer':
199 | scalefactor = torch.clamp(mask_scale+5*boundary_scale,0,1)
200 |
201 |
202 | mask_out = torch.unsqueeze(mask[:,1,:,:],1)
203 | relu = self.relu(mask_out)
204 | scalar = relu.cpu().detach().numpy()
205 | if np.sum(scalar) == 0:
206 | average = 0
207 | else :
208 | average = scalar[np.nonzero(scalar)].mean()
209 | mask_out = mask_out-relu + (average*scalefactor)
210 |
211 | if self.mode == 'Train':
212 | mask_out = F.sigmoid(mask_out)
213 | boundary_out = F.sigmoid(boundary_out)
214 |
215 | return d1_out, d2_out, d3_out, d4_out, d5_out, d6_out, boundary_out, mask_out
216 | elif self.mode =='Infer':
217 | return mask_out
218 |
219 | return out
220 |
221 |
222 | class Interpolate(nn.Module):
223 | def __init__(
224 | self,
225 | size: int = None,
226 | scale_factor: int = None,
227 | mode: str = "nearest",
228 | align_corners: bool = False,
229 | ):
230 | super().__init__()
231 | self.interp = nn.functional.interpolate
232 | self.size = size
233 | self.mode = mode
234 | self.scale_factor = scale_factor
235 | self.align_corners = align_corners
236 |
237 | def forward(self, x: torch.Tensor) -> torch.Tensor:
238 | x = self.interp(
239 | x,
240 | size=self.size,
241 | scale_factor=self.scale_factor,
242 | mode=self.mode,
243 | align_corners=self.align_corners,
244 | )
245 | return x
246 |
247 |
248 | class DecoderBlockV2(nn.Module):
249 | def __init__(
250 | self,
251 | in_channels: int,
252 | middle_channels: int,
253 | out_channels: int,
254 | is_deconv: bool = True,
255 | ):
256 | super().__init__()
257 | self.in_channels = in_channels
258 |
259 | if is_deconv:
260 | """
261 | Paramaters for Deconvolution were chosen to avoid artifacts, following
262 | link https://distill.pub/2016/deconv-checkerboard/
263 | """
264 |
265 | self.block = nn.Sequential(
266 | ConvRelu(in_channels, middle_channels),
267 | nn.ConvTranspose2d(
268 | middle_channels, out_channels, kernel_size=4, stride=2, padding=1
269 | ),
270 | nn.ReLU(inplace=True),
271 | )
272 | else:
273 | self.block = nn.Sequential(
274 | Interpolate(scale_factor=2, mode="bilinear"),
275 | ConvRelu(in_channels, middle_channels),
276 | ConvRelu(middle_channels, out_channels),
277 | )
278 |
279 | def forward(self, x: torch.Tensor) -> torch.Tensor:
280 | return self.block(x)
--------------------------------------------------------------------------------
/nets/zoo/resunet_BE.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class ResUnetPlusPlus_BE(nn.Module):
7 | def __init__(self, filters=[32, 64, 128, 256, 512], pretrained=False, mode = 'Train'):
8 | super(ResUnetPlusPlus_BE, self).__init__()
9 | self.mode = mode
10 | self.input_layer = nn.Sequential(
11 | nn.Conv2d(3, filters[0], kernel_size=3, padding=1),
12 | nn.BatchNorm2d(filters[0]),
13 | nn.ReLU(),
14 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15 | )
16 | self.input_skip = nn.Sequential(
17 | nn.Conv2d(3, filters[0], kernel_size=3, padding=1)
18 | )
19 |
20 | self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])
21 |
22 | self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)
23 |
24 | self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])
25 |
26 | self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)
27 |
28 | self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])
29 |
30 | self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)
31 |
32 | self.aspp_bridge = ASPP(filters[3], filters[4])
33 |
34 | self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])
35 | self.upsample1 = Upsample_(2)
36 | self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)
37 |
38 | self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])
39 | self.upsample2 = Upsample_(2)
40 | self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)
41 |
42 | self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])
43 | self.upsample3 = Upsample_(2)
44 | self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)
45 |
46 | self.aspp_out = ASPP(filters[1], filters[0])
47 |
48 | self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1))
49 |
50 | # HED Block
51 | self.dsn1 = nn.Conv2d(filters[0], 1, 1)
52 | self.dsn2 = nn.Conv2d(filters[1], 1, 1)
53 | self.dsn3 = nn.Conv2d(filters[2], 1, 1)
54 | self.dsn4 = nn.Conv2d(filters[3], 1, 1)
55 | self.dsn5 = nn.Conv2d(filters[4], 1, 1)
56 | self.fuse = nn.Sequential(nn.Conv2d(5, 32, 1),nn.ReLU(inplace=True))
57 | # self.fuse = nn.Conv2d(5, 64, 1)
58 |
59 | self.SE_mimic = nn.Sequential(
60 | nn.Linear(32, 32, bias=False),
61 | nn.ReLU(inplace=True),
62 | nn.Linear(32, 5, bias=False),
63 | nn.Sigmoid()
64 | )
65 | self.final_boundary = nn.Conv2d(5,2,1)
66 |
67 | self.final_conv = nn.Sequential(
68 | nn.Conv2d(64,64,3, padding=1),
69 | nn.ReLU(inplace=True)
70 | )
71 | self.final_mask = nn.Conv2d(64,2,1)
72 |
73 |
74 |
75 | self.relu = nn.ReLU()
76 | self.out = nn.Conv2d(64,1,1)
77 |
78 |
79 |
80 | def forward(self, x):
81 | h = x.size(2)
82 | w = x.size(3)
83 | x1 = self.input_layer(x) + self.input_skip(x)
84 |
85 | x2 = self.squeeze_excite1(x1)
86 | x2 = self.residual_conv1(x2)
87 |
88 | x3 = self.squeeze_excite2(x2)
89 | x3 = self.residual_conv2(x3)
90 |
91 | x4 = self.squeeze_excite3(x3)
92 | x4 = self.residual_conv3(x4)
93 |
94 | x5 = self.aspp_bridge(x4)
95 |
96 | x6 = self.attn1(x3, x5)
97 | x6 = self.upsample1(x6)
98 | x6 = torch.cat([x6, x3], dim=1)
99 | x6 = self.up_residual_conv1(x6)
100 |
101 | x7 = self.attn2(x2, x6)
102 | x7 = self.upsample2(x7)
103 | x7 = torch.cat([x7, x2], dim=1)
104 | x7 = self.up_residual_conv2(x7)
105 |
106 | x8 = self.attn3(x1, x7)
107 | x8 = self.upsample3(x8)
108 | x8 = torch.cat([x8, x1], dim=1)
109 | x8 = self.up_residual_conv3(x8)
110 |
111 | xx = self.aspp_out(x8)
112 | # out = self.output_layer(x9)
113 | # out = F.sigmoid(out)
114 |
115 | ## side output
116 | d1 = self.dsn1(x1)
117 | d2 = F.upsample_bilinear(self.dsn2(x2), size=(h,w))
118 | d3 = F.upsample_bilinear(self.dsn3(x3), size=(h,w))
119 | d4 = F.upsample_bilinear(self.dsn4(x4), size=(h,w))
120 | d5 = F.upsample_bilinear(self.dsn5(x5), size=(h,w))
121 | #
122 | ###########sigmoid ver
123 | d1_out = F.sigmoid(d1)
124 | d2_out = F.sigmoid(d2)
125 | d3_out = F.sigmoid(d3)
126 | d4_out = F.sigmoid(d4)
127 | d5_out = F.sigmoid(d5)
128 |
129 | concat = torch.cat((d1_out, d2_out, d3_out, d4_out, d5_out), 1)
130 |
131 | fuse_box = self.fuse(concat)
132 | GAP = F.adaptive_avg_pool2d(fuse_box,(1,1))
133 | GAP = GAP.view(-1, 32)
134 | se_like = self.SE_mimic(GAP)
135 | se_like = torch.unsqueeze(se_like, 2)
136 | se_like = torch.unsqueeze(se_like, 3)
137 |
138 | feat_se = concat * se_like.expand_as(concat)
139 | boundary = self.final_boundary(feat_se)
140 | boundary_out = torch.unsqueeze(boundary[:,1,:,:],1)
141 | bd_sftmax = F.softmax(boundary, dim=1)
142 | boundary_scale = torch.unsqueeze(bd_sftmax[:,1,:,:],1)
143 |
144 | feat_concat = torch.cat( [xx, fuse_box], 1)
145 | feat_concat_conv = self.final_conv(feat_concat)
146 | mask = self.final_mask(feat_concat_conv)
147 | mask_sftmax = F.softmax(mask,dim=1)
148 | mask_scale = torch.unsqueeze(mask_sftmax[:,1,:,:],1)
149 |
150 | if self.mode == 'Train':
151 | scalefactor = torch.clamp(mask_scale+boundary_scale,0,1)
152 | elif self.mode == 'Infer':
153 | scalefactor = torch.clamp(mask_scale+5*boundary_scale,0,1)
154 |
155 |
156 | mask_out = torch.unsqueeze(mask[:,1,:,:],1)
157 | relu = self.relu(mask_out)
158 | scalar = relu.cpu().detach().numpy()
159 | if np.sum(scalar) == 0:
160 | average = 0
161 | else :
162 | average = scalar[np.nonzero(scalar)].mean()
163 | mask_out = mask_out-relu + (average*scalefactor)
164 |
165 | if self.mode == 'Train':
166 | mask_out = F.sigmoid(mask_out)
167 | boundary_out = F.sigmoid(boundary_out)
168 |
169 | return d1_out, d2_out, d3_out, d4_out, d5_out, boundary_out, mask_out
170 | elif self.mode =='Infer':
171 | return mask_out
172 |
173 |
174 |
175 | class ResidualConv(nn.Module):
176 | def __init__(self, input_dim, output_dim, stride, padding):
177 | super(ResidualConv, self).__init__()
178 |
179 | self.conv_block = nn.Sequential(
180 | nn.BatchNorm2d(input_dim),
181 | nn.ReLU(),
182 | nn.Conv2d(
183 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
184 | ),
185 | nn.BatchNorm2d(output_dim),
186 | nn.ReLU(),
187 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
188 | )
189 | self.conv_skip = nn.Sequential(
190 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
191 | nn.BatchNorm2d(output_dim),
192 | )
193 |
194 | def forward(self, x):
195 |
196 | return self.conv_block(x) + self.conv_skip(x)
197 |
198 |
199 | class Upsample(nn.Module):
200 | def __init__(self, input_dim, output_dim, kernel, stride):
201 | super(Upsample, self).__init__()
202 |
203 | self.upsample = nn.ConvTranspose2d(
204 | input_dim, output_dim, kernel_size=kernel, stride=stride
205 | )
206 |
207 | def forward(self, x):
208 | return self.upsample(x)
209 |
210 |
211 | class Squeeze_Excite_Block(nn.Module):
212 | def __init__(self, channel, reduction=16):
213 | super(Squeeze_Excite_Block, self).__init__()
214 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
215 | self.fc = nn.Sequential(
216 | nn.Linear(channel, channel // reduction, bias=False),
217 | nn.ReLU(inplace=True),
218 | nn.Linear(channel // reduction, channel, bias=False),
219 | nn.Sigmoid(),
220 | )
221 |
222 | def forward(self, x):
223 | b, c, _, _ = x.size()
224 | y = self.avg_pool(x).view(b, c)
225 | y = self.fc(y).view(b, c, 1, 1)
226 | return x * y.expand_as(x)
227 |
228 |
229 | class ASPP(nn.Module):
230 | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
231 | super(ASPP, self).__init__()
232 |
233 | self.aspp_block1 = nn.Sequential(
234 | nn.Conv2d(
235 | in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
236 | ),
237 | nn.ReLU(inplace=True),
238 | nn.BatchNorm2d(out_dims),
239 | )
240 | self.aspp_block2 = nn.Sequential(
241 | nn.Conv2d(
242 | in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
243 | ),
244 | nn.ReLU(inplace=True),
245 | nn.BatchNorm2d(out_dims),
246 | )
247 | self.aspp_block3 = nn.Sequential(
248 | nn.Conv2d(
249 | in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
250 | ),
251 | nn.ReLU(inplace=True),
252 | nn.BatchNorm2d(out_dims),
253 | )
254 |
255 | self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
256 | self._init_weights()
257 |
258 | def forward(self, x):
259 | x1 = self.aspp_block1(x)
260 | x2 = self.aspp_block2(x)
261 | x3 = self.aspp_block3(x)
262 | out = torch.cat([x1, x2, x3], dim=1)
263 | return self.output(out)
264 |
265 | def _init_weights(self):
266 | for m in self.modules():
267 | if isinstance(m, nn.Conv2d):
268 | nn.init.kaiming_normal_(m.weight)
269 | elif isinstance(m, nn.BatchNorm2d):
270 | m.weight.data.fill_(1)
271 | m.bias.data.zero_()
272 |
273 |
274 | class Upsample_(nn.Module):
275 | def __init__(self, scale=2):
276 | super(Upsample_, self).__init__()
277 |
278 | self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
279 |
280 | def forward(self, x):
281 | return self.upsample(x)
282 |
283 |
284 | class AttentionBlock(nn.Module):
285 | def __init__(self, input_encoder, input_decoder, output_dim):
286 | super(AttentionBlock, self).__init__()
287 |
288 | self.conv_encoder = nn.Sequential(
289 | nn.BatchNorm2d(input_encoder),
290 | nn.ReLU(),
291 | nn.Conv2d(input_encoder, output_dim, 3, padding=1),
292 | nn.MaxPool2d(2, 2),
293 | )
294 |
295 | self.conv_decoder = nn.Sequential(
296 | nn.BatchNorm2d(input_decoder),
297 | nn.ReLU(),
298 | nn.Conv2d(input_decoder, output_dim, 3, padding=1),
299 | )
300 |
301 | self.conv_attn = nn.Sequential(
302 | nn.BatchNorm2d(output_dim),
303 | nn.ReLU(),
304 | nn.Conv2d(output_dim, 1, 1),
305 | )
306 |
307 | def forward(self, x1, x2):
308 | out = self.conv_encoder(x1) + self.conv_decoder(x2)
309 | out = self.conv_attn(out)
310 | return out * x2
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/utils/io.py:
--------------------------------------------------------------------------------
1 | """Utility functions for data io."""
2 | import numpy as np
3 | import skimage.io
4 |
5 |
6 | def imread(path, make_8bit=False, rescale=False,
7 | rescale_min='auto', rescale_max='auto'):
8 | """Read in an image file and rescale pixel values (if applicable).
9 |
10 | Note
11 | ----
12 | Because overhead imagery is often either 16-bit or multispectral (i.e. >3
13 | channels or bands that don't directly translate into the RGB scheme of
14 | photographs), this package using scikit-image_ ``io`` algorithms. Though
15 | slightly slower, these algorithms are compatible with any bit depth or
16 | channel count.
17 |
18 | .. _scikit-image: https://scikit-image.org
19 |
20 | Arguments
21 | ---------
22 | path : str
23 | Path to the image file to load.
24 | make_8bit : bool, optional
25 | Should the image be converted to an 8-bit format? Defaults to False.
26 | rescale : bool, optional
27 | Should pixel intensities be rescaled? Defaults to no (False).
28 | rescale_min : ``'auto'`` or :class:`int` or :class:`float` or :class:`list`
29 | The minimum pixel value(s) for rescaling. If ``rescale=True`` but no
30 | value is provided for `rescale_min`, the minimum pixel intensity in
31 | each channel of the image will be subtracted such that the minimum
32 | value becomes zero. If a single number is provided, that number will be
33 | subtracted from each channel. If a list of values is provided that is
34 | the same length as the number of channels, then those values will be
35 | subtracted from the corresponding channels.
36 | rescale_max : ``'auto'`` or :class:`int` or :class:`float` or :class:`list`
37 | The max pixel value(s) for rescaling. If ``rescale=True`` but no
38 | value is provided for `rescale_max`, each channel will be rescaled such
39 | that the maximum value in the channel is set to the bit range's max.
40 | If a single number is provided, that number will be set as the upper
41 | limit for all channels. If a list of values is provided that is the
42 | same length as the number of channels, then those values will be
43 | set to the maximum value in the corresponding channels.
44 |
45 | Returns
46 | -------
47 | im : :func:`numpy.array`
48 | A NumPy array of shape ``[Y, X, C]`` containing the imagery, with dtype
49 | ``uint8``.
50 |
51 | """
52 | im_arr = skimage.io.imread(path)
53 | # check dtype for preprocessing
54 | if im_arr.dtype == np.uint8:
55 | dtype = 'uint8'
56 | elif im_arr.dtype == np.uint16:
57 | dtype = 'uint16'
58 | elif im_arr.dtype in [np.float16, np.float32, np.float64]:
59 | if np.amax(im_arr) <= 1 and np.amin(im_arr) >= 0:
60 | dtype = 'zero-one normalized' # range = 0-1
61 | elif np.amax(im_arr) > 0 and np.amin(im_arr) < 0:
62 | dtype = 'z-scored'
63 | elif np.amax(im_arr) <= 255:
64 | dtype = '255 float'
65 | elif np.amax(im_arr) <= 65535:
66 | dtype = '65535 float'
67 | else:
68 | raise TypeError('The loaded image array is an unexpected dtype.')
69 | else:
70 | raise TypeError('The loaded image array is an unexpected dtype.')
71 | if make_8bit:
72 | im_arr = preprocess_im_arr(im_arr, dtype, rescale=rescale,
73 | rescale_min=rescale_min,
74 | rescale_max=rescale_max)
75 | return im_arr
76 |
77 |
78 | def preprocess_im_arr(im_arr, im_format, rescale=False,
79 | rescale_min='auto', rescale_max='auto'):
80 | """Convert image to standard shape and dtype for use in the pipeline.
81 |
82 | Notes
83 | -----
84 | This repo will require the following of images:
85 |
86 | - Their shape is of form [X, Y, C]
87 | - Input images are dtype ``uint8``
88 |
89 | This function will take an image array `im_arr` and reshape it accordingly.
90 |
91 | Arguments
92 | ---------
93 | im_arr : :func:`numpy.array`
94 | A numpy array representation of an image. `im_arr` should have either
95 | two or three dimensions.
96 | im_format : str
97 | One of ``'uint8'``, ``'uint16'``, ``'z-scored'``,
98 | ``'zero-one normalized'``, ``'255 float'``, or ``'65535 float'``.
99 | String indicating the dtype of the input, which will dictate the
100 | preprocessing applied.
101 | rescale : bool, optional
102 | Should pixel intensities be rescaled? Defaults to no (False).
103 | rescale_min : ``'auto'`` or :class:`int` or :class:`float` or :class:`list`
104 | The minimum pixel value(s) for rescaling. If ``rescale=True`` but no
105 | value is provided for `rescale_min`, the minimum pixel intensity in
106 | each channel of the image will be subtracted such that the minimum
107 | value becomes zero. If a single number is provided, that number will be
108 | subtracted from each channel. If a list of values is provided that is
109 | the same length as the number of channels, then those values will be
110 | subtracted from the corresponding channels.
111 | rescale_max : ``'auto'`` or :class:`int` or :class:`float` or :class:`list`
112 | The max pixel value(s) for rescaling. If ``rescale=True`` but no
113 | value is provided for `rescale_max`, each channel will be rescaled such
114 | that the maximum value in the channel is set to the bit range's max.
115 | If a single number is provided, that number will be set as the upper
116 | limit for all channels. If a list of values is provided that is the
117 | same length as the number of channels, then those values will be
118 | set to the maximum value in the corresponding channels.
119 |
120 | Returns
121 | -------
122 | A :func:`numpy.array` with shape ``[X, Y, C]`` and dtype ``uint8``.
123 |
124 | """
125 | # get [Y, X, C] axis order set up
126 | if im_arr.ndim not in [2, 3]:
127 | raise ValueError('This package can only read two-dimensional'
128 | 'image data with an optional channel dimension.')
129 | if im_arr.ndim == 2:
130 | im_arr = im_arr[:, :, np.newaxis]
131 | if im_arr.shape[0] < im_arr.shape[2]: # if the channel axis comes first
132 | im_arr = np.moveaxis(im_arr, 0, -1) # move 0th axis tolast position
133 |
134 | # rescale images (if applicable)
135 | if rescale:
136 | im_arr = rescale_arr(im_arr, im_format, rescale_min, rescale_max)
137 |
138 | if im_format == 'uint8':
139 | return im_arr.astype('uint8') # just to be sure
140 | elif im_format == 'uint16':
141 | im_arr = (im_arr.astype('float64')*255./65535.).astype('uint8')
142 | elif im_format == 'z-scored':
143 | im_arr = ((im_arr+1)*177.5).astype('uint8')
144 | elif im_format == 'zero-one normalized':
145 | im_arr = (im_arr*255).astype('uint8')
146 | elif im_format == '255 float':
147 | im_arr = im_arr.astype('uint8')
148 | elif im_format == '65535 float':
149 | # why are you using this format?
150 | im_arr = (im_arr*255/65535).astype('uint8')
151 | return im_arr
152 |
153 |
154 | def scale_for_model(image, output_type=None):
155 | """Scale an image to a model's required parameters.
156 |
157 | Arguments
158 | ---------
159 | image : :class:`np.array`
160 | The image array to be transformed to a desired output format.
161 | output_type : str, optional
162 | The data format of the output to pass into the model. There are five
163 | possible values:
164 |
165 | * ``'normalized'`` : values rescaled to 0-1.
166 | * ``'zscored'`` : image converted to zero mean and unit stdev.
167 | * ``'8bit'`` : image converted to 8-bit format.
168 | * ``'16bit'`` : image converted to 16-bit format.
169 |
170 | If no value is provided, no re-scaling is performed (input array is
171 | returned directly).
172 | """
173 |
174 | if output_type is None:
175 | return image
176 | elif output_type == 'normalized':
177 | out_im = image/image.max()
178 | return out_im
179 | elif output_type == 'zscored':
180 | return (image - np.mean(image))/np.std(image)
181 | elif output_type == '8bit':
182 | if image.max() > 255:
183 | # assume it's 16-bit, rescale to 8-bit scale to min/max
184 | out_im = 255.*image/65535
185 | return out_im.astype('uint8')
186 | elif image.max() <= 1:
187 | out_im = 255.*image
188 | return out_im.astype('uint8')
189 | else:
190 | return image.astype('uint8')
191 | elif output_type == '16bit':
192 | if (image.max() < 255) and (image.max() > 1):
193 | # scale to min/max
194 | out_im = 65535.*image/255
195 | return out_im.astype('uint16')
196 | elif image.max() <= 1:
197 | out_im = 65535.*image
198 | return out_im.astype('uint16')
199 | else:
200 | return image.astype('uint16')
201 | else:
202 | raise ValueError('output_type must be one of'
203 | ' "normalized", "zscored", "8bit", "16bit"')
204 |
205 |
206 | def rescale_arr(im_arr, im_format, rescale_min='auto', rescale_max='auto'):
207 | """Rescale array values in a 3D image array with channel order [Y, X, C].
208 |
209 | Arguments
210 | ---------
211 | im_arr : :class:`numpy.array`
212 | A numpy array representation of an image. `im_arr` should have either
213 | two or three dimensions.
214 | im_format : str
215 | One of ``'uint8'``, ``'uint16'``, ``'z-scored'``,
216 | ``'zero-one normalized'``, ``'255 float'``, or ``'65535 float'``.
217 | String indicating the dtype of the input, which will dictate the
218 | preprocessing applied.
219 | rescale_min : ``'auto'`` or :class:`int` or :class:`float` or :class:`list`
220 | The minimum pixel value(s) for rescaling. If ``rescale=True`` but no
221 | value is provided for `rescale_min`, the minimum pixel intensity in
222 | each channel of the image will be subtracted such that the minimum
223 | value becomes zero. If a single number is provided, that number will be
224 | subtracted from each channel. If a list of values is provided that is
225 | the same length as the number of channels, then those values will be
226 | subtracted from the corresponding channels.
227 | rescale_max : ``'auto'`` or :class:`int` or :class:`float` or :class:`list`
228 | The max pixel value(s) for rescaling. If ``rescale=True`` but no
229 | value is provided for `rescale_max`, each channel will be rescaled such
230 | that the maximum value in the channel is set to the bit range's max.
231 | If a single number is provided, that number will be set as the upper
232 | limit for all channels. If a list of values is provided that is the
233 | same length as the number of channels, then those values will be
234 | set to the maximum value in the corresponding channels.
235 |
236 | Returns
237 | -------
238 | normalized_arr : :class:`numpy.array`
239 | """
240 |
241 | if isinstance(rescale_min, list):
242 | if len(rescale_min) != im_arr.shape[2]: # if list len != channels
243 | raise ValueError('The channel rescaling parameters must be '
244 | 'either a single value or a list of length '
245 | 'n_channels.')
246 | else:
247 | rescale_min = np.array(rescale_min)
248 | elif isinstance(rescale_min, int) or isinstance(rescale_min, float):
249 | rescale_min = np.array([rescale_min]*im_arr.shape[2])
250 | elif rescale_min == 'auto':
251 | rescale_min = np.amin(im_arr, axis=(0, 1))
252 |
253 | if isinstance(rescale_max, list):
254 | if len(rescale_max) != im_arr.shape[2]: # if list len != channels
255 | raise ValueError('The channel rescaling parameters must be '
256 | 'either a single value or a list of length '
257 | 'n_channels.')
258 | else:
259 | rescale_max = np.array(rescale_max)
260 | elif isinstance(rescale_max, int) or isinstance(rescale_max, float):
261 | rescale_max = np.array([rescale_max]*im_arr.shape[2])
262 | elif rescale_max == 'auto':
263 | rescale_max = np.amax(im_arr, axis=(0, 1))
264 |
265 | scale_factor = None
266 | if im_format in ['uint8', '255 float']:
267 | scale_factor = 255
268 | elif im_format in ['uint16', '65535 float']:
269 | scale_factor = 65535
270 | elif im_format == 'zero-one normalized':
271 | scale_factor = 1
272 |
273 | # set all values above the scale max to the scale max, and all values
274 | # below the scale min to the scale min
275 | for channel in range(im_arr.shape[2]):
276 | subarr = im_arr[:, :, channel]
277 | subarr[subarr < rescale_min[channel]] = rescale_min[channel]
278 | subarr[subarr > rescale_max[channel]] = rescale_max[channel]
279 | im_arr[:, :, channel] = subarr
280 |
281 | if scale_factor is not None:
282 | im_arr = (im_arr-rescale_min)*(
283 | scale_factor/(rescale_max-rescale_min))
284 |
285 | return im_arr
286 |
287 |
288 | def _check_channel_order(im_arr, framework):
289 | im_shape = im_arr.shape
290 | if len(im_shape) == 3: # doesn't matter for 1-channel images
291 | if im_shape[0] > im_shape[2] and framework in ['torch', 'pytorch']:
292 | # in [Y, X, C], needs to be in [C, Y, X]
293 | im_arr = np.moveaxis(im_arr, 2, 0)
294 | elif im_shape[2] > im_shape[0] and framework == 'keras':
295 | # in [C, Y, X], needs to be in [Y, X, C]
296 | im_arr = np.moveaxis(im_arr, 0, 2)
297 | elif len(im_shape) == 4: # for a whole minibatch
298 | if im_shape[1] > im_shape[3] and framework in ['torch', 'pytorch']:
299 | # in [Y, X, C], needs to be in [C, Y, X]
300 | im_arr = np.moveaxis(im_arr, 3, 1)
301 | elif im_shape[3] > im_shape[1] and framework == 'keras':
302 | # in [C, Y, X], needs to be in [Y, X, C]
303 | im_arr = np.moveaxis(im_arr, 1, 3)
304 |
305 | return im_arr
306 |
--------------------------------------------------------------------------------
/nets/zoo/enru.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from collections import OrderedDict
5 | import math
6 | def conv3x3(in_planes, out_planes, stride=1):
7 | "3x3 convolution with padding"
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
9 | padding=1, bias=False)
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
16 | super(BasicBlock, self).__init__()
17 | self.conv1 = conv3x3(inplanes, planes, stride)
18 | self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
19 | self.relu = nn.ReLU(inplace=True)
20 | self.conv2 = conv3x3(planes, planes)
21 | self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
22 | self.downsample = downsample
23 | self.stride = stride
24 |
25 | def forward(self, x):
26 | residual = x
27 |
28 | out = self.conv1(x)
29 | out = self.bn1(out)
30 | out = self.relu(out)
31 |
32 | out = self.conv2(out)
33 | out = self.bn2(out)
34 |
35 | if self.downsample is not None:
36 | residual = self.downsample(x)
37 |
38 | out += residual
39 | out = self.relu(out)
40 |
41 | return out
42 |
43 |
44 | class Bottleneck(nn.Module):
45 | expansion = 4
46 |
47 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
48 | super(Bottleneck, self).__init__()
49 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
50 | self.bn1 = bn(planes)
51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
52 | padding=1, bias=False)
53 | self.bn2 = bn(planes)
54 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
55 | self.bn3 = bn(planes * 4)
56 | self.relu = nn.ReLU(inplace=True)
57 | self.downsample = downsample
58 | self.stride = stride
59 |
60 | def forward(self, x):
61 | residual = x
62 |
63 | out = self.conv1(x)
64 | out = self.bn1(out)
65 |
66 | out = self.relu(out)
67 |
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv3(out)
73 | out = self.bn3(out)
74 |
75 | if self.downsample is not None:
76 | residual = self.downsample(x)
77 |
78 | out += residual
79 | out = self.relu(out)
80 |
81 | return out
82 |
83 |
84 | class ResNet(nn.Module):
85 |
86 | def __init__(self, block, layers, num_classes=1000, deep_base=False, norm_type=None):
87 | super(ResNet, self).__init__()
88 | self.inplanes = 128 if deep_base else 16
89 | if deep_base:
90 | self.prefix = nn.Sequential(OrderedDict([
91 | ('conv1', nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)),
92 | ('bn1', bn(64)),
93 | ('relu1', nn.ReLU(inplace=False)),
94 | ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)),
95 | ('bn2', bn(64)),
96 | ('relu2', nn.ReLU(inplace=False)),
97 | ('conv3', nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)),
98 | ('bn3', bn(self.inplanes)),
99 | ('relu3', nn.ReLU(inplace=False))]
100 | ))
101 | else:
102 | self.prefix = nn.Sequential(OrderedDict([
103 | ('conv1', nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
104 | ('bn1', bn(self.inplanes)),
105 | ('relu', nn.ReLU(inplace=False))]
106 | ))
107 |
108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change.
109 |
110 | self.layer1 = self._make_layer(block, 16, layers[0], norm_type=norm_type)
111 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2, norm_type=norm_type)
112 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, norm_type=norm_type)
113 | self.layer4 = self._make_layer(block, 128, layers[3], stride=2, norm_type=norm_type)
114 | self.avgpool = nn.AvgPool2d(7, stride=1)
115 | self.fc = nn.Linear(128 * block.expansion, num_classes)
116 |
117 | for m in self.modules():
118 | if isinstance(m, nn.Conv2d):
119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
120 | m.weight.data.normal_(0, math.sqrt(2. / n))
121 | # elif isinstance(m, ModuleHelper.BatchNorm2d(norm_type=norm_type, ret_cls=True)):
122 | # m.weight.data.fill_(1)
123 | # m.bias.data.zero_()
124 |
125 | def _make_layer(self, block, planes, blocks, stride=1, norm_type=None):
126 | downsample = None
127 | if stride != 1 or self.inplanes != planes * block.expansion:
128 | downsample = nn.Sequential(
129 | nn.Conv2d(self.inplanes, planes * block.expansion,
130 | kernel_size=1, stride=stride, bias=False),
131 | bn(planes * block.expansion),
132 | )
133 |
134 | layers = []
135 | layers.append(block(self.inplanes, planes, stride, downsample, norm_type=norm_type))
136 | self.inplanes = planes * block.expansion
137 | for i in range(1, blocks):
138 | layers.append(block(self.inplanes, planes, norm_type=norm_type))
139 |
140 | return nn.Sequential(*layers)
141 |
142 | def forward(self, x):
143 | x = self.conv1(x)
144 | x = self.bn1(x)
145 | x = self.relu(x)
146 | x = self.maxpool(x)
147 |
148 | x = self.layer1(x)
149 |
150 | x = self.layer2(x)
151 |
152 | x = self.layer3(x)
153 | x = self.layer4(x)
154 |
155 | x = self.avgpool(x)
156 | x = x.view(x.size(0), -1)
157 | x = self.fc(x)
158 |
159 | return x
160 |
161 |
162 | class NormalResnetBackbone(nn.Module):
163 | def __init__(self, orig_resnet):
164 | super(NormalResnetBackbone, self).__init__()
165 |
166 | self.num_features = 512
167 | # take pretrained resnet, except AvgPool and FC
168 | self.prefix = orig_resnet.prefix
169 | self.maxpool = orig_resnet.maxpool
170 | self.layer1 = orig_resnet.layer1
171 | self.layer2 = orig_resnet.layer2
172 | self.layer3 = orig_resnet.layer3
173 | self.layer4 = orig_resnet.layer4
174 |
175 | def get_num_features(self):
176 | return self.num_features
177 |
178 | def forward(self, x):
179 | tuple_features = list()
180 | x = self.prefix(x)
181 | x = self.maxpool(x)
182 | x0 = x
183 | x1 = self.layer1(x)
184 | tuple_features.append(x1)
185 |
186 | x2 = self.layer2(x1)
187 | tuple_features.append(x2)
188 | x3 = self.layer3(x2)
189 | tuple_features.append(x3)
190 | x4 = self.layer4(x3)
191 | tuple_features.append(x4)
192 |
193 | return x0, x1, x2, x3, x4
194 |
195 | def resnet50(**kwargs):
196 | """Constructs a ResNet-50 model.
197 | Args:
198 | pretrained (bool): If True, returns a model pre-trained on Places
199 | """
200 | model = ResNet(Bottleneck, [3, 4, 6, 3], deep_base=False, **kwargs)
201 |
202 | return model
203 |
204 |
205 | def bn(num_features):
206 | return nn.Sequential(
207 | nn.BatchNorm2d(num_features),
208 | nn.ReLU()
209 | )
210 |
211 | class PSPModule(nn.Module):
212 | # (1, 2, 3, 6)
213 | def __init__(self, sizes=(1, 3, 6, 8), dimension=2):
214 | super(PSPModule, self).__init__()
215 | self.stages = nn.ModuleList([self._make_stage(size, dimension) for size in sizes])
216 |
217 | def _make_stage(self, size, dimension=2):
218 | if dimension == 1:
219 | prior = nn.AdaptiveAvgPool1d(output_size=size)
220 | elif dimension == 2:
221 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
222 | elif dimension == 3:
223 | prior = nn.AdaptiveAvgPool3d(output_size=(size, size, size))
224 | return prior
225 |
226 | def forward(self, feats):
227 | n, c, _, _ = feats.size()
228 | priors = [stage(feats).view(n, c, -1) for stage in self.stages]
229 | center = torch.cat(priors, -1)
230 | return center
231 |
232 |
233 | class _SelfAttentionBlock(nn.Module):
234 | '''
235 | The basic implementation for self-attention block/non-local block
236 | Input:
237 | N X C X H X W
238 | Parameters:
239 | in_channels : the dimension of the input feature map
240 | key_channels : the dimension after the key/query transform
241 | value_channels : the dimension after the value transform
242 | scale : choose the scale to downsample the input feature maps (save memory cost)
243 | Return:
244 | N X C X H X W
245 | position-aware context features.(w/o concate or add with the input)
246 | '''
247 |
248 | def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1,psp_size=(1,3,6,8)):
249 | super(_SelfAttentionBlock, self).__init__()
250 | self.scale = scale
251 | self.in_channels = in_channels
252 | self.out_channels = out_channels
253 | self.key_channels = key_channels
254 | self.value_channels = value_channels
255 | if out_channels == None:
256 | self.out_channels = in_channels
257 | self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
258 | self.f_key = nn.Sequential(
259 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
260 | kernel_size=1, stride=1, padding=0),
261 | bn(self.key_channels),
262 | # ModuleHelper.BNReLU(self.key_channels, norm_type=norm_type),
263 | )
264 | self.f_query = self.f_key
265 | self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels,
266 | kernel_size=1, stride=1, padding=0)
267 | self.W = nn.Conv2d(in_channels=self.value_channels, out_channels=self.out_channels,
268 | kernel_size=1, stride=1, padding=0)
269 |
270 | self.psp = PSPModule(psp_size)
271 | nn.init.constant_(self.W.weight, 0)
272 | nn.init.constant_(self.W.bias, 0)
273 |
274 | def forward(self, x):
275 | batch_size, h, w = x.size(0), x.size(2), x.size(3)
276 | if self.scale > 1:
277 | x = self.pool(x)
278 |
279 | value = self.psp(self.f_value(x))
280 |
281 | query = self.f_query(x).view(batch_size, self.key_channels, -1)
282 | query = query.permute(0, 2, 1)
283 | key = self.f_key(x)
284 | # value=self.psp(value)#.view(batch_size, self.value_channels, -1)
285 | value = value.permute(0, 2, 1)
286 | key = self.psp(key) # .view(batch_size, self.key_channels, -1)
287 | sim_map = torch.matmul(query, key)
288 | sim_map = (self.key_channels ** -.5) * sim_map
289 | sim_map = F.softmax(sim_map, dim=-1)
290 |
291 | context = torch.matmul(sim_map, value)
292 | context = context.permute(0, 2, 1).contiguous()
293 | context = context.view(batch_size, self.value_channels, *x.size()[2:])
294 | context = self.W(context)
295 | return context
296 |
297 |
298 | class SelfAttentionBlock2D(_SelfAttentionBlock):
299 | def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1,psp_size=(1,3,6,8)):
300 | super(SelfAttentionBlock2D, self).__init__(in_channels,
301 | key_channels,
302 | value_channels,
303 | out_channels,
304 | scale,
305 |
306 | psp_size=psp_size)
307 |
308 |
309 | class APNB(nn.Module):
310 | """
311 | Parameters:
312 | in_features / out_features: the channels of the input / output feature maps.
313 | dropout: we choose 0.05 as the default value.
314 | size: you can apply multiple sizes. Here we only use one size.
315 | Return:
316 | features fused with Object context information.
317 | """
318 |
319 | def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1]), psp_size=(1,3,6,8)):
320 | super(APNB, self).__init__()
321 | self.stages = []
322 |
323 | self.psp_size=psp_size
324 | self.stages = nn.ModuleList(
325 | [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes])
326 | self.conv_bn_dropout = nn.Sequential(
327 | nn.Conv2d(2 * in_channels, out_channels, kernel_size=1, padding=0),
328 | # ModuleHelper.BNReLU(out_channels, norm_type=norm_type),
329 | bn(out_channels),
330 | nn.Dropout2d(dropout)
331 | )
332 |
333 | def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size):
334 | return SelfAttentionBlock2D(in_channels,
335 | key_channels,
336 | value_channels,
337 | output_channels,
338 | size,
339 |
340 | self.psp_size)
341 |
342 | def forward(self, feats):
343 | priors = [stage(feats) for stage in self.stages]
344 | context = priors[0]
345 | for i in range(1, len(priors)):
346 | context += priors[i]
347 | output = self.conv_bn_dropout(torch.cat([context, feats], 1))
348 | return output
349 |
350 |
351 | def double_conv(in_channels, out_channels):
352 | return nn.Sequential(
353 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
354 | nn.ReLU(inplace=True),
355 | nn.Conv2d(out_channels, out_channels, 3, padding=1),
356 | nn.ReLU(inplace=True)
357 | )
358 |
359 |
360 | class ENRUNet(nn.Sequential):
361 | def __init__(self,pretrained=False,mode='Train'):
362 | super(ENRUNet, self).__init__()
363 | self.mode=mode
364 | self.backbone = NormalResnetBackbone(resnet50())
365 | # low_in_channels, high_in_channels, out_channels, key_channels, value_channels, dropout
366 | self.dconv_up4 = double_conv(512+256, 256)
367 | self.dconv_up3 = double_conv(256+128, 128)
368 | self.dconv_up2 = double_conv(128+64, 64)
369 | self.dconv_up1 = double_conv(64 + 16, 64)
370 | self.APNB = nn.Sequential(
371 | APNB(in_channels=64, out_channels=64, key_channels=32, value_channels=32,
372 | dropout=0.05, sizes=([1]))
373 | )
374 |
375 | self.conv_last = nn.Conv2d(64, 1, 1)
376 |
377 | def forward(self, x_):
378 | x0, x1, x2, x3, x4 = self.backbone(x_)
379 | up4 = F.interpolate(x4, size=(x3.size(2), x3.size(3)), mode="bilinear", align_corners=True)
380 | x = torch.cat([up4, x3], dim=1)
381 | x = self.dconv_up4(x)
382 | up3 = F.interpolate(x, size=(x2.size(2), x2.size(3)), mode="bilinear", align_corners=True)
383 | x = torch.cat([up3, x2], dim=1)
384 | x = self.dconv_up3(x)
385 | up2 = F.interpolate(x, size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True)
386 | x = torch.cat([up2, x1], dim=1)
387 | x = self.dconv_up2(x)
388 | up1 = F.interpolate(x, size=(x0.size(2), x0.size(3)), mode="bilinear", align_corners=True)
389 | x = torch.cat([up1, x0], dim=1)
390 | x = self.dconv_up1(x)
391 | x = F.interpolate(x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True)
392 | x = self.APNB(x)
393 | out = self.conv_last(x)
394 | if self.mode == 'Train':
395 | return F.sigmoid(out)
396 | elif self.mode == 'Infer':
397 | return out
--------------------------------------------------------------------------------
/nets/_torch_losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from torch import nn
6 | from math import exp
7 | try:
8 | from itertools import ifilterfalse
9 | except ImportError: # py3k
10 | from itertools import filterfalse as ifilterfalse
11 |
12 | epsilon_ = 1e-15
13 |
14 | class TorchDiceLoss(nn.Module):
15 | def __init__(self, weight=None, size_average=True,
16 | per_image=False, logits=False):
17 | super().__init__()
18 | self.size_average = size_average
19 | self.register_buffer('weight', weight)
20 | self.per_image = per_image
21 | self.logits = logits
22 |
23 | def forward(self, input, target):
24 | if self.logits:
25 | input = torch.sigmoid(input)
26 | return soft_dice_loss(input, target, per_image=self.per_image)
27 |
28 |
29 | class TorchFocalLoss(nn.Module):
30 | """Implementation of Focal Loss[1]_ modified from Catalyst [2]_ .
31 |
32 | Arguments
33 | ---------
34 | gamma : :class:`int` or :class:`float`
35 | Focusing parameter. See [1]_ .
36 | alpha : :class:`int` or :class:`float`
37 | Normalization factor. See [1]_ .
38 |
39 | References
40 | ----------
41 | .. [1] https://arxiv.org/pdf/1708.02002.pdf
42 | .. [2] https://catalyst-team.github.io/catalyst/
43 | """
44 |
45 | def __init__(self, gamma=2, reduce=True, logits=False):
46 | super().__init__()
47 | self.gamma = gamma
48 | self.reduce = reduce
49 | self.logits = logits
50 |
51 | # TODO refactor
52 | def forward(self, outputs, targets):
53 | """Calculate the loss function between `outputs` and `targets`.
54 |
55 | Arguments
56 | ---------
57 | outputs : :class:`torch.Tensor`
58 | The output tensor from a model.
59 | targets : :class:`torch.Tensor`
60 | The training target.
61 |
62 | Returns
63 | -------
64 | loss : :class:`torch.Variable`
65 | The loss value.
66 | """
67 |
68 | if self.logits:
69 | BCE_loss = F.binary_cross_entropy_with_logits(outputs, targets,
70 | reduction='none')
71 | else:
72 | BCE_loss = F.binary_cross_entropy(outputs, targets,
73 | reduction='none')
74 | pt = torch.exp(-BCE_loss)
75 | F_loss = (1-pt)**self.gamma * BCE_loss
76 | if self.reduce:
77 | return torch.mean(F_loss)
78 | else:
79 | return F_loss
80 |
81 |
82 | def torch_lovasz_hinge(logits, labels, per_image=False, ignore=None):
83 | """Lovasz Hinge Loss. Implementation edited from Maxim Berman's GitHub.
84 |
85 | References
86 | ----------
87 | https://github.com/bermanmaxim/LovaszSoftmax/
88 | https://arxiv.org/abs/1705.08790
89 |
90 | Arguments
91 | ---------
92 | logits: :class:`torch.Variable`
93 | logits at each pixel (between -inf and +inf)
94 | labels: :class:`torch.Tensor`
95 | binary ground truth masks (0 or 1)
96 | per_image: bool, optional
97 | compute the loss per image instead of per batch. Defaults to ``False``.
98 | ignore: optional void class id.
99 |
100 | Returns
101 | -------
102 | loss : :class:`torch.Variable`
103 | Lovasz loss value for the input logits and labels. Compatible with
104 | ``loss.backward()`` as its a :class:`torch.Variable` .
105 | """
106 | # TODO: Restructure into a class like TorchFocalLoss for compatibility
107 | if per_image:
108 | loss = mean(
109 | lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0),
110 | lab.unsqueeze(0),
111 | ignore))
112 | for log, lab in zip(logits, labels))
113 | else:
114 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits,
115 | labels,
116 | ignore))
117 | return loss
118 |
119 |
120 | def lovasz_hinge_flat(logits, labels):
121 | """Binary Lovasz hinge loss.
122 |
123 | Arguments
124 | ---------
125 | logits: :class:`torch.Variable`
126 | Logits at each prediction (between -inf and +inf)
127 | labels: :class:`torch.Tensor`
128 | binary ground truth labels (0 or 1)
129 |
130 | Returns
131 | -------
132 | loss : :class:`torch.Variable`
133 | Lovasz loss value for the input logits and labels.
134 | """
135 | if len(labels) == 0:
136 | # only void pixels, the gradients should be 0
137 | return logits.sum() * 0.
138 | signs = 2. * labels.float() - 1.
139 | errors = (1. - logits * Variable(signs))
140 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
141 | perm = perm.data
142 | gt_sorted = labels[perm]
143 | grad = lovasz_grad(gt_sorted)
144 | loss = torch.dot(F.relu(errors_sorted), Variable(grad))
145 | return loss
146 |
147 |
148 | def flatten_binary_scores(scores, labels, ignore=None):
149 | """
150 | Flattens predictions in the batch (binary case)
151 | Remove labels equal to 'ignore'
152 | """
153 | scores = scores.view(-1)
154 | labels = labels.view(-1)
155 | if ignore is None:
156 | return scores, labels
157 | valid = (labels != ignore)
158 | vscores = scores[valid]
159 | vlabels = labels[valid]
160 | return vscores, vlabels
161 |
162 |
163 | class TorchJaccardLoss(torch.nn.modules.Module):
164 | # modified from XD_XD's implementation
165 | def __init__(self):
166 | super(TorchJaccardLoss, self).__init__()
167 |
168 | def forward(self, outputs, targets):
169 | eps = 1e-15
170 |
171 | jaccard_target = (targets == 1).float()
172 | jaccard_output = torch.sigmoid(outputs)
173 | #jaccard_output = outputs # bear's modif part
174 | intersection = (jaccard_output * jaccard_target).sum()
175 | union = jaccard_output.sum() + jaccard_target.sum()
176 | jaccard_score = ((intersection + eps) / (union - intersection + eps))
177 | self._stash_jaccard = jaccard_score
178 | loss = 1. - jaccard_score
179 |
180 | return loss
181 |
182 |
183 | class TorchStableBCELoss(torch.nn.modules.Module):
184 | def __init__(self):
185 | super(TorchStableBCELoss, self).__init__()
186 |
187 | def forward(self, input, target):
188 | neg_abs = - input.abs()
189 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
190 | return loss.mean()
191 |
192 |
193 | def binary_xloss(logits, labels, ignore=None):
194 | """
195 | Binary Cross entropy loss
196 | logits: [B, H, W] Variable, logits at each pixel (between -inf and +inf)
197 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
198 | ignore: void class id
199 | """
200 | logits, labels = flatten_binary_scores(logits, labels, ignore)
201 | loss = TorchStableBCELoss()(logits, Variable(labels.float()))
202 | return loss
203 |
204 |
205 | def lovasz_grad(gt_sorted):
206 | """
207 | Computes gradient of the Lovasz extension w.r.t sorted errors
208 | See Alg. 1 in paper
209 | """
210 | p = len(gt_sorted)
211 | gts = gt_sorted.sum()
212 | intersection = gts - gt_sorted.float().cumsum(0)
213 | union = gts + (1 - gt_sorted).float().cumsum(0)
214 | jaccard = 1. - intersection / union
215 | if p > 1: # cover 1 - pixel case
216 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
217 | return jaccard
218 |
219 |
220 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
221 | """
222 | IoU for foreground class
223 | binary: 1 foreground, 0 background
224 | """
225 | if not per_image:
226 | preds, labels = (preds,), (labels,)
227 | ious = []
228 | for pred, label in zip(preds, labels):
229 | intersection = ((label == 1) & (pred == 1)).sum()
230 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
231 | if not union:
232 | iou = EMPTY
233 | else:
234 | iou = float(intersection) / float(union)
235 | ious.append(iou)
236 | iou = mean(ious) # mean accross images if per_image
237 | return 100 * iou
238 |
239 |
240 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
241 | """
242 | Array of IoU for each (non ignored) class
243 | """
244 | if not per_image:
245 | preds, labels = (preds,), (labels,)
246 | ious = []
247 | for pred, label in zip(preds, labels):
248 | iou = []
249 | for i in range(C):
250 | if i != ignore:
251 | intersection = ((label == i) & (pred == i)).sum()
252 | union = ((label == i) | ((pred == i) & (label != ignore))).sum()
253 | if not union:
254 | iou.append(EMPTY)
255 | else:
256 | iou.append(float(intersection) / float(union))
257 | ious.append(iou)
258 | ious = [mean(iou) for iou in zip(*ious)] # mean across images if per_image
259 | return 100 * np.array(ious)
260 |
261 |
262 | # helper functions
263 | def isnan(x):
264 | return x != x
265 |
266 |
267 | def mean(l, ignore_nan=False, empty=0):
268 | """
269 | nanmean compatible with generators.
270 | """
271 | l = iter(l)
272 | if ignore_nan:
273 | l = ifilterfalse(isnan, l)
274 | try:
275 | n = 1
276 | acc = next(l)
277 | except StopIteration:
278 | if empty == 'raise':
279 | raise ValueError('Empty mean')
280 | return empty
281 | for n, v in enumerate(l, 2):
282 | acc += v
283 | if n == 1:
284 | return acc
285 | return acc / n
286 |
287 |
288 | def dice_round(preds, trues):
289 | preds = preds.float()
290 | return soft_dice_loss(preds, trues)
291 |
292 |
293 | def soft_dice_loss(outputs, targets, per_image=False):
294 | batch_size = outputs.size()[0]
295 | eps = 1e-5
296 | if not per_image:
297 | batch_size = 1
298 | dice_target = targets.contiguous().view(batch_size, -1).float()
299 | dice_output = outputs.contiguous().view(batch_size, -1)
300 | intersection = torch.sum(dice_output * dice_target, dim=1)
301 | union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) + eps
302 | loss = (1 - (2 * intersection + eps) / union).mean()
303 |
304 | return loss
305 | class MSSSIM(torch.nn.Module):
306 | def __init__(self, window_size=11, size_average=True, channel=1):
307 | super(MSSSIM, self).__init__()
308 | self.window_size = window_size
309 | self.size_average = size_average
310 | self.channel = channel
311 |
312 | def forward(self, img1, img2):
313 | # TODO: store window between calls if possible
314 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
315 | class SSIM(torch.nn.Module):
316 | def __init__(self, window_size=11, size_average=True, val_range=None):
317 | super(SSIM, self).__init__()
318 | self.window_size = window_size
319 | self.size_average = size_average
320 | self.val_range = val_range
321 |
322 | # Assume 1 channel for SSIM
323 | self.channel = 1
324 | self.window = create_window(window_size)
325 |
326 | def forward(self, img1, img2):
327 | (_, channel, _, _) = img1.size()
328 |
329 | if channel == self.channel and self.window.dtype == img1.dtype:
330 | window = self.window
331 | else:
332 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
333 | self.window = window
334 | self.channel = channel
335 |
336 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
337 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=True):
338 | device = img1.device
339 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
340 |
341 |
342 |
343 | levels = weights.size()[0]
344 |
345 | ssims = []
346 | mcs = []
347 |
348 | for _ in range(levels):
349 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
350 |
351 | # Relu normalize (not compliant with original definition)
352 | if normalize == "relu":
353 | ssims.append(torch.relu(sim))
354 | mcs.append(torch.relu(cs))
355 | else:
356 | ssims.append(sim)
357 | mcs.append(cs)
358 |
359 | img1 = F.avg_pool2d(img1, (2, 2))
360 | img2 = F.avg_pool2d(img2, (2, 2))
361 |
362 | ssims = torch.stack(ssims)
363 | mcs = torch.stack(mcs)
364 |
365 | # Simple normalize (not compliant with original definition)
366 | # TODO: remove support for normalize == True (kept for backward support)
367 | if normalize == "simple" or normalize == True:
368 | ssims = (ssims + 1) / 2
369 | mcs = (mcs + 1) / 2
370 |
371 | pow1 = mcs ** weights
372 | pow2 = ssims ** weights
373 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
374 | output = torch.prod(pow1[:-1] * pow2[-1])
375 | return output
376 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
377 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
378 | if val_range is None:
379 | if torch.max(img1) > 128:
380 | max_val = 255
381 | else:
382 | max_val = 1
383 |
384 | if torch.min(img1) < -0.5:
385 | min_val = -1
386 | else:
387 | min_val = 0
388 | L = max_val - min_val
389 | else:
390 | L = val_range
391 |
392 | padd = 0
393 | (_, channel, height, width) = img1.size()
394 | if window is None:
395 | real_size = min(window_size, height, width)
396 | window = create_window(real_size, channel=channel).to(img1.device)
397 |
398 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
399 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
400 |
401 | mu1_sq = mu1.pow(2)
402 | mu2_sq = mu2.pow(2)
403 | mu1_mu2 = mu1 * mu2
404 |
405 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
406 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
407 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
408 |
409 | C1 = (0.01 * L) ** 2
410 | C2 = (0.03 * L) ** 2
411 |
412 | v1 = 2.0 * sigma12 + C2
413 | v2 = sigma1_sq + sigma2_sq + C2
414 | cs = torch.mean(v1 / v2+epsilon_) # contrast sensitivity
415 |
416 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2+epsilon_)
417 |
418 | if size_average:
419 | ret = ssim_map.mean()
420 | else:
421 | ret = ssim_map.mean(1).mean(1).mean(1)
422 |
423 | if full:
424 | return ret, cs
425 | return ret
426 | def gaussian(window_size, sigma):
427 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
428 | return gauss/(gauss.sum()+epsilon_)
429 |
430 |
431 | def create_window(window_size, channel=1):
432 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
433 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
434 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
435 | return window
436 |
437 | class BCEDiceLoss(nn.Module):
438 | def __init__(self, weight=None, size_average=True):
439 | super().__init__()
440 |
441 | def forward(self, input, target):
442 |
443 | truth = target.view(-1)
444 | pred = input.view(-1)
445 | # pred = input
446 | # BCE loss
447 | bce_loss = nn.BCELoss()(pred, truth).double()
448 |
449 | # Dice Loss
450 | dice_coef = (2.0 * (pred * truth).double().sum() + 1) / (
451 | pred.double().sum() + truth.double().sum() + 1
452 | )
453 |
454 | return bce_loss + (1 - dice_coef)
455 |
456 | class BCEDiceLoss2(nn.Module):
457 | def __init__(self, weight=None, size_average=True):
458 | super().__init__()
459 |
460 | def forward(self, input, target):
461 |
462 | truth = target.view(-1)
463 | pred = input.view(-1)
464 | # pred = input
465 | # BCE loss
466 | bce_loss = nn.BCELoss()(pred, truth).double()
467 | eps = 1e-5
468 | # Dice Loss
469 | dice_coef = (2.0 * (pred * truth).double().sum() + eps) / (
470 | pred.double().sum() + truth.double().sum() + eps
471 | )
472 |
473 | return bce_loss + (1 - dice_coef)
474 | torch_losses = {
475 | 'l1loss': nn.L1Loss,
476 | 'l1': nn.L1Loss,
477 | 'mae': nn.L1Loss,
478 | 'mean_absolute_error': nn.L1Loss,
479 | 'smoothl1loss': nn.SmoothL1Loss,
480 | 'smoothl1': nn.SmoothL1Loss,
481 | 'mean_squared_error': nn.MSELoss,
482 | 'mse': nn.MSELoss,
483 | 'mseloss': nn.MSELoss,
484 | 'categorical_crossentropy': nn.CrossEntropyLoss,
485 | 'cce': nn.CrossEntropyLoss,
486 | 'crossentropyloss': nn.CrossEntropyLoss,
487 | 'negative_log_likelihood': nn.NLLLoss,
488 | 'nll': nn.NLLLoss,
489 | 'nllloss': nn.NLLLoss,
490 | 'poisson_negative_log_likelihood': nn.PoissonNLLLoss,
491 | 'poisson_nll': nn.PoissonNLLLoss,
492 | 'poissonnll': nn.PoissonNLLLoss,
493 | 'kullback_leibler_divergence': nn.KLDivLoss,
494 | 'kld': nn.KLDivLoss,
495 | 'kldivloss': nn.KLDivLoss,
496 | 'binary_crossentropy': nn.BCELoss,
497 | 'bce': nn.BCELoss,
498 | 'bceloss': nn.BCELoss,
499 | 'bcewithlogits': nn.BCEWithLogitsLoss,
500 | 'bcewithlogitsloss': nn.BCEWithLogitsLoss,
501 | 'hinge': nn.HingeEmbeddingLoss,
502 | 'hingeembeddingloss': nn.HingeEmbeddingLoss,
503 | 'multiclass_hinge': nn.MultiMarginLoss,
504 | 'multimarginloss': nn.MultiMarginLoss,
505 | 'softmarginloss': nn.SoftMarginLoss,
506 | 'softmargin': nn.SoftMarginLoss,
507 | 'multiclass_softmargin': nn.MultiLabelSoftMarginLoss,
508 | 'multilabelsoftmarginloss': nn.MultiLabelSoftMarginLoss,
509 | 'cosine': nn.CosineEmbeddingLoss,
510 | 'cosineloss': nn.CosineEmbeddingLoss,
511 | 'cosineembeddingloss': nn.CosineEmbeddingLoss,
512 | 'lovaszhinge': torch_lovasz_hinge,
513 | 'focalloss': TorchFocalLoss,
514 | 'focal': TorchFocalLoss,
515 | 'jaccard': TorchJaccardLoss,
516 | 'jaccardloss': TorchJaccardLoss,
517 | 'dice': TorchDiceLoss,
518 | 'diceloss': TorchDiceLoss
519 | , 'msssim': MSSSIM , 'bcedice': BCEDiceLoss, 'bcedice2': BCEDiceLoss2
520 | }
521 |
--------------------------------------------------------------------------------