├── net ├── utils │ ├── __init__.py │ ├── cmap.npy │ ├── cs_cmap.npy │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── helpers.cpython-36.pyc │ │ ├── helpers.cpython-37.pyc │ │ ├── layer_factory.cpython-36.pyc │ │ └── layer_factory.cpython-37.pyc │ ├── helpers.py │ └── layer_factory.py ├── Ours │ ├── __pycache__ │ │ └── Module.cpython-37.pyc │ ├── lib │ │ ├── __pycache__ │ │ │ ├── non_local_dot_product.cpython-36.pyc │ │ │ ├── non_local_dot_product.cpython-37.pyc │ │ │ └── non_local_embedded_gaussian.cpython-36.pyc │ │ ├── non_local_gaussian.py │ │ ├── non_local_dot_product.py │ │ ├── non_local_embedded_gaussian.py │ │ └── non_local_concatenation.py │ ├── EffLA.py │ ├── base.py │ ├── DMNet.py │ ├── SpNet.py │ └── Module.py └── LSTM │ ├── __pycache__ │ ├── grouplstm.cpython-36.pyc │ ├── grouplstm.cpython-37.pyc │ ├── bottlenecklstm.cpython-36.pyc │ ├── bottlenecklstm.cpython-37.pyc │ ├── torch_convlstm.cpython-36.pyc │ └── torch_convlstm.cpython-37.pyc │ ├── torch_convlstmv2.py │ ├── bottlenecklstm.py │ ├── torch_convlstm.py │ └── grouplstm.py ├── framework.jpg ├── .gitignore ├── src ├── tests │ ├── test_setup_network.py │ ├── test_setup_data_loaders.py │ ├── test_setup_optimisers_and_schedulers.py │ ├── test_networks.py │ └── test_transforms.py ├── network.py ├── optimisers.py ├── arguments.py ├── train.py └── data.py ├── utils ├── metrics.py ├── EndoMetric.py ├── LoadModel.py ├── EndoLoss.py ├── pytorch_modelsize.py ├── losses.py ├── summary.py └── image.py ├── README.md ├── dataset └── Endovis2017.py └── train.py /net/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/framework.jpg -------------------------------------------------------------------------------- /net/utils/cmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/utils/cmap.npy -------------------------------------------------------------------------------- /net/utils/cs_cmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/utils/cs_cmap.npy -------------------------------------------------------------------------------- /net/Ours/__pycache__/Module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/Ours/__pycache__/Module.cpython-37.pyc -------------------------------------------------------------------------------- /net/LSTM/__pycache__/grouplstm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/LSTM/__pycache__/grouplstm.cpython-36.pyc -------------------------------------------------------------------------------- /net/LSTM/__pycache__/grouplstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/LSTM/__pycache__/grouplstm.cpython-37.pyc -------------------------------------------------------------------------------- /net/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /net/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /net/utils/__pycache__/helpers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/utils/__pycache__/helpers.cpython-36.pyc -------------------------------------------------------------------------------- /net/utils/__pycache__/helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/utils/__pycache__/helpers.cpython-37.pyc -------------------------------------------------------------------------------- /net/LSTM/__pycache__/bottlenecklstm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/LSTM/__pycache__/bottlenecklstm.cpython-36.pyc -------------------------------------------------------------------------------- /net/LSTM/__pycache__/bottlenecklstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/LSTM/__pycache__/bottlenecklstm.cpython-37.pyc -------------------------------------------------------------------------------- /net/LSTM/__pycache__/torch_convlstm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/LSTM/__pycache__/torch_convlstm.cpython-36.pyc -------------------------------------------------------------------------------- /net/LSTM/__pycache__/torch_convlstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/LSTM/__pycache__/torch_convlstm.cpython-37.pyc -------------------------------------------------------------------------------- /net/utils/__pycache__/layer_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/utils/__pycache__/layer_factory.cpython-36.pyc -------------------------------------------------------------------------------- /net/utils/__pycache__/layer_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/utils/__pycache__/layer_factory.cpython-37.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | net/BiseNet 2 | net/light-weight-refinenet-master 3 | net/MobileNetRefine 4 | net/TDNet 5 | net/TernausNet 6 | net/TorchSeg 7 | net/unet 8 | net/utils 9 | scripts/ -------------------------------------------------------------------------------- /net/Ours/lib/__pycache__/non_local_dot_product.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/Ours/lib/__pycache__/non_local_dot_product.cpython-36.pyc -------------------------------------------------------------------------------- /net/Ours/lib/__pycache__/non_local_dot_product.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/Ours/lib/__pycache__/non_local_dot_product.cpython-37.pyc -------------------------------------------------------------------------------- /net/Ours/lib/__pycache__/non_local_embedded_gaussian.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcwang123/DMNet/HEAD/net/Ours/lib/__pycache__/non_local_embedded_gaussian.cpython-36.pyc -------------------------------------------------------------------------------- /src/tests/test_setup_network.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | 4 | import densetorch as dt 5 | 6 | from arguments import get_arguments 7 | from train import setup_network 8 | 9 | 10 | def test_setup_network(): 11 | # NOTE: Removing any sys.argv to get default arguments 12 | sys.argv = [""] 13 | args = get_arguments() 14 | device = "cuda" if torch.cuda.is_available() else "cpu" 15 | segmenter, training_loss, validation_loss = setup_network(args, device) 16 | assert isinstance(segmenter, torch.nn.Module) 17 | assert isinstance(training_loss, torch.nn.CrossEntropyLoss) 18 | assert isinstance(validation_loss, dt.engine.MeanIoU) 19 | -------------------------------------------------------------------------------- /src/tests/test_setup_data_loaders.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from torchvision.datasets import FakeData 4 | 5 | import train 6 | from arguments import get_arguments 7 | 8 | 9 | def get_fake_datasets(**kwargs): 10 | return FakeData(), FakeData() 11 | 12 | 13 | def test_setup_data_loaders(mocker): 14 | # NOTE: Removing any sys.argv to get default arguments 15 | sys.argv = [""] 16 | args = get_arguments() 17 | mocker.patch.object(train, "get_datasets", side_effect=get_fake_datasets) 18 | train_loaders, val_loader = train.setup_data_loaders(args) 19 | assert len(train_loaders) == args.num_stages 20 | for train_loader in train_loaders: 21 | assert isinstance(train_loader, torch.utils.data.DataLoader) 22 | assert isinstance(val_loader, torch.utils.data.DataLoader) 23 | -------------------------------------------------------------------------------- /net/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | IMG_SCALE = 1.0 / 255 5 | IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) 6 | IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) 7 | 8 | 9 | def maybe_download(model_name, model_url, model_dir=None, map_location=None): 10 | import os 11 | import sys 12 | from six.moves import urllib 13 | 14 | if model_dir is None: 15 | torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch")) 16 | model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models")) 17 | if not os.path.exists(model_dir): 18 | os.makedirs(model_dir) 19 | filename = "{}.pth.tar".format(model_name) 20 | cached_file = os.path.join(model_dir, filename) 21 | if not os.path.exists(cached_file): 22 | url = model_url 23 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 24 | urllib.request.urlretrieve(url, cached_file) 25 | return torch.load(cached_file, map_location=map_location) 26 | 27 | 28 | def prepare_img(img): 29 | return (img * IMG_SCALE - IMG_MEAN) / IMG_STD 30 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def compute_dice(pre,ref,return_all=False): 4 | # n*c*w*h 5 | pre = pre>0.5 6 | ref = ref>0.5 7 | assert pre.shape==ref.shape 8 | class_num = pre.shape[1] 9 | dice = np.zeros((class_num-1,)) 10 | for c in range(1, class_num): 11 | index = list(np.sum(ref[:,c],axis=(1,2))>10) 12 | p = pre[:,c][index] 13 | r = ref[:,c][index] 14 | dice[c-1] = np.mean(2*np.sum(p*r,axis=(1,2))/(np.sum(p,axis=(1,2))+np.sum(r,axis=(1,2)))) 15 | if return_all: 16 | return dice 17 | else: 18 | return np.mean(dice[dice>-1]) 19 | 20 | def compute_iou(pre,ref, return_all=False): 21 | pre = pre>0.5 22 | ref = ref>0.5 23 | assert pre.shape==ref.shape 24 | class_num = pre.shape[1] 25 | iou = np.zeros((class_num-1,)) 26 | for c in range(1, class_num): 27 | index = list(np.sum(ref[:,c],axis=(1,2))>10) 28 | p = pre[:,c][index] 29 | r = ref[:,c][index] 30 | iou[c-1] = np.mean(np.sum(p*r,axis=(1,2))/(np.sum(p,axis=(1,2))+np.sum(r,axis=(1,2))-np.sum(p*r,axis=(1,2)))) 31 | if return_all: 32 | return iou 33 | else: 34 | return np.mean(iou[iou>-1]) -------------------------------------------------------------------------------- /utils/EndoMetric.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | import argparse 4 | import cv2 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | def general_dice(y_true, y_pred): 9 | result = [] 10 | 11 | if y_true.sum() == 0: 12 | if y_pred.sum() == 0: 13 | return 1 14 | else: 15 | return 0 16 | 17 | for instrument_id in set(y_true.flatten()): 18 | if instrument_id == 0: 19 | continue 20 | result += [[instrument_id,dice(y_true == instrument_id, y_pred == instrument_id)]] 21 | 22 | return result 23 | 24 | 25 | def general_jaccard(y_true, y_pred): 26 | result = [] 27 | 28 | if y_true.sum() == 0: 29 | if y_pred.sum() == 0: 30 | return 1 31 | else: 32 | return 0 33 | 34 | for instrument_id in set(y_true.flatten()): 35 | if instrument_id == 0: 36 | continue 37 | result += [[instrument_id,jaccard(y_true == instrument_id, y_pred == instrument_id)]] 38 | 39 | return result 40 | 41 | def jaccard(y_true, y_pred): 42 | intersection = (y_true * y_pred).sum() 43 | union = y_true.sum() + y_pred.sum() - intersection 44 | return (intersection + 1e-15) / (union + 1e-15) 45 | 46 | 47 | def dice(y_true, y_pred): 48 | return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15) -------------------------------------------------------------------------------- /src/tests/test_setup_optimisers_and_schedulers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | 4 | from arguments import get_arguments 5 | from train import setup_optimisers_and_schedulers 6 | 7 | 8 | class DummyEncDecModel(torch.nn.Module): 9 | def __init__(self): 10 | super(DummyEncDecModel, self).__init__() 11 | self.layer1 = torch.nn.Parameter(torch.FloatTensor(1, 2)) 12 | self.dec1 = torch.nn.Parameter(torch.FloatTensor(1, 2)) 13 | 14 | 15 | def test_setup_optimisers_and_schedulers(): 16 | # NOTE: Removing any sys.argv to get default arguments 17 | sys.argv = [""] 18 | args = get_arguments() 19 | model = DummyEncDecModel() 20 | optimisers, schedulers = setup_optimisers_and_schedulers(args, model) 21 | assert len(optimisers) == 2 22 | assert len(schedulers) == 2 23 | for optimiser in optimisers: 24 | assert isinstance(optimiser, torch.optim.Optimizer) 25 | assert hasattr(optimiser, "state_dict") 26 | assert hasattr(optimiser, "load_state_dict") 27 | assert hasattr(optimiser, "step") 28 | assert hasattr(optimiser, "zero_grad") 29 | for scheduler in schedulers: 30 | assert isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler) 31 | assert hasattr(scheduler, "state_dict") 32 | assert hasattr(scheduler, "load_state_dict") 33 | assert hasattr(scheduler, "step") 34 | -------------------------------------------------------------------------------- /src/network.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | from models.mobilenet import mbv2 5 | # from models.resnet import rf_lw50, rf_lw101, rf_lw152 6 | 7 | 8 | def get_segmenter( 9 | enc_backbone, enc_pretrained, num_classes, 10 | ): 11 | """Create Encoder-Decoder; for now only ResNet [50,101,152] Encoders are supported""" 12 | if enc_backbone == "50": 13 | return rf_lw50(num_classes, imagenet=enc_pretrained) 14 | elif enc_backbone == "101": 15 | return rf_lw101(num_classes, imagenet=enc_pretrained) 16 | elif enc_backbone == "152": 17 | return rf_lw152(num_classes, imagenet=enc_pretrained) 18 | elif enc_backbone == "mbv2": 19 | return mbv2(num_classes, imagenet=enc_pretrained) 20 | else: 21 | raise ValueError("{} is not supported".format(str(enc_backbone))) 22 | 23 | 24 | def get_encoder_and_decoder_params(model): 25 | """Filter model parameters into two groups: encoder and decoder.""" 26 | logger = logging.getLogger(__name__) 27 | enc_params = [] 28 | dec_params = [] 29 | for k, v in model.named_parameters(): 30 | if bool(re.match(".*conv1.*|.*bn1.*|.*layer.*", k)): 31 | enc_params.append(v) 32 | logger.info(" Enc. parameter: {}".format(k)) 33 | else: 34 | dec_params.append(v) 35 | logger.info(" Dec. parameter: {}".format(k)) 36 | return enc_params, dec_params 37 | -------------------------------------------------------------------------------- /utils/LoadModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | 5 | def load_model(model, pretrain_dir, log=True): 6 | state_dict_ = torch.load(pretrain_dir, map_location='cuda:0') 7 | print('loaded pretrained weights form %s !' % pretrain_dir) 8 | state_dict = OrderedDict() 9 | 10 | # convert data_parallal to model 11 | for key in state_dict_: 12 | if key.startswith('module') and not key.startswith('module_list'): 13 | state_dict[key[7:]] = state_dict_[key] 14 | else: 15 | state_dict[key] = state_dict_[key] 16 | 17 | # check loaded parameters and created model parameters 18 | model_state_dict = model.state_dict() 19 | for key in state_dict: 20 | if key in model_state_dict: 21 | # print(key,state_dict[key].shape,model_state_dict[key].shape) 22 | if state_dict[key].shape != model_state_dict[key].shape: 23 | if log: 24 | print('Skip loading parameter {}, required shape{}, loaded shape{}.'.format(key, model_state_dict[key].shape, state_dict[key].shape)) 25 | state_dict[key] = model_state_dict[key] 26 | else: 27 | if log: 28 | print('Drop parameter {}.'.format(key)) 29 | for key in model_state_dict: 30 | if key not in state_dict: 31 | if log: 32 | print('No param {}.'.format(key)) 33 | state_dict[key] = model_state_dict[key] 34 | model.load_state_dict(state_dict, strict=False) 35 | 36 | return model -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Global-Local Memory for Real-time Instrument Segmentation of Robotic Surgical Video 2 | We propose, on the one hand, an efficient local memory by taking the complementary advantages of convolutional LSTM and non-local mechanisms towards the relating reception field. On the other hand, we develop an active global memory to gather the global semantic correlation in long temporal range to current one, in which we gather the most informative frames derived from model uncertainty and frame similarity. 3 | 4 | This paper has been accepted by [MICCAI](https://link.springer.com/chapter/10.1007/978-3-030-87202-1_33). 5 | Get the full paper on [Arxiv](https://arxiv.org/abs/2109.13593). 6 | 7 | ![bat](./framework.jpg) 8 | Fig. 1. Structure of DMNet. 9 | 10 | ## Message 11 | We have updated the codes of Efficient LA and GA. As the active selection is used only at inferrence and written in jupyter, we will update this part later. -- by 10/12 12 | 13 | 14 | ## Code List 15 | 16 | - [x] Pre-processing 17 | - [x] Training Codes 18 | - [x] Network 19 | 20 | For more details or any questions, please feel easy to contact us by email ^\_^ 21 | 22 | ## Usage 23 | 24 | 25 | ## Citation 26 | If you find DMNet useful in your research, please consider citing: 27 | 28 | ``` 29 | @inproceedings{wang2021efficient, 30 | title={Efficient Global-Local Memory for Real-Time Instrument Segmentation of Robotic Surgical Video}, 31 | author={Wang, Jiacheng and Jin, Yueming and Wang, Liansheng and Cai, Shuntian and Heng, Pheng-Ann and Qin, Jing}, 32 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 33 | pages={341--351}, 34 | year={2021}, 35 | organization={Springer} 36 | } 37 | ``` 38 | 39 | -------------------------------------------------------------------------------- /net/Ours/EffLA.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys, time, os 5 | 6 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../')) 7 | from net.Ours.Module import * 8 | 9 | 10 | class EffLA(nn.Module): 11 | def __init__(self, num_classes, tag): 12 | super(EffLA, self).__init__() 13 | self.encoder = MobileEncoder() 14 | self.decoder = RefineDecoder(num_classes) 15 | self.lstm = TimeProcesser(256, 256, (16, 20), 1, tag, 1) 16 | self.memory = Memory(256) 17 | 18 | def forward(self, x): 19 | tic = time.perf_counter() 20 | 21 | b, t, _, w, h = x.size() 22 | 23 | seq = [] 24 | for i in range(t): 25 | tensor = self.encoder(x[:, i]) 26 | seq.append(tensor[-1].unsqueeze(1)) 27 | seq = torch.cat(seq, dim=1) 28 | 29 | temporal_output = self.lstm(seq)[:, -1:] 30 | densest_output, p = self.memory(temporal_output, temporal_output[:, 0]) 31 | 32 | tensor[-1] = densest_output 33 | out_segm = self.decoder(tensor) 34 | return out_segm 35 | 36 | def _initialize_weights(self): 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | m.weight.data.normal_(0, 0.01) 40 | if m.bias is not None: 41 | m.bias.data.zero_() 42 | elif isinstance(m, nn.BatchNorm2d): 43 | m.weight.data.fill_(1) 44 | m.bias.data.zero_() 45 | 46 | 47 | if __name__ == '__main__': 48 | import torch 49 | net = EffLA(11, tag='convlstm').cuda() 50 | 51 | print('CALculate..') 52 | with torch.no_grad(): 53 | y = net(torch.randn(2, 5, 3, 512, 640).cuda()) 54 | -------------------------------------------------------------------------------- /src/optimisers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import densetorch as dt 4 | 5 | from network import get_encoder_and_decoder_params 6 | 7 | 8 | def get_lr_schedulers( 9 | enc_optim, 10 | dec_optim, 11 | enc_lr_gamma, 12 | dec_lr_gamma, 13 | enc_scheduler_type, 14 | dec_scheduler_type, 15 | epochs_per_stage, 16 | ): 17 | milestones = np.cumsum(epochs_per_stage) 18 | max_epochs = milestones[-1] 19 | schedulers = [ 20 | dt.misc.create_scheduler( 21 | scheduler_type=enc_scheduler_type, 22 | optim=enc_optim, 23 | gamma=enc_lr_gamma, 24 | milestones=milestones, 25 | max_epochs=max_epochs, 26 | ), 27 | dt.misc.create_scheduler( 28 | scheduler_type=dec_scheduler_type, 29 | optim=dec_optim, 30 | gamma=dec_lr_gamma, 31 | milestones=milestones, 32 | max_epochs=max_epochs, 33 | ), 34 | ] 35 | return schedulers 36 | 37 | 38 | def get_optimisers( 39 | model, 40 | enc_optim_type, 41 | enc_lr, 42 | enc_weight_decay, 43 | enc_momentum, 44 | dec_optim_type, 45 | dec_lr, 46 | dec_weight_decay, 47 | dec_momentum, 48 | ): 49 | enc_params, dec_params = get_encoder_and_decoder_params(model) 50 | optimisers = [ 51 | dt.misc.create_optim( 52 | optim_type=enc_optim_type, 53 | parameters=enc_params, 54 | lr=enc_lr, 55 | weight_decay=enc_weight_decay, 56 | momentum=enc_momentum, 57 | ), 58 | dt.misc.create_optim( 59 | optim_type=dec_optim_type, 60 | parameters=dec_params, 61 | lr=dec_lr, 62 | weight_decay=dec_weight_decay, 63 | momentum=dec_momentum, 64 | ), 65 | ] 66 | return optimisers 67 | -------------------------------------------------------------------------------- /utils/EndoLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import utils 5 | import numpy as np 6 | 7 | 8 | class LossBinary: 9 | """ 10 | Loss defined as \alpha BCE - (1 - \alpha) SoftJaccard 11 | """ 12 | 13 | def __init__(self, jaccard_weight=0): 14 | self.nll_loss = nn.BCEWithLogitsLoss() 15 | self.jaccard_weight = jaccard_weight 16 | 17 | def __call__(self, outputs, targets): 18 | loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets) 19 | 20 | if self.jaccard_weight: 21 | eps = 1e-15 22 | jaccard_target = (targets == 1).float() 23 | jaccard_output = F.sigmoid(outputs) 24 | 25 | intersection = (jaccard_output * jaccard_target).sum() 26 | union = jaccard_output.sum() + jaccard_target.sum() 27 | 28 | loss -= self.jaccard_weight * torch.log((intersection + eps) / (union - intersection + eps)) 29 | return loss 30 | 31 | 32 | class LossMulti: 33 | def __init__(self, jaccard_weight=0, class_weights=None, num_classes=1): 34 | if class_weights is not None: 35 | nll_weight = utils.cuda( 36 | torch.from_numpy(class_weights.astype(np.float32))) 37 | else: 38 | nll_weight = None 39 | self.nll_loss = nn.NLLLoss2d(weight=nll_weight) 40 | self.jaccard_weight = jaccard_weight 41 | self.num_classes = num_classes 42 | 43 | def __call__(self, outputs, targets): 44 | loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets) 45 | 46 | if self.jaccard_weight: 47 | eps = 1e-15 48 | for cls in range(self.num_classes): 49 | jaccard_target = (targets == cls).float() 50 | jaccard_output = outputs[:, cls].exp() 51 | intersection = (jaccard_output * jaccard_target).sum() 52 | 53 | union = jaccard_output.sum() + jaccard_target.sum() 54 | loss -= torch.log((intersection + eps) / (union - intersection + eps)) * self.jaccard_weight 55 | return loss -------------------------------------------------------------------------------- /net/Ours/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys, time, os 4 | 5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../')) 6 | 7 | # from net.utils.helpers import maybe_download 8 | # from net.utils.layer_factory import conv1x1, conv3x3, convbnrelu, CRPBlock 9 | from net.LSTM.torch_convlstm import ConvLSTM 10 | from net.LSTM.bottlenecklstm import BottleneckLSTM 11 | from net.LSTM.grouplstm import GroupLSTM 12 | from net.Ours.Module import * 13 | 14 | from net.Ours.lib.non_local_dot_product import NONLocalBlock2D 15 | 16 | 17 | class TemporalNet(nn.Module): 18 | def __init__(self, num_classes, batch_size, tag, group): 19 | super(TemporalNet, self).__init__() 20 | self.encoder = MobileEncoder() 21 | self.decoder = RefineDecoder(num_classes) 22 | self.lstm = TimeProcesser(256, 256, (16, 20), batch_size, tag, group) 23 | 24 | def forward(self, x): 25 | tic = time.perf_counter() 26 | b, t, _, w, h = x.size() # 27 | 28 | seq = [] 29 | 30 | for i in range(t): 31 | tensor = self.encoder(x[:, i]) 32 | seq.append(tensor[-1].unsqueeze(1)) 33 | tem = torch.cat(seq, dim=1) # b,t,c,w,h 34 | 35 | temporal_output = self.lstm(tem)[:, -1] 36 | 37 | tensor[-1] = temporal_output 38 | out_segm = self.decoder(tensor) 39 | return out_segm 40 | 41 | def _initialize_weights(self): 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | m.weight.data.normal_(0, 0.01) 45 | if m.bias is not None: 46 | m.bias.data.zero_() 47 | elif isinstance(m, nn.BatchNorm2d): 48 | m.weight.data.fill_(1) 49 | m.bias.data.zero_() 50 | 51 | 52 | if __name__ == '__main__': 53 | import torch 54 | net = TemporalNet(11, batch_size=8, tag='btnlstm', group=1).cuda() 55 | 56 | def hook(self, input, output): 57 | print(output.data.cpu().numpy().shape) 58 | 59 | print('CALculate..') 60 | with torch.no_grad(): 61 | y = net(torch.randn(2, 5, 3, 512, 640).cuda()) 62 | -------------------------------------------------------------------------------- /src/tests/test_networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import random 4 | import torch 5 | 6 | import densetorch as dt 7 | 8 | from network import get_segmenter, get_encoder_and_decoder_params 9 | 10 | 11 | NUMBER_OF_PARAMETERS_WITH_21_CLASSES = { 12 | "152": 61993301, 13 | "101": 46349653, 14 | "50": 27357525, 15 | "mbv2": 3284565, 16 | } 17 | 18 | NUMBER_OF_ENCODER_DECODER_LAYERS = { 19 | "152": (465, 28), 20 | "101": (312, 28), 21 | "50": (159, 28), 22 | "mbv2": (156, 27), 23 | } 24 | 25 | 26 | def get_dummy_input_tensor(height, width, channels=3, batch=4): 27 | input_tensor = torch.FloatTensor(batch, channels, height, width).float() 28 | return input_tensor 29 | 30 | 31 | def get_network_output_shape(h, w, output_stride=4): 32 | return np.ceil(h / output_stride), np.ceil(w / output_stride) 33 | 34 | 35 | @pytest.fixture() 36 | def num_classes(): 37 | return random.randint(1, 40) 38 | 39 | 40 | @pytest.fixture() 41 | def input_height(): 42 | return random.randint(33, 320) 43 | 44 | 45 | @pytest.fixture() 46 | def input_width(): 47 | return random.randint(33, 320) 48 | 49 | 50 | @pytest.mark.parametrize("enc_backbone", ["50", "101", "152", "mbv2"]) 51 | @pytest.mark.parametrize("enc_pretrained", [False, True]) 52 | def test_networks(enc_backbone, enc_pretrained, num_classes, input_height, input_width): 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | network = ( 55 | get_segmenter( 56 | enc_backbone=enc_backbone, 57 | enc_pretrained=enc_pretrained, 58 | num_classes=num_classes, 59 | ) 60 | .eval() 61 | .to(device) 62 | ) 63 | if num_classes == 21: 64 | assert ( 65 | dt.misc.compute_params(network) 66 | == NUMBER_OF_PARAMETERS_WITH_21_CLASSES[enc_backbone] 67 | ) 68 | 69 | enc_params, dec_params = get_encoder_and_decoder_params(network) 70 | n_enc_layers, n_dec_layers = NUMBER_OF_ENCODER_DECODER_LAYERS[enc_backbone] 71 | assert len(enc_params) == n_enc_layers 72 | assert len(dec_params) == n_dec_layers 73 | 74 | with torch.no_grad(): 75 | input_tensor = get_dummy_input_tensor( 76 | height=input_height, width=input_width 77 | ).to(device) 78 | output_h, output_w = get_network_output_shape(*input_tensor.shape[-2:]) 79 | output = network(input_tensor) 80 | assert output.size(1) == num_classes 81 | assert output.size(2) == output_h 82 | assert output.size(3) == output_w 83 | -------------------------------------------------------------------------------- /net/Ours/DMNet.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys, time, os 5 | 6 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../')) 7 | 8 | from net.Ours.Module import * 9 | from net.Ours.lib.non_local_dot_product import NONLocalBlock2D 10 | 11 | 12 | class DMNet(nn.Module): 13 | def __init__(self, 14 | num_classes, 15 | batch_size, 16 | tag, 17 | group, 18 | t, 19 | global_n, 20 | fusion_type='tandem'): 21 | super(DMNet, self).__init__() 22 | self.encoder = MobileEncoder() 23 | self.decoder = RefineDecoder(num_classes) 24 | self.lstm = TimeProcesser(256, 256, (16, 20), batch_size, tag, group) 25 | self.memory = Memory(256) 26 | self.t = t 27 | self.g = global_n 28 | self.ft = fusion_type 29 | 30 | def forward(self, x): 31 | tic = time.perf_counter() 32 | g = self.g 33 | t = self.t 34 | b, n, _, w, h = x.size() # 35 | assert self.g + self.t == n 36 | 37 | seq = [] 38 | 39 | for i in range(g): 40 | tensor = self.encoder(x[:, i]) 41 | seq.append(tensor[-1].unsqueeze(1)) 42 | global_mem = torch.cat(seq, dim=1) # b,g,c,w,h 43 | 44 | seq = [] 45 | for i in range(g, n): 46 | tensor = self.encoder(x[:, i]) 47 | seq.append(tensor[-1].unsqueeze(1)) 48 | local_mem = torch.cat(seq, dim=1) # b,g,c,w,h 49 | 50 | if self.ft == 'tandem': 51 | local_output = self.lstm(local_mem)[:, -1] 52 | final_output, gdst_p = self.memory(global_mem, local_output) 53 | else: 54 | local_output = self.lstm(local_mem)[:, -1] 55 | global_output, _ = self.memory(global_mem, local_mem[:, -1]) 56 | if self.ft == 'add': 57 | final_output = global_output + local_output 58 | else: 59 | raise NotImplementedError 60 | tensor[-1] = final_output 61 | out_segm = self.decoder(tensor) 62 | return out_segm 63 | 64 | def _initialize_weights(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | m.weight.data.normal_(0, 0.01) 68 | if m.bias is not None: 69 | m.bias.data.zero_() 70 | elif isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | 75 | if __name__ == '__main__': 76 | import torch 77 | net = DMNet(11, batch_size=8, tag='convlstm', group=1, t=5, 78 | global_n=4).cuda() 79 | 80 | print('CALculate..') 81 | with torch.no_grad(): 82 | y = net(torch.randn(2, 9, 3, 512, 640).cuda()) 83 | -------------------------------------------------------------------------------- /utils/pytorch_modelsize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | class SizeEstimator(object): 7 | 8 | def __init__(self, model, input_size=(1,1,32,32), bits=32): 9 | ''' 10 | Estimates the size of PyTorch models in memory 11 | for a given input size 12 | ''' 13 | self.model = model 14 | self.input_size = input_size 15 | self.bits = bits 16 | return 17 | 18 | def get_parameter_sizes(self): 19 | '''Get sizes of all parameters in `model`''' 20 | mods = list(self.model.modules()) 21 | sizes = [] 22 | 23 | for i in range(1,len(mods)): 24 | m = mods[i] 25 | p = list(m.parameters()) 26 | for j in range(len(p)): 27 | sizes.append(np.array(p[j].size())) 28 | 29 | self.param_sizes = sizes 30 | return 31 | 32 | def get_output_sizes(self): 33 | '''Run sample input through each layer to get output sizes''' 34 | input_ = Variable(torch.FloatTensor(*self.input_size), volatile=True) 35 | mods = list(self.model.modules()) 36 | out_sizes = [] 37 | for i in range(1, len(mods)): 38 | m = mods[i] 39 | out = m(input_) 40 | out_sizes.append(np.array(out.size())) 41 | input_ = out 42 | 43 | self.out_sizes = out_sizes 44 | return 45 | 46 | def calc_param_bits(self): 47 | '''Calculate total number of bits to store `model` parameters''' 48 | total_bits = 0 49 | for i in range(len(self.param_sizes)): 50 | s = self.param_sizes[i] 51 | bits = np.prod(np.array(s))*self.bits 52 | total_bits += bits 53 | self.param_bits = total_bits 54 | return 55 | 56 | def calc_forward_backward_bits(self): 57 | '''Calculate bits to store forward and backward pass''' 58 | total_bits = 0 59 | for i in range(len(self.out_sizes)): 60 | s = self.out_sizes[i] 61 | bits = np.prod(np.array(s))*self.bits 62 | total_bits += bits 63 | # multiply by 2 for both forward AND backward 64 | self.forward_backward_bits = (total_bits*2) 65 | return 66 | 67 | def calc_input_bits(self): 68 | '''Calculate bits to store input''' 69 | self.input_bits = np.prod(np.array(self.input_size))*self.bits 70 | return 71 | 72 | def estimate_size(self): 73 | '''Estimate model size in memory in megabytes and bits''' 74 | self.get_parameter_sizes() 75 | self.get_output_sizes() 76 | self.calc_param_bits() 77 | self.calc_forward_backward_bits() 78 | self.calc_input_bits() 79 | total = self.param_bits + self.forward_backward_bits + self.input_bits 80 | 81 | total_megabytes = (total/8)/(1024**2) 82 | return total_megabytes, total -------------------------------------------------------------------------------- /net/Ours/SpNet.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys, time, os 5 | 6 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../')) 7 | 8 | from net.utils.helpers import maybe_download 9 | from net.utils.layer_factory import conv1x1, conv3x3, convbnrelu, CRPBlock 10 | from net.LSTM.torch_convlstm import ConvLSTM 11 | from net.LSTM.bottlenecklstm import BottleneckLSTM 12 | from net.LSTM.grouplstm import GroupLSTM 13 | from net.Ours.Module import * 14 | 15 | 16 | class SPNet(nn.Module): 17 | def __init__(self, num_classes, global_n, spatial_layer): 18 | super(SPNet, self).__init__() 19 | c = 256 if spatial_layer == -1 else 96 20 | self.memory = Memory(c) 21 | self.global_n = global_n 22 | self.encoder = MobileEncoder() 23 | self.decoder = RefineDecoder(num_classes) 24 | self.spatial_layer = spatial_layer 25 | 26 | def forward(self, x): 27 | tic = time.perf_counter() 28 | 29 | b, t, _, w, h = x.size() 30 | 31 | seq = [] 32 | for i in range(t): 33 | tensor = self.encoder(x[:, i]) 34 | seq.append(tensor[self.spatial_layer].unsqueeze(1)) 35 | seq = torch.cat(seq, dim=1) 36 | 37 | global_context = seq[:, :-1] 38 | current_context = seq[:, -1] 39 | if self.global_n > 0: 40 | st_outputs, st_p = self.memory(global_context, current_context) 41 | else: 42 | st_outputs = current_context 43 | tensor[self.spatial_layer] = st_outputs 44 | out_segm = self.decoder(tensor) 45 | return out_segm 46 | 47 | def _initialize_weights(self): 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | m.weight.data.normal_(0, 0.01) 51 | if m.bias is not None: 52 | m.bias.data.zero_() 53 | elif isinstance(m, nn.BatchNorm2d): 54 | m.weight.data.fill_(1) 55 | m.bias.data.zero_() 56 | 57 | 58 | def spnet(num_classes, imagenet=False, pretrained=True, **kwargs): 59 | """Constructs the network. 60 | 61 | Args: 62 | 63 | num_classes (int): the number of classes for the segmentation head to output. 64 | 65 | """ 66 | model = SPNet(num_classes, **kwargs) 67 | if imagenet: 68 | key = "mbv2_imagenet" 69 | url = models_urls[key] 70 | model.load_state_dict(maybe_download(key, url), strict=False) 71 | elif pretrained: 72 | dataset = data_info.get(num_classes, None) 73 | if dataset: 74 | bname = "mbv2_" + dataset.lower() 75 | key = "rf_lw" + bname 76 | url = models_urls[bname] 77 | model.load_state_dict(maybe_download(key, url), strict=False) 78 | return model 79 | 80 | 81 | if __name__ == '__main__': 82 | import torch 83 | net = spnet(11, imagenet=True, global_n=4, spatial_layer=-1).cuda() 84 | 85 | print('CALculate..') 86 | with torch.no_grad(): 87 | y = net(torch.randn(2, 5, 3, 512, 640).cuda()) 88 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | 10 | def make_one_hot(input, num_classes): 11 | """Convert class index tensor to one hot encoding tensor. 12 | Args: 13 | input: A tensor of shape [N, 1, *] 14 | num_classes: An int of number of class 15 | Returns: 16 | A tensor of shape [N, num_classes, *] 17 | """ 18 | shape = np.array(input.shape) 19 | shape[1] = num_classes 20 | shape = tuple(shape) 21 | result = torch.zeros(shape) 22 | result = result.scatter_(1, input.cpu(), 1) 23 | 24 | return result 25 | 26 | # class BCELoss(nn.Module): 27 | # def __init__(self): 28 | # super(BCELoss, self).__init__() 29 | # def forward(self, predict, target): 30 | # predict = F.softmax(predict,dim=1) 31 | # return F.binary_cross_entropy(predict, target) 32 | 33 | class BinaryCrossEntropyLoss(nn.Module): 34 | """Dice loss of binary class 35 | Args: 36 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 37 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 38 | predict: A tensor of shape [N, *] 39 | target: A tensor of shape same with predict 40 | reduction: Reduction method to apply, return mean over batch if 'mean', 41 | return sum if 'sum', return a tensor of shape [N,] if 'none' 42 | Returns: 43 | Loss tensor according to arg reduction 44 | Raise: 45 | Exception if unexpected reduction 46 | """ 47 | def __init__(self, smooth=1, p=2, reduction='mean'): 48 | super(BinaryCrossEntropyLoss, self).__init__() 49 | self.smooth = smooth 50 | self.p = p 51 | self.reduction = reduction 52 | 53 | def forward(self, predict, target): 54 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 55 | predict = predict.contiguous().view(predict.shape[0], -1) 56 | target = target.contiguous().view(target.shape[0], -1) 57 | return F.binary_cross_entropy(predict, target) 58 | 59 | class BCELoss(nn.Module): 60 | """Dice loss, need one hot encode input 61 | Args: 62 | weight: An array of shape [num_classes,] 63 | ignore_index: class index to ignore 64 | predict: A tensor of shape [N, C, *] 65 | target: A tensor of same shape with predict 66 | other args pass to BinaryDiceLoss 67 | Return: 68 | same as BinaryDiceLoss 69 | """ 70 | def __init__(self, weight=None, ignore_index=None, **kwargs): 71 | super(BCELoss, self).__init__() 72 | self.kwargs = kwargs 73 | self.weight = weight 74 | self.ignore_index = ignore_index 75 | 76 | def forward(self, predict, target): 77 | assert predict.shape == target.shape, 'predict & target shape do not match' 78 | dice = BinaryCrossEntropyLoss(**self.kwargs) 79 | total_loss = [] 80 | predict = F.softmax(predict, dim=1) 81 | avg_loss = 0 82 | for i in range(target.shape[1]): 83 | if i != self.ignore_index: 84 | loss = dice(predict[:, i], target[:, i]) 85 | if self.weight is not None: 86 | assert self.weight.shape[0] == target.shape[1], \ 87 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) 88 | loss *= self.weights[i] 89 | total_loss.append(loss) 90 | avg_loss += loss 91 | return avg_loss/target.shape[1] -------------------------------------------------------------------------------- /utils/summary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import logging 5 | from datetime import datetime 6 | 7 | # return a fake summarywriter if tensorbaordX is not installed 8 | 9 | try: 10 | from tensorboardX import SummaryWriter 11 | except ImportError: 12 | class SummaryWriter: 13 | def __init__(self, log_dir=None, comment='', **kwargs): 14 | print('\nunable to import tensorboardX, log will be recorded by pytorch!\n') 15 | self.log_dir = log_dir if log_dir is not None else './logs' 16 | os.makedirs('./logs', exist_ok=True) 17 | self.logs = {'comment': comment} 18 | return 19 | 20 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): 21 | if tag in self.logs: 22 | self.logs[tag].append((scalar_value, global_step, walltime)) 23 | else: 24 | self.logs[tag] = [(scalar_value, global_step, walltime)] 25 | return 26 | 27 | def close(self): 28 | timestamp = str(datetime.now()).replace(' ', '_').replace(':', '_') 29 | torch.save(self.logs, os.path.join(self.log_dir, 'log_%s.pickle' % timestamp)) 30 | return 31 | 32 | 33 | class EmptySummaryWriter: 34 | def __init__(self, **kwargs): 35 | pass 36 | 37 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): 38 | pass 39 | 40 | def close(self): 41 | pass 42 | 43 | 44 | def create_summary(distributed_rank=0, **kwargs): 45 | if distributed_rank > 0: 46 | return EmptySummaryWriter(**kwargs) 47 | else: 48 | return SummaryWriter(**kwargs) 49 | 50 | 51 | def create_logger(distributed_rank=0, save_dir=None): 52 | logger = logging.getLogger('logger') 53 | logger.setLevel(logging.DEBUG) 54 | 55 | filename = "log_%s.txt" % (datetime.now().strftime("%Y_%m_%d_%H_%M_%S")) 56 | 57 | # don't log results for the non-master process 58 | if distributed_rank > 0: 59 | return logger 60 | ch = logging.StreamHandler(stream=sys.stdout) 61 | ch.setLevel(logging.DEBUG) 62 | # formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 63 | formatter = logging.Formatter("%(message)s [%(asctime)s]") 64 | ch.setFormatter(formatter) 65 | logger.addHandler(ch) 66 | 67 | if save_dir is not None: 68 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 69 | fh.setLevel(logging.DEBUG) 70 | fh.setFormatter(formatter) 71 | logger.addHandler(fh) 72 | 73 | return logger 74 | 75 | 76 | class Saver: 77 | def __init__(self, distributed_rank, save_dir): 78 | self.distributed_rank = distributed_rank 79 | self.save_dir = save_dir 80 | os.makedirs(self.save_dir, exist_ok=True) 81 | return 82 | 83 | def save(self, obj, save_name): 84 | if self.distributed_rank == 0: 85 | torch.save(obj, os.path.join(self.save_dir, save_name + '.t7')) 86 | return 'checkpoint saved in %s !' % os.path.join(self.save_dir, save_name) 87 | else: 88 | return '' 89 | 90 | 91 | def create_saver(distributed_rank, save_dir): 92 | return Saver(distributed_rank, save_dir) 93 | 94 | 95 | class DisablePrint: 96 | def __init__(self, local_rank=0): 97 | self.local_rank = local_rank 98 | 99 | def __enter__(self): 100 | if self.local_rank != 0: 101 | self._original_stdout = sys.stdout 102 | sys.stdout = open(os.devnull, 'w') 103 | else: 104 | pass 105 | 106 | def __exit__(self, exc_type, exc_val, exc_tb): 107 | if self.local_rank != 0: 108 | sys.stdout.close() 109 | sys.stdout = self._original_stdout 110 | else: 111 | pass 112 | 113 | 114 | if __name__ == '__main__': 115 | sw = SummaryWriter() 116 | sw.close() 117 | -------------------------------------------------------------------------------- /net/utils/layer_factory.py: -------------------------------------------------------------------------------- 1 | """RefineNet-LightWeight-CRP Block 2 | 3 | RefineNet-LigthWeight PyTorch for non-commercial purposes 4 | 5 | Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au) 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | """ 29 | 30 | import torch.nn as nn 31 | 32 | 33 | def batchnorm(in_planes): 34 | "batch norm 2d" 35 | return nn.BatchNorm2d(in_planes, affine=True, eps=1e-5, momentum=0.1) 36 | 37 | 38 | def conv3x3(in_planes, out_planes, stride=1, bias=False): 39 | "3x3 convolution with padding" 40 | return nn.Conv2d( 41 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=bias 42 | ) 43 | 44 | 45 | def conv1x1(in_planes, out_planes, stride=1, bias=False): 46 | "1x1 convolution" 47 | return nn.Conv2d( 48 | in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=bias 49 | ) 50 | 51 | 52 | def convbnrelu(in_planes, out_planes, kernel_size, stride=1, groups=1, act=True): 53 | "conv-batchnorm-relu" 54 | if act: 55 | return nn.Sequential( 56 | nn.Conv2d( 57 | in_planes, 58 | out_planes, 59 | kernel_size, 60 | stride=stride, 61 | padding=int(kernel_size / 2.0), 62 | groups=groups, 63 | bias=False, 64 | ), 65 | batchnorm(out_planes), 66 | nn.ReLU6(inplace=True), 67 | ) 68 | else: 69 | return nn.Sequential( 70 | nn.Conv2d( 71 | in_planes, 72 | out_planes, 73 | kernel_size, 74 | stride=stride, 75 | padding=int(kernel_size / 2.0), 76 | groups=groups, 77 | bias=False, 78 | ), 79 | batchnorm(out_planes), 80 | ) 81 | 82 | class CRPBlock(nn.Module): 83 | def __init__(self, in_planes, out_planes, n_stages): 84 | super(CRPBlock, self).__init__() 85 | for i in range(n_stages): 86 | setattr( 87 | self, 88 | "{}_{}".format(i + 1, "outvar_dimred"), 89 | conv1x1( 90 | in_planes if (i == 0) else out_planes, 91 | out_planes, 92 | stride=1, 93 | bias=False, 94 | ), 95 | ) 96 | self.stride = 1 97 | self.n_stages = n_stages 98 | self.maxpool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) 99 | 100 | def forward(self, x): 101 | top = x 102 | for i in range(self.n_stages): 103 | top = self.maxpool(top) 104 | top = getattr(self, "{}_{}".format(i + 1, "outvar_dimred"))(top) 105 | x = top + x 106 | return x 107 | -------------------------------------------------------------------------------- /src/tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import random 4 | import torch 5 | 6 | from densetorch.misc import broadcast 7 | 8 | from data import get_transforms 9 | 10 | 11 | def get_dummy_image_and_mask(size=(512, 512)): 12 | image = np.random.randint(low=0, high=255, size=size + (3,)).astype(np.float32) 13 | mask = np.random.randint(low=0, high=15, size=size, dtype=np.uint8) 14 | return image, mask 15 | 16 | 17 | def pack_sample(image, mask, dataset_type): 18 | image = image.copy() 19 | mask = mask.copy() 20 | if dataset_type == "densetorch": 21 | sample = ({"image": image, "mask": mask, "names": ("mask",)},) 22 | elif dataset_type == "torchvision": 23 | sample = (image, mask) 24 | return sample 25 | 26 | 27 | def unpack_sample(sample, dataset_type): 28 | if dataset_type == "densetorch": 29 | image = sample["image"] 30 | mask = sample["mask"] 31 | elif dataset_type == "torchvision": 32 | image, mask = sample 33 | return image, mask 34 | 35 | 36 | @pytest.fixture() 37 | def num_stages(): 38 | return random.randint(1, 5) 39 | 40 | 41 | @pytest.fixture() 42 | def crop_size(): 43 | crop_size = random.randint(160, 960) 44 | if crop_size % 2 == 1: 45 | # NOTE: In DenseTorch, the crop is always even. 46 | crop_size -= 1 47 | return crop_size 48 | 49 | 50 | @pytest.fixture() 51 | def shorter_side(): 52 | return random.randint(160, 960) 53 | 54 | 55 | @pytest.fixture() 56 | def low_scale(): 57 | return random.random() 58 | 59 | 60 | @pytest.fixture() 61 | def high_scale(): 62 | return random.random() 63 | 64 | 65 | @pytest.mark.parametrize("augmentations_type", ["densetorch", "albumentations"]) 66 | @pytest.mark.parametrize("dataset_type", ["densetorch", "torchvision"]) 67 | def test_transforms( 68 | augmentations_type, 69 | crop_size, 70 | dataset_type, 71 | num_stages, 72 | shorter_side, 73 | low_scale, 74 | high_scale, 75 | img_mean=(0.5, 0.5, 0.5), 76 | img_std=(0.5, 0.5, 0.5), 77 | img_scale=1.0 / 255, 78 | ignore_label=255, 79 | ): 80 | train_transforms, val_transforms = get_transforms( 81 | crop_size=broadcast(crop_size, num_stages), 82 | shorter_side=broadcast(shorter_side, num_stages), 83 | low_scale=broadcast(low_scale, num_stages), 84 | high_scale=broadcast(high_scale, num_stages), 85 | img_mean=(0.5, 0.5, 0.5), 86 | img_std=(0.5, 0.5, 0.5), 87 | img_scale=1.0 / 255, 88 | ignore_label=255, 89 | num_stages=num_stages, 90 | augmentations_type=augmentations_type, 91 | dataset_type=dataset_type, 92 | ) 93 | assert len(train_transforms) == num_stages 94 | for is_val, transform in zip( 95 | [False] * num_stages + [True], train_transforms + [val_transforms] 96 | ): 97 | image, mask = get_dummy_image_and_mask() 98 | sample = pack_sample(image=image, mask=mask, dataset_type=dataset_type) 99 | output = transform(*sample) 100 | image_output, mask_output = unpack_sample( 101 | sample=output, dataset_type=dataset_type 102 | ) 103 | # Test shape 104 | if not is_val: 105 | assert ( 106 | image_output.shape[-2:] 107 | == mask_output.shape[-2:] 108 | == (crop_size, crop_size) 109 | ) 110 | # Test that the outputs are torch tensors 111 | assert isinstance(image_output, torch.Tensor) 112 | assert isinstance(mask_output, torch.Tensor) 113 | # Test that there are no new segmentation classes, except for probably ignore_label 114 | uq_classes_before = np.unique(mask) 115 | uq_classes_after = np.unique(mask_output.numpy()) 116 | assert ( 117 | len( 118 | np.setdiff1d( 119 | uq_classes_after, uq_classes_before.tolist() + [ignore_label] 120 | ) 121 | ) 122 | == 0 123 | ) 124 | if is_val: 125 | # Test that for validation transformation the output shape has not changed 126 | assert ( 127 | image_output.shape[-2:] 128 | == image.shape[:2] 129 | == mask_output.shape[-2:] 130 | == mask.shape[:2] 131 | ) 132 | # Test that there were no changes to the classes at all 133 | assert all(uq_classes_before == uq_classes_after) 134 | -------------------------------------------------------------------------------- /src/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from densetorch.misc import broadcast 4 | 5 | 6 | def get_arguments(): 7 | """Parse all the arguments provided from the CLI.""" 8 | parser = argparse.ArgumentParser( 9 | description="Arguments for Light-Weight-RefineNet Training Pipeline" 10 | ) 11 | 12 | # Common transformations 13 | parser.add_argument("--img-scale", type=float, default=1.0 / 255) 14 | parser.add_argument( 15 | "--img-mean", type=float, nargs=3, default=(0.485, 0.456, 0.406) 16 | ) 17 | parser.add_argument("--img-std", type=float, nargs=3, default=(0.229, 0.224, 0.225)) 18 | 19 | # Training augmentations 20 | parser.add_argument( 21 | "--augmentations-type", 22 | type=str, 23 | choices=["densetorch", "albumentations"], 24 | default="densetorch", 25 | ) 26 | 27 | # Dataset 28 | parser.add_argument( 29 | "--val-list-path", type=str, default="./data/val.nyu", 30 | ) 31 | parser.add_argument( 32 | "--val-dir", type=str, default="./datasets/nyud/", 33 | ) 34 | parser.add_argument("--val-batch-size", type=int, default=1) 35 | 36 | # Optimisation 37 | parser.add_argument("--random-seed", type=int, default=42) 38 | 39 | # Training / validation setup 40 | parser.add_argument( 41 | "--enc-backbone", type=str, choices=["50", "101", "152", "mbv2"], default="mbv2" 42 | ) 43 | parser.add_argument("--enc-pretrained", type=int, choices=[0, 1], default=1) 44 | parser.add_argument( 45 | "--num-stages", 46 | type=int, 47 | default=3, 48 | help="Number of training stages. All other arguments with nargs='+' must " 49 | "have the number of arguments equal to this value. Otherwise, the given " 50 | "arguments will be broadcasted to have the required length.", 51 | ) 52 | parser.add_argument("--num-classes", type=int, default=40) 53 | parser.add_argument( 54 | "--dataset-type", 55 | type=str, 56 | default="densetorch", 57 | choices=["densetorch", "torchvision"], 58 | ) 59 | parser.add_argument( 60 | "--val-download", 61 | type=int, 62 | choices=[0, 1], 63 | default=0, 64 | help="Only used if dataset_type == torchvision.", 65 | ) 66 | 67 | # Checkpointing configuration 68 | parser.add_argument("--ckpt-dir", type=str, default="./checkpoints/") 69 | parser.add_argument( 70 | "--ckpt-path", 71 | type=str, 72 | default="./checkpoints/checkpoint.pth.tar", 73 | help="Path to the checkpoint file.", 74 | ) 75 | 76 | # Arguments broadcastable across training stages 77 | stage_parser = parser.add_argument_group("stage-parser") 78 | stage_parser.add_argument( 79 | "--crop-size", type=int, nargs="+", default=(500, 500, 500,) 80 | ) 81 | stage_parser.add_argument( 82 | "--shorter-side", type=int, nargs="+", default=(350, 350, 350,) 83 | ) 84 | stage_parser.add_argument( 85 | "--low-scale", type=float, nargs="+", default=(0.5, 0.5, 0.5,) 86 | ) 87 | stage_parser.add_argument( 88 | "--high-scale", type=float, nargs="+", default=(2.0, 2.0, 2.0,) 89 | ) 90 | stage_parser.add_argument( 91 | "--train-list-path", type=str, nargs="+", default=("./data/train.nyu",) 92 | ) 93 | stage_parser.add_argument( 94 | "--train-dir", type=str, nargs="+", default=("./datasets/nyud/",) 95 | ) 96 | stage_parser.add_argument( 97 | "--train-batch-size", type=int, nargs="+", default=(6, 6, 6,) 98 | ) 99 | stage_parser.add_argument( 100 | "--freeze-bn", type=int, choices=[0, 1], nargs="+", default=(1, 1, 1,) 101 | ) 102 | stage_parser.add_argument( 103 | "--epochs-per-stage", type=int, nargs="+", default=(100, 100, 100), 104 | ) 105 | stage_parser.add_argument("--val-every", type=int, nargs="+", default=(5, 5, 5,)) 106 | stage_parser.add_argument( 107 | "--stage-names", 108 | type=str, 109 | nargs="+", 110 | choices=["SBD", "VOC"], 111 | default=("SBD", "VOC",), 112 | help="Only used if dataset_type == torchvision.", 113 | ) 114 | stage_parser.add_argument( 115 | "--train-download", 116 | type=int, 117 | nargs="+", 118 | choices=[0, 1], 119 | default=(0, 0,), 120 | help="Only used if dataset_type == torchvision.", 121 | ) 122 | stage_parser.add_argument( 123 | "--grad-norm", 124 | type=float, 125 | nargs="+", 126 | default=(0.0,), 127 | help="If > 0.0, clip gradients' norm to this value.", 128 | ) 129 | args = parser.parse_args() 130 | # Broadcast all arguments in stage-parser 131 | for group_action in stage_parser._group_actions: 132 | argument_name = group_action.dest 133 | setattr( 134 | args, 135 | argument_name, 136 | broadcast(getattr(args, argument_name), args.num_stages), 137 | ) 138 | return args 139 | -------------------------------------------------------------------------------- /net/LSTM/torch_convlstmv2.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | class ConvLSTMCell(nn.Module): 8 | def __init__(self, input_channels, hidden_channels, kernel_size): 9 | super(ConvLSTMCell, self).__init__() 10 | 11 | assert hidden_channels % 2 == 0 12 | 13 | self.input_channels = input_channels 14 | self.hidden_channels = hidden_channels 15 | self.kernel_size = kernel_size 16 | self.num_features = 4 17 | 18 | self.padding = int((kernel_size - 1) / 2) 19 | 20 | self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) 21 | self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) 22 | self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) 23 | self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) 24 | self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) 25 | self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) 26 | self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) 27 | self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) 28 | 29 | self.Wci = None 30 | self.Wcf = None 31 | self.Wco = None 32 | 33 | def forward(self, x, h, c): 34 | ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci) 35 | cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf) 36 | cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h)) 37 | co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc * self.Wco) 38 | ch = co * torch.tanh(cc) 39 | return ch, cc 40 | 41 | def init_hidden(self, batch_size, hidden, shape): 42 | if self.Wci is None: 43 | self.Wci = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda() 44 | self.Wcf = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda() 45 | self.Wco = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda() 46 | else: 47 | assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!' 48 | assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!' 49 | return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda(), 50 | Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda()) 51 | 52 | 53 | class ConvLSTM(nn.Module): 54 | # input_channels corresponds to the first input feature map 55 | # hidden state is a list of succeeding lstm layers. 56 | def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1]): 57 | super(ConvLSTM, self).__init__() 58 | self.input_channels = [input_channels] + hidden_channels 59 | self.hidden_channels = hidden_channels 60 | self.kernel_size = kernel_size 61 | self.num_layers = len(hidden_channels) 62 | self.step = step 63 | self.effective_step = effective_step 64 | self._all_layers = [] 65 | for i in range(self.num_layers): 66 | name = 'cell{}'.format(i) 67 | cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size) 68 | setattr(self, name, cell) 69 | self._all_layers.append(cell) 70 | 71 | def forward(self, input): 72 | internal_state = [] 73 | outputs = [] 74 | for step in range(self.step): 75 | x = input 76 | for i in range(self.num_layers): 77 | # all cells are initialized in the first step 78 | name = 'cell{}'.format(i) 79 | if step == 0: 80 | bsize, _, height, width = x.size() 81 | (h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i], 82 | shape=(height, width)) 83 | internal_state.append((h, c)) 84 | 85 | # do forward 86 | (h, c) = internal_state[i] 87 | x, new_c = getattr(self, name)(x, h, c) 88 | internal_state[i] = (x, new_c) 89 | # only record effective steps 90 | if step in self.effective_step: 91 | outputs.append(x) 92 | 93 | return outputs, (x, new_c) 94 | 95 | 96 | if __name__ == '__main__': 97 | # gradient check 98 | convlstm = ConvLSTM(input_channels=512, hidden_channels=[128, 64, 64, 32, 32], kernel_size=3, step=5, 99 | effective_step=[4]).cuda() 100 | loss_fn = torch.nn.MSELoss() 101 | 102 | input = Variable(torch.randn(1, 512, 64, 32)).cuda() 103 | target = Variable(torch.randn(1, 32, 64, 32)).double().cuda() 104 | 105 | output = convlstm(input) 106 | output = output[0][0].double() 107 | res = torch.autograd.gradcheck(loss_fn, (output, target), eps=1e-6, raise_exception=True) 108 | print(res) -------------------------------------------------------------------------------- /net/Ours/lib/non_local_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | if sub_sample: 54 | self.g = nn.Sequential(self.g, max_pool_layer) 55 | self.phi = max_pool_layer 56 | 57 | def forward(self, x, return_nl_map=False): 58 | """ 59 | :param x: (b, c, t, h, w) 60 | :param return_nl_map: if True return z, nl_map, else only return z. 61 | :return: 62 | """ 63 | 64 | batch_size = x.size(0) 65 | 66 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 67 | 68 | g_x = g_x.permute(0, 2, 1) 69 | 70 | theta_x = x.view(batch_size, self.in_channels, -1) 71 | theta_x = theta_x.permute(0, 2, 1) 72 | 73 | if self.sub_sample: 74 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1) 75 | else: 76 | phi_x = x.view(batch_size, self.in_channels, -1) 77 | 78 | f = torch.matmul(theta_x, phi_x) 79 | f_div_C = F.softmax(f, dim=-1) 80 | 81 | # if self.store_last_batch_nl_map: 82 | # self.nl_map = f_div_C 83 | 84 | y = torch.matmul(f_div_C, g_x) 85 | y = y.permute(0, 2, 1).contiguous() 86 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 87 | W_y = self.W(y) 88 | z = W_y + x 89 | 90 | if return_nl_map: 91 | return z, f_div_C 92 | return z 93 | 94 | 95 | class NONLocalBlock1D(_NonLocalBlockND): 96 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 97 | super(NONLocalBlock1D, self).__init__(in_channels, 98 | inter_channels=inter_channels, 99 | dimension=1, sub_sample=sub_sample, 100 | bn_layer=bn_layer) 101 | 102 | 103 | class NONLocalBlock2D(_NonLocalBlockND): 104 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 105 | super(NONLocalBlock2D, self).__init__(in_channels, 106 | inter_channels=inter_channels, 107 | dimension=2, sub_sample=sub_sample, 108 | bn_layer=bn_layer) 109 | 110 | 111 | class NONLocalBlock3D(_NonLocalBlockND): 112 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 113 | super(NONLocalBlock3D, self).__init__(in_channels, 114 | inter_channels=inter_channels, 115 | dimension=3, sub_sample=sub_sample, 116 | bn_layer=bn_layer) 117 | 118 | 119 | if __name__ == '__main__': 120 | import torch 121 | 122 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 123 | img = torch.zeros(2, 3, 20) 124 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 125 | out = net(img) 126 | print(out.size()) 127 | 128 | img = torch.zeros(2, 3, 20, 20) 129 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 130 | out = net(img) 131 | print(out.size()) 132 | 133 | img = torch.randn(2, 3, 8, 20, 20) 134 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 135 | out = net(img) 136 | print(out.size()) 137 | 138 | 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /net/Ours/lib/non_local_dot_product.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | 56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | 59 | if sub_sample: 60 | self.g = nn.Sequential(self.g, max_pool_layer) 61 | self.phi = nn.Sequential(self.phi, max_pool_layer) 62 | 63 | def forward(self, x, return_nl_map=False): 64 | """ 65 | :param x: (b, c, t, h, w) 66 | :param return_nl_map: if True return z, nl_map, else only return z. 67 | :return: 68 | """ 69 | 70 | batch_size = x.size(0) 71 | 72 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 73 | g_x = g_x.permute(0, 2, 1) 74 | 75 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 76 | theta_x = theta_x.permute(0, 2, 1) 77 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 78 | f = torch.matmul(theta_x, phi_x) 79 | N = f.size(-1) 80 | f_div_C = f / N 81 | 82 | y = torch.matmul(f_div_C, g_x) 83 | y = y.permute(0, 2, 1).contiguous() 84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 85 | W_y = self.W(y) 86 | z = W_y + x 87 | 88 | if return_nl_map: 89 | return z, f_div_C 90 | return z 91 | 92 | 93 | class NONLocalBlock1D(_NonLocalBlockND): 94 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 95 | super(NONLocalBlock1D, self).__init__(in_channels, 96 | inter_channels=inter_channels, 97 | dimension=1, sub_sample=sub_sample, 98 | bn_layer=bn_layer) 99 | 100 | 101 | class NONLocalBlock2D(_NonLocalBlockND): 102 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 103 | super(NONLocalBlock2D, self).__init__(in_channels, 104 | inter_channels=inter_channels, 105 | dimension=2, sub_sample=sub_sample, 106 | bn_layer=bn_layer) 107 | 108 | 109 | class NONLocalBlock3D(_NonLocalBlockND): 110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 111 | super(NONLocalBlock3D, self).__init__(in_channels, 112 | inter_channels=inter_channels, 113 | dimension=3, sub_sample=sub_sample, 114 | bn_layer=bn_layer) 115 | 116 | 117 | if __name__ == '__main__': 118 | import torch 119 | 120 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 121 | img = torch.zeros(2, 3, 20) 122 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 123 | out = net(img) 124 | print(out.size()) 125 | 126 | img = torch.zeros(2, 3, 20, 20) 127 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 128 | out = net(img) 129 | print(out.size()) 130 | 131 | img = torch.randn(2, 3, 8, 20, 20) 132 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 133 | out = net(img) 134 | print(out.size()) 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /net/Ours/lib/non_local_embedded_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | """ 9 | :param in_channels: 10 | :param inter_channels: 11 | :param dimension: 12 | :param sub_sample: 13 | :param bn_layer: 14 | """ 15 | 16 | super(_NonLocalBlockND, self).__init__() 17 | 18 | assert dimension in [1, 2, 3] 19 | 20 | self.dimension = dimension 21 | self.sub_sample = sub_sample 22 | 23 | self.in_channels = in_channels 24 | self.inter_channels = inter_channels 25 | 26 | if self.inter_channels is None: 27 | self.inter_channels = in_channels // 2 28 | if self.inter_channels == 0: 29 | self.inter_channels = 1 30 | 31 | if dimension == 3: 32 | conv_nd = nn.Conv3d 33 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 34 | bn = nn.BatchNorm3d 35 | elif dimension == 2: 36 | conv_nd = nn.Conv2d 37 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 38 | bn = nn.BatchNorm2d 39 | else: 40 | conv_nd = nn.Conv1d 41 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 42 | bn = nn.BatchNorm1d 43 | 44 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | 47 | if bn_layer: 48 | self.W = nn.Sequential( 49 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 50 | kernel_size=1, stride=1, padding=0), 51 | bn(self.in_channels) 52 | ) 53 | nn.init.constant_(self.W[1].weight, 0) 54 | nn.init.constant_(self.W[1].bias, 0) 55 | else: 56 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | nn.init.constant_(self.W.weight, 0) 59 | nn.init.constant_(self.W.bias, 0) 60 | 61 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 62 | kernel_size=1, stride=1, padding=0) 63 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 64 | kernel_size=1, stride=1, padding=0) 65 | 66 | if sub_sample: 67 | self.g = nn.Sequential(self.g, max_pool_layer) 68 | self.phi = nn.Sequential(self.phi, max_pool_layer) 69 | 70 | def forward(self, x, return_nl_map=False): 71 | """ 72 | :param x: (b, c, t, h, w) 73 | :param return_nl_map: if True return z, nl_map, else only return z. 74 | :return: 75 | """ 76 | 77 | batch_size = x.size(0) 78 | 79 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 80 | g_x = g_x.permute(0, 2, 1) 81 | 82 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 83 | theta_x = theta_x.permute(0, 2, 1) 84 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 85 | f = torch.matmul(theta_x, phi_x) 86 | f_div_C = F.softmax(f, dim=-1) 87 | 88 | y = torch.matmul(f_div_C, g_x) 89 | y = y.permute(0, 2, 1).contiguous() 90 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 91 | W_y = self.W(y) 92 | z = W_y + x 93 | 94 | if return_nl_map: 95 | return z, f_div_C 96 | return z 97 | 98 | 99 | class NONLocalBlock1D(_NonLocalBlockND): 100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 101 | super(NONLocalBlock1D, self).__init__(in_channels, 102 | inter_channels=inter_channels, 103 | dimension=1, sub_sample=sub_sample, 104 | bn_layer=bn_layer) 105 | 106 | 107 | class NONLocalBlock2D(_NonLocalBlockND): 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 109 | super(NONLocalBlock2D, self).__init__(in_channels, 110 | inter_channels=inter_channels, 111 | dimension=2, sub_sample=sub_sample, 112 | bn_layer=bn_layer,) 113 | 114 | 115 | class NONLocalBlock3D(_NonLocalBlockND): 116 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 117 | super(NONLocalBlock3D, self).__init__(in_channels, 118 | inter_channels=inter_channels, 119 | dimension=3, sub_sample=sub_sample, 120 | bn_layer=bn_layer,) 121 | 122 | 123 | if __name__ == '__main__': 124 | import torch 125 | 126 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 127 | img = torch.zeros(2, 3, 20) 128 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 129 | out = net(img) 130 | print(out.size()) 131 | 132 | img = torch.zeros(2, 3, 20, 20) 133 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 134 | out = net(img) 135 | print(out.size()) 136 | 137 | img = torch.randn(2, 3, 8, 20, 20) 138 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 139 | out = net(img) 140 | print(out.size()) 141 | 142 | 143 | -------------------------------------------------------------------------------- /net/LSTM/bottlenecklstm.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | class BottleneckLSTMCell(nn.Module): 8 | """ Creates a LSTM layer cell 9 | Arguments: 10 | input_channels : variable used to contain value of number of channels in input 11 | hidden_channels : variable used to contain value of number of channels in the hidden state of LSTM cell 12 | """ 13 | def __init__(self, input_channels, hidden_channels): 14 | super(BottleneckLSTMCell, self).__init__() 15 | 16 | assert hidden_channels % 2 == 0 17 | 18 | self.input_channels = int(input_channels) 19 | self.hidden_channels = int(hidden_channels) 20 | self.num_features = 4 21 | self.W = nn.Conv2d(in_channels=self.input_channels, out_channels=self.input_channels, kernel_size=3, groups=self.input_channels, stride=1, padding=1) 22 | self.Wy = nn.Conv2d(int(self.input_channels+self.hidden_channels), self.hidden_channels, kernel_size=1) 23 | self.Wi = nn.Conv2d(self.hidden_channels, self.hidden_channels, 3, 1, 1, groups=self.hidden_channels, bias=False) 24 | self.Wbi = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False) 25 | self.Wbf = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False) 26 | self.Wbc = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False) 27 | self.Wbo = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False) 28 | self.ac = nn.ReLU6() 29 | 30 | # self.Wci = None 31 | # self.Wcf = None 32 | # self.Wco = None 33 | # logging.info("Initializing weights of lstm") 34 | self._initialize_weights() 35 | 36 | def _initialize_weights(self): 37 | """ 38 | Returns: 39 | initialized weights of the model 40 | """ 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | nn.init.xavier_uniform_(m.weight) 44 | if m.bias is not None: 45 | m.bias.data.zero_() 46 | elif isinstance(m, nn.BatchNorm2d): 47 | m.weight.data.fill_(1) 48 | m.bias.data.zero_() 49 | 50 | def forward(self, x, h, c): #implemented as mentioned in paper here the only difference is Wbi, Wbf, Wbc & Wbo are commuted all together in paper 51 | """ 52 | Arguments: 53 | x : input tensor 54 | h : hidden state tensor 55 | c : cell state tensor 56 | Returns: 57 | output tensor after LSTM cell 58 | """ 59 | x = self.W(x) 60 | y = torch.cat((x, h),1) #concatenate input and hidden layers 61 | i = self.Wy(y) #reduce to hidden layer size 62 | b = self.Wi(i) #depth wise 3*3 63 | ci = torch.sigmoid(self.Wbi(b)) 64 | cf = torch.sigmoid(self.Wbf(b)) 65 | cc = cf * c + ci * F.tanh(self.Wbc(b)) 66 | co = torch.sigmoid(self.Wbo(b)) 67 | ch = co * F.tanh(cc) 68 | return ch, cc 69 | 70 | def init_hidden(self, batch_size, hidden, shape): 71 | """ 72 | Arguments: 73 | batch_size : an int variable having value of batch size while training 74 | hidden : an int variable having value of number of channels in hidden state 75 | shape : an array containing shape of the hidden and cell state 76 | Returns: 77 | cell state and hidden state 78 | """ 79 | # if self.Wci is None: 80 | # self.Wci = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda() 81 | # self.Wcf = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda() 82 | # self.Wco = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda() 83 | # else: 84 | # assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!' 85 | # assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!' 86 | return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda(), 87 | Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda() 88 | ) 89 | 90 | class BottleneckLSTM(nn.Module): 91 | def __init__(self, input_channels, hidden_channels, height, width): 92 | """ Creates Bottleneck LSTM layer 93 | Arguments: 94 | input_channels : variable having value of number of channels of input to this layer 95 | hidden_channels : variable having value of number of channels of hidden state of this layer 96 | height : an int variable having value of height of the input 97 | width : an int variable having value of width of the input 98 | batch_size : an int variable having value of batch_size of the input 99 | Returns: 100 | Output tensor of LSTM layer 101 | """ 102 | super(BottleneckLSTM, self).__init__() 103 | self.input_channels = int(input_channels) 104 | self.hidden_channels = int(hidden_channels) 105 | self.cell = BottleneckLSTMCell(self.input_channels, self.hidden_channels) 106 | 107 | # self.hidden_state = h 108 | # self.cell_state = c 109 | # print(h.size(),c.size() 110 | 111 | def forward(self, input): 112 | # new_h, new_c = self.cell(input, self.hidden_state, self.cell_state) 113 | # self.hidden_state = new_h 114 | # self.cell_state = new_c 115 | batch_size,seq_len,_,height, width = input.size() 116 | h, c = self.cell.init_hidden(batch_size, hidden=self.hidden_channels, shape=(height, width)) 117 | # h, c = self.cell(input[:,0], self.hidden_state, self.cell_state) 118 | output_inner = [] 119 | for t in range(seq_len): 120 | h, c = self.cell(input[:, t, :, :, :], h, c) 121 | output_inner.append(h) 122 | layer_output = torch.stack(output_inner, dim=1) 123 | return layer_output -------------------------------------------------------------------------------- /net/Ours/lib/non_local_concatenation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | 56 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | 59 | self.concat_project = nn.Sequential( 60 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), 61 | nn.ReLU() 62 | ) 63 | 64 | if sub_sample: 65 | self.g = nn.Sequential(self.g, max_pool_layer) 66 | self.phi = nn.Sequential(self.phi, max_pool_layer) 67 | 68 | def forward(self, x, return_nl_map=False): 69 | ''' 70 | :param x: (b, c, t, h, w) 71 | :param return_nl_map: if True return z, nl_map, else only return z. 72 | :return: 73 | ''' 74 | 75 | batch_size = x.size(0) 76 | 77 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 78 | g_x = g_x.permute(0, 2, 1) 79 | 80 | # (b, c, N, 1) 81 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 82 | # (b, c, 1, N) 83 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 84 | 85 | h = theta_x.size(2) 86 | w = phi_x.size(3) 87 | theta_x = theta_x.repeat(1, 1, 1, w) 88 | phi_x = phi_x.repeat(1, 1, h, 1) 89 | 90 | concat_feature = torch.cat([theta_x, phi_x], dim=1) 91 | f = self.concat_project(concat_feature) 92 | b, _, h, w = f.size() 93 | f = f.view(b, h, w) 94 | 95 | N = f.size(-1) 96 | f_div_C = f / N 97 | 98 | y = torch.matmul(f_div_C, g_x) 99 | y = y.permute(0, 2, 1).contiguous() 100 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 101 | W_y = self.W(y) 102 | z = W_y + x 103 | 104 | if return_nl_map: 105 | return z, f_div_C 106 | return z 107 | 108 | 109 | class NONLocalBlock1D(_NonLocalBlockND): 110 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 111 | super(NONLocalBlock1D, self).__init__(in_channels, 112 | inter_channels=inter_channels, 113 | dimension=1, sub_sample=sub_sample, 114 | bn_layer=bn_layer) 115 | 116 | 117 | class NONLocalBlock2D(_NonLocalBlockND): 118 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 119 | super(NONLocalBlock2D, self).__init__(in_channels, 120 | inter_channels=inter_channels, 121 | dimension=2, sub_sample=sub_sample, 122 | bn_layer=bn_layer) 123 | 124 | 125 | class NONLocalBlock3D(_NonLocalBlockND): 126 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,): 127 | super(NONLocalBlock3D, self).__init__(in_channels, 128 | inter_channels=inter_channels, 129 | dimension=3, sub_sample=sub_sample, 130 | bn_layer=bn_layer) 131 | 132 | 133 | if __name__ == '__main__': 134 | import torch 135 | 136 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 137 | img = torch.zeros(2, 3, 20) 138 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 139 | out = net(img) 140 | print(out.size()) 141 | 142 | img = torch.zeros(2, 3, 20, 20) 143 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 144 | out = net(img) 145 | print(out.size()) 146 | 147 | img = torch.randn(2, 3, 8, 20, 20) 148 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 149 | out = net(img) 150 | print(out.size()) 151 | -------------------------------------------------------------------------------- /net/LSTM/torch_convlstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class ConvLSTMCell(nn.Module): 7 | 8 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias): 9 | """ 10 | Initialize ConvLSTM cell. 11 | Parameters 12 | ---------- 13 | input_size: (int, int) 14 | Height and width of input tensor as (height, width). 15 | input_dim: int 16 | Number of channels of input tensor. 17 | hidden_dim: int 18 | Number of channels of hidden state. 19 | kernel_size: (int, int) 20 | Size of the convolutional kernel. 21 | bias: bool 22 | Whether or not to add the bias. 23 | """ 24 | super(ConvLSTMCell, self).__init__() 25 | 26 | self.height, self.width = input_size 27 | self.input_dim = input_dim 28 | self.hidden_dim = hidden_dim 29 | 30 | self.kernel_size = kernel_size 31 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 32 | self.bias = bias 33 | 34 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 35 | out_channels=4 * self.hidden_dim, 36 | kernel_size=self.kernel_size, 37 | padding=self.padding, 38 | bias=self.bias) 39 | 40 | def forward(self, input, prev_state): 41 | h_prev, c_prev = prev_state 42 | combined = torch.cat((input, h_prev), dim=1) # concatenate along channel axis 43 | 44 | combined_conv = self.conv(combined) 45 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 46 | 47 | i = F.sigmoid(cc_i) 48 | f = F.sigmoid(cc_f) 49 | o = F.sigmoid(cc_o) 50 | g = F.tanh(cc_g) 51 | 52 | c_cur = f * c_prev + i * g 53 | h_cur = o * F.tanh(c_cur) 54 | 55 | return h_cur, c_cur 56 | 57 | def init_hidden(self, batch_size, cuda=True): 58 | state = (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)), 59 | Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width))) 60 | if cuda: 61 | state = (state[0].cuda(), state[1].cuda()) 62 | return state 63 | 64 | class ConvLSTM(nn.Module): 65 | 66 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers, 67 | batch_first=False, bias=True, return_all_layers=False): 68 | super(ConvLSTM, self).__init__() 69 | 70 | self._check_kernel_size_consistency(kernel_size) 71 | 72 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 73 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 74 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 75 | if not len(kernel_size) == len(hidden_dim) == num_layers: 76 | raise ValueError('Inconsistent list length.') 77 | 78 | self.height, self.width = input_size 79 | 80 | self.input_dim = input_dim 81 | self.hidden_dim = hidden_dim 82 | self.kernel_size = kernel_size 83 | self.num_layers = num_layers 84 | self.batch_first = batch_first 85 | self.bias = bias 86 | self.return_all_layers = return_all_layers 87 | 88 | cell_list = [] 89 | for i in range(0, self.num_layers): 90 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1] 91 | 92 | cell_list.append(ConvLSTMCell(input_size=(self.height, self.width), 93 | input_dim=cur_input_dim, 94 | hidden_dim=self.hidden_dim[i], 95 | kernel_size=self.kernel_size[i], 96 | bias=self.bias)) 97 | 98 | self.cell_list = nn.ModuleList(cell_list) 99 | 100 | def forward(self, input, hidden_state=None): 101 | """ 102 | Parameters 103 | ---------- 104 | input_tensor: todo 105 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 106 | hidden_state: todo 107 | None. todo implement stateful 108 | Returns 109 | ------- 110 | last_state_list, layer_output 111 | """ 112 | if not self.batch_first: 113 | # (t, b, c, h, w) -> (b, t, c, h, w) 114 | input = input.permute(1, 0, 2, 3, 4) 115 | 116 | 117 | if hidden_state is None: 118 | hidden_state = self.get_init_states(batch_size=input.size(0)) 119 | 120 | layer_output_list = [] 121 | last_state_list = [] 122 | 123 | seq_len = input.size(1) 124 | cur_layer_input = input 125 | 126 | for layer_idx in range(self.num_layers): 127 | h, c = hidden_state[layer_idx] 128 | output_inner = [] 129 | for t in range(seq_len): 130 | h, c = self.cell_list[layer_idx](input=cur_layer_input[:, t, :, :, :], 131 | prev_state=[h, c]) 132 | output_inner.append(h) 133 | 134 | layer_output = torch.stack(output_inner, dim=1) 135 | # # print(layer_output.size()) 136 | # return layer_output 137 | cur_layer_input = layer_output 138 | 139 | layer_output_list.append(layer_output) 140 | last_state_list.append((h, c)) 141 | 142 | 143 | layer_output = torch.cat(layer_output_list,dim=1) 144 | if not self.batch_first: 145 | layer_output = layer_output.permute(1, 0, 2, 3, 4) 146 | 147 | return layer_output 148 | 149 | def get_init_states(self, batch_size, cuda=True): 150 | init_states = [] 151 | for i in range(self.num_layers): 152 | init_states.append(self.cell_list[i].init_hidden(batch_size, cuda)) 153 | return init_states 154 | 155 | @staticmethod 156 | def _check_kernel_size_consistency(kernel_size): 157 | if not (isinstance(kernel_size, tuple) or 158 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 159 | raise ValueError('`kernel_size` must be tuple or list of tuples') 160 | 161 | @staticmethod 162 | def _extend_for_multilayer(param, num_layers): 163 | if not isinstance(param, list): 164 | param = [param] * num_layers 165 | return param -------------------------------------------------------------------------------- /net/LSTM/grouplstm.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | class GroupLSTMCell(nn.Module): 8 | """ Creates a LSTM layer cell 9 | Arguments: 10 | input_channels : variable used to contain value of number of channels in input 11 | hidden_channels : variable used to contain value of number of channels in the hidden state of LSTM cell 12 | """ 13 | def __init__(self, input_channels, hidden_channels, ac='tanh'): 14 | super(GroupLSTMCell, self).__init__() 15 | 16 | assert hidden_channels % 2 == 0 17 | 18 | self.input_channels = int(input_channels) 19 | self.hidden_channels = int(hidden_channels) 20 | self.num_features = 4 21 | self.W = nn.Conv2d(in_channels=self.input_channels, out_channels=self.input_channels, kernel_size=3, groups=self.input_channels, stride=1, padding=1) 22 | self.Wy = nn.Conv2d(int(self.input_channels+self.hidden_channels), self.hidden_channels, kernel_size=1) 23 | self.Wi = nn.Conv2d(self.hidden_channels, self.hidden_channels, 3, 1, 1, groups=self.hidden_channels, bias=False) 24 | self.Wbi = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False) 25 | self.Wbf = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False) 26 | self.Wbc = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False) 27 | self.Wbo = nn.Conv2d(self.hidden_channels, self.hidden_channels, 1, 1, 0, bias=False) 28 | if ac=='tanh': 29 | self._activation = F.tanh 30 | else: 31 | self._activation = F.relu 32 | 33 | # self.Wci = None 34 | # self.Wcf = None 35 | # self.Wco = None 36 | # logging.info("Initializing weights of lstm") 37 | self._initialize_weights() 38 | 39 | def _initialize_weights(self): 40 | """ 41 | Returns: 42 | initialized weights of the model 43 | """ 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | nn.init.xavier_uniform_(m.weight) 47 | if m.bias is not None: 48 | m.bias.data.zero_() 49 | elif isinstance(m, nn.BatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | 53 | def forward(self, x, h, c): #implemented as mentioned in paper here the only difference is Wbi, Wbf, Wbc & Wbo are commuted all together in paper 54 | """ 55 | Arguments: 56 | x : input tensor 57 | h : hidden state tensor 58 | c : cell state tensor 59 | Returns: 60 | output tensor after LSTM cell 61 | """ 62 | x = self.W(x) 63 | y = torch.cat((x, h),1) #concatenate input and hidden layers 64 | i = self.Wy(y) #reduce to hidden layer size 65 | b = self.Wi(i) #depth wise 3*3 66 | ci = torch.sigmoid(self.Wbi(b)) 67 | cf = torch.sigmoid(self.Wbf(b)) 68 | cc = cf * c + ci * self._activation(self.Wbc(b)) 69 | co = torch.sigmoid(self.Wbo(b)) 70 | ch = co * self._activation(cc) 71 | o = ch+cc 72 | return o, ch, cc 73 | 74 | def init_hidden(self, batch_size, hidden, shape): 75 | """ 76 | Arguments: 77 | batch_size : an int variable having value of batch size while training 78 | hidden : an int variable having value of number of channels in hidden state 79 | shape : an array containing shape of the hidden and cell state 80 | Returns: 81 | cell state and hidden state 82 | """ 83 | # if self.Wci is None: 84 | # self.Wci = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda() 85 | # self.Wcf = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda() 86 | # self.Wco = Variable(torch.zeros(1, hidden, shape[0], shape[1])).cuda() 87 | # else: 88 | # assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!' 89 | # assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!' 90 | return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda(), 91 | Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda() 92 | ) 93 | 94 | class GroupLSTM(nn.Module): 95 | def __init__(self, input_channels, hidden_channels, height, width, group): 96 | """ Creates Bottleneck LSTM layer 97 | Arguments: 98 | input_channels : variable having value of number of channels of input to this layer 99 | hidden_channels : variable having value of number of channels of hidden state of this layer 100 | height : an int variable having value of height of the input 101 | width : an int variable having value of width of the input 102 | batch_size : an int variable having value of batch_size of the input 103 | Returns: 104 | Output tensor of LSTM layer 105 | """ 106 | super(GroupLSTM, self).__init__() 107 | assert input_channels%group==0 and hidden_channels%group==0 108 | self.input_channels = int(input_channels) 109 | self.hidden_channels = int(hidden_channels) 110 | self.group = group 111 | self.cell = GroupLSTMCell(self.input_channels//self.group, self.hidden_channels//self.group) 112 | 113 | # self.hidden_state = h 114 | # self.cell_state = c 115 | # print(h.size(),c.size() 116 | 117 | def forward(self, input): 118 | # new_h, new_c = self.cell(input, self.hidden_state, self.cell_state) 119 | # self.hidden_state = new_h 120 | # self.cell_state = new_c 121 | batch_size,seq_len,channels,height, width = input.size() 122 | split_inputs = torch.split(input, channels//self.group, dim=2) 123 | outputs = [] 124 | for split_input in split_inputs: 125 | h, c = self.cell.init_hidden(batch_size, hidden=self.hidden_channels//self.group, shape=(height, width)) 126 | # h, c = self.cell(input[:,0], self.hidden_state, self.cell_state) 127 | output_inner = [] 128 | for t in range(seq_len): 129 | o, h, c = self.cell(split_input[:, t, :, :, :], h, c) 130 | output_inner.append(o) 131 | layer_output = torch.stack(output_inner, dim=1) 132 | outputs.append(layer_output) 133 | outputs = torch.cat(outputs,dim=1) 134 | return layer_output -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # general libs 2 | import logging 3 | import numpy as np 4 | 5 | # pytorch libs 6 | import torch 7 | import torch.nn as nn 8 | 9 | # densetorch wrapper 10 | import densetorch as dt 11 | 12 | # configuration for light-weight refinenet 13 | from arguments import get_arguments 14 | from data import get_datasets, get_transforms 15 | from network import get_segmenter 16 | from optimisers import get_optimisers, get_lr_schedulers 17 | 18 | 19 | def setup_network(args, device): 20 | logger = logging.getLogger(__name__) 21 | segmenter = get_segmenter( 22 | enc_backbone=args.enc_backbone, 23 | enc_pretrained=args.enc_pretrained, 24 | num_classes=args.num_classes, 25 | ).to(device) 26 | if device == "cuda": 27 | segmenter = nn.DataParallel(segmenter) 28 | logger.info( 29 | " Loaded Segmenter {}, ImageNet-Pre-Trained={}, #PARAMS={:3.2f}M".format( 30 | args.enc_backbone, 31 | args.enc_pretrained, 32 | dt.misc.compute_params(segmenter) / 1e6, 33 | ) 34 | ) 35 | training_loss = nn.CrossEntropyLoss(ignore_index=args.ignore_label).to(device) 36 | validation_loss = dt.engine.MeanIoU(num_classes=args.num_classes) 37 | return segmenter, training_loss, validation_loss 38 | 39 | 40 | def setup_checkpoint_and_maybe_restore(args, model, optimisers, schedulers): 41 | saver = dt.misc.Saver( 42 | args=vars(args), 43 | ckpt_dir=args.ckpt_dir, 44 | best_val=0, 45 | condition=lambda x, y: x > y, 46 | ) # keep checkpoint with the best validation score 47 | ( 48 | epoch_start, 49 | _, 50 | model_state_dict, 51 | optims_state_dict, 52 | scheds_state_dict, 53 | ) = saver.maybe_load( 54 | ckpt_path=args.ckpt_path, 55 | keys_to_load=["epoch", "best_val", "model", "optimisers", "schedulers"], 56 | ) 57 | if epoch_start is None: 58 | epoch_start = 0 59 | dt.misc.load_state_dict(model, model_state_dict) 60 | if optims_state_dict is not None: 61 | for optim, optim_state_dict in zip(optimisers, optims_state_dict): 62 | optim.load_state_dict(optim_state_dict) 63 | if scheds_state_dict is not None: 64 | for sched, sched_state_dict in zip(schedulers, scheds_state_dict): 65 | sched.load_state_dict(sched_state_dict) 66 | return saver, epoch_start 67 | 68 | 69 | def setup_data_loaders(args): 70 | train_transforms, val_transforms = get_transforms( 71 | crop_size=args.crop_size, 72 | shorter_side=args.shorter_side, 73 | low_scale=args.low_scale, 74 | high_scale=args.high_scale, 75 | img_mean=args.img_mean, 76 | img_std=args.img_std, 77 | img_scale=args.img_scale, 78 | ignore_label=args.ignore_label, 79 | num_stages=args.num_stages, 80 | augmentations_type=args.augmentations_type, 81 | dataset_type=args.dataset_type, 82 | ) 83 | train_sets, val_set = get_datasets( 84 | train_dir=args.train_dir, 85 | val_dir=args.val_dir, 86 | train_list_path=args.train_list_path, 87 | val_list_path=args.val_list_path, 88 | train_transforms=train_transforms, 89 | val_transforms=val_transforms, 90 | masks_names=("segm",), 91 | dataset_type=args.dataset_type, 92 | stage_names=args.stage_names, 93 | train_download=args.train_download, 94 | val_download=args.val_download, 95 | ) 96 | train_loaders, val_loader = dt.data.get_loaders( 97 | train_batch_size=args.train_batch_size, 98 | val_batch_size=args.val_batch_size, 99 | train_set=train_sets, 100 | val_set=val_set, 101 | num_stages=args.num_stages, 102 | ) 103 | return train_loaders, val_loader 104 | 105 | 106 | def setup_optimisers_and_schedulers(args, model): 107 | optimisers = get_optimisers( 108 | model=model, 109 | enc_optim_type=args.enc_optim_type, 110 | enc_lr=args.enc_lr, 111 | enc_weight_decay=args.enc_weight_decay, 112 | enc_momentum=args.enc_momentum, 113 | dec_optim_type=args.dec_optim_type, 114 | dec_lr=args.dec_lr, 115 | dec_weight_decay=args.dec_weight_decay, 116 | dec_momentum=args.dec_momentum, 117 | ) 118 | schedulers = get_lr_schedulers( 119 | enc_optim=optimisers[0], 120 | dec_optim=optimisers[1], 121 | enc_lr_gamma=args.enc_lr_gamma, 122 | dec_lr_gamma=args.dec_lr_gamma, 123 | enc_scheduler_type=args.enc_scheduler_type, 124 | dec_scheduler_type=args.dec_scheduler_type, 125 | epochs_per_stage=args.epochs_per_stage, 126 | ) 127 | return optimisers, schedulers 128 | 129 | 130 | def main(): 131 | args = get_arguments() 132 | logger = logging.getLogger(__name__) 133 | torch.backends.cudnn.deterministic = True 134 | dt.misc.set_seed(args.random_seed) 135 | device = "cuda" if torch.cuda.is_available() else "cpu" 136 | # Network 137 | segmenter, training_loss, validation_loss = setup_network(args, device=device) 138 | # Data 139 | train_loaders, val_loader = setup_data_loaders(args) 140 | # Optimisers 141 | optimisers, schedulers = setup_optimisers_and_schedulers(args, model=segmenter) 142 | # Checkpoint 143 | saver, restart_epoch = setup_checkpoint_and_maybe_restore( 144 | args, model=segmenter, optimisers=optimisers, schedulers=schedulers, 145 | ) 146 | # Calculate from which stage and which epoch to restart the training 147 | total_epoch = restart_epoch 148 | all_epochs = np.cumsum(args.epochs_per_stage) 149 | restart_stage = sum(restart_epoch >= all_epochs) 150 | if restart_stage > 0: 151 | restart_epoch -= all_epochs[restart_stage - 1] 152 | for stage in range(restart_stage, args.num_stages): 153 | if stage > restart_stage: 154 | restart_epoch = 0 155 | for epoch in range(restart_epoch, args.epochs_per_stage[stage]): 156 | logger.info(f"Training: stage {stage} epoch {epoch}") 157 | dt.engine.train( 158 | model=segmenter, 159 | opts=optimisers, 160 | crits=training_loss, 161 | dataloader=train_loaders[stage], 162 | freeze_bn=args.freeze_bn[stage], 163 | grad_norm=args.grad_norm[stage], 164 | ) 165 | total_epoch += 1 166 | for scheduler in schedulers: 167 | scheduler.step(total_epoch) 168 | if (epoch + 1) % args.val_every[stage] == 0: 169 | logger.info(f"Validation: stage {stage} epoch {epoch}") 170 | vals = dt.engine.validate( 171 | model=segmenter, metrics=validation_loss, dataloader=val_loader, 172 | ) 173 | saver.maybe_save( 174 | new_val=vals, 175 | dict_to_save={ 176 | "model": segmenter.state_dict(), 177 | "epoch": total_epoch, 178 | "optimisers": [ 179 | optimiser.state_dict() for optimiser in optimisers 180 | ], 181 | "schedulers": [ 182 | scheduler.state_dict() for scheduler in schedulers 183 | ], 184 | }, 185 | ) 186 | 187 | 188 | if __name__ == "__main__": 189 | logging.basicConfig( 190 | format="%(asctime)s :: %(levelname)s :: %(name)s :: %(message)s", 191 | level=logging.INFO, 192 | ) 193 | main() 194 | -------------------------------------------------------------------------------- /dataset/Endovis2017.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | import json 5 | import math 6 | import time 7 | import numpy as np 8 | import sys,random 9 | from PIL import Image 10 | sys.path.append('/raid/wjc/code/RealtimeSegmentation/') 11 | 12 | import torch 13 | import torch.utils.data as data 14 | import torch.nn.functional as F 15 | 16 | # from utils.image import get_border, get_affine_transform, affine_transform, color_aug 17 | 18 | import torchvision.transforms.functional as TF 19 | from torchvision import transforms 20 | MEAN = [0.40789654, 0.44719302, 0.47026115] 21 | STD = [0.28863828, 0.27408164, 0.27809835] 22 | 23 | # Folds = [[1,3],[2,6],[4,8],[5,7]] 24 | Folds = [[1,3],[2,5],[4,8],[6,7]] 25 | ins_types = ['Bipolar_Forceps','Prograsp_Forcep','Large_Needle_Driver','Vessel_Sealer', 'Grasping_Retractor', 26 | 'Monopolar_Curved_Scissors','Other'] 27 | 28 | 29 | class endovis2017(data.Dataset): 30 | def __init__(self, split, t=1, fold=0, rate=4, tag='part', global_n=0, test=False): 31 | super(endovis2017, self).__init__() 32 | self.split = split 33 | self.mean = np.array(MEAN, dtype=np.float32)[None, None, :] 34 | self.std = np.array(STD, dtype=np.float32)[None, None, :] 35 | self.img_size = {'h': 512, 'w': 640} 36 | self.t = t 37 | self.tag = tag 38 | self.test = test 39 | self.class_num = 7 if tag=='type' else 4 40 | # /raid/wjc/data/ead/endovis2017/training/instrument_dataset_*/frame.png 41 | # 8 * 225 frames 42 | # 0 for valid and the last 25 of the rest are used for valid 43 | 44 | if test: 45 | self.images = [[j,i] for j in range(1,9) for i in range(225,300)] + [[j,i] for j in range(9,11) for i in range(300)] 46 | else: 47 | self.images = [] 48 | train_images = [] 49 | valid_images = [] 50 | for f in range(4): 51 | if f==fold: 52 | valid_images += [[j,i] for j in Folds[f] for i in range(225)] 53 | else: 54 | train_images += [[j,i] for j in Folds[f] for i in range(225)] 55 | self.images = train_images if self.split=='train' else valid_images 56 | print('Loaded {}frames'.format(len(self.images))) 57 | self.num_samples = len(self.images) 58 | self.rate = rate 59 | self.global_n = global_n 60 | 61 | 62 | def load_data(self, ins,frame,t=1,global_n=0): 63 | image = [] 64 | if global_n: 65 | global_images_index = (np.random.rand(global_n)*225).astype('int') 66 | image += [np.load('/raid/wjc/data/ead/endovis2017/training/instrument_dataset_{}/processed_v1/image{:03d}.npy'\ 67 | .format(ins,ind)) for ind in global_images_index] 68 | if t>frame: 69 | image += list([np.load('/raid/wjc/data/ead/endovis2017/training/instrument_dataset_{}/processed_v1/image{:03d}.npy'\ 70 | .format(ins,i)) for i in range(frame+t-1,frame-1,-1)]) 71 | else: 72 | image += list([np.load('/raid/wjc/data/ead/endovis2017/training/instrument_dataset_{}/processed_v1/image{:03d}.npy'\ 73 | .format(ins,i)) for i in range(frame-t+1,frame+1)]) 74 | label = np.load('/raid/wjc/data/ead/endovis2017/training/instrument_dataset_{}/processed_v1/{}{:03d}.npy'.format(ins,self.tag,frame)) 75 | return image, label 76 | 77 | def transform(self, images, masks): 78 | # Resize 79 | scale = random.random()*0.4+1 80 | resize = transforms.Resize(size=(int(512*scale), int(640*scale))) 81 | 82 | 83 | # image = resize(image) 84 | # mask = resize(mask) 85 | images = list([resize(image) for image in images]) 86 | masks = list([resize(mask) for mask in masks]) 87 | 88 | 89 | # Random crop 90 | i, j, h, w = transforms.RandomCrop.get_params( 91 | images[0], output_size=(512, 640)) 92 | # image = TF.crop(image, i, j, h, w) 93 | # mask = resize(mask) 94 | images = list([TF.crop(image, i, j, h, w) for image in images]) 95 | masks = list([TF.crop(mask, i, j, h, w) for mask in masks]) 96 | 97 | # Random horizontal flipping 98 | if random.random() > 0.5: 99 | # image = TF.hflip(image) 100 | images = list([TF.hflip(image) for image in images]) 101 | masks = list([TF.hflip(mask) for mask in masks]) 102 | 103 | # Random vertical flipping 104 | if random.random() > 0.5: 105 | # image = TF.vflip(image) 106 | images = list([TF.vflip(image) for image in images]) 107 | masks = list([TF.vflip(mask) for mask in masks]) 108 | 109 | return images, masks 110 | 111 | def __getitem__(self, index): 112 | 113 | ins,frame = self.images[index] 114 | # print(img_path) 115 | # st = time.perf_counter() 116 | imgs, label = self.load_data(ins, frame, self.t, global_n=self.global_n) 117 | # print('Load data:',time.perf_counter()-st) 118 | # st = time.perf_counter() 119 | 120 | label = (label/30+0.5).astype('int') # w * h 121 | masks = [] 122 | 123 | if self.split=='train': 124 | # img = Image.fromarray(np.uint8(img)) 125 | imgs = [Image.fromarray(np.uint8(img)) for img in imgs] 126 | classes = np.unique(label) 127 | masks = [] 128 | for cls in classes: 129 | if cls: 130 | masks.append(Image.fromarray(np.uint8(label==cls))) 131 | imgs,masks = self.transform(imgs,masks) 132 | imgs = [np.asarray(img) for img in imgs] 133 | label = np.zeros((imgs[0].shape[0],imgs[0].shape[1])) 134 | for i in range(1, len(classes)): 135 | mask = np.asarray(masks[i-1]) 136 | label[mask>0] = classes[i] 137 | # print('transform data:',time.perf_counter()-st) 138 | # st = time.perf_counter() 139 | imgs = np.array(imgs) 140 | # print('img2numpy:',time.perf_counter()-st) 141 | # st = time.perf_counter() 142 | img2 = imgs - np.min(imgs) 143 | img2 = img2 / np.max(img2) 144 | # img2 = imgs/255 145 | img2 = (img2 - self.mean) / self.std 146 | # print('imgmean:',time.perf_counter()-st) 147 | if (self.t+self.global_n)==1: 148 | img = img2[0].transpose(2,0,1) # c w h 149 | else: 150 | img = img2.transpose(0,3,1,2) # t c w h 151 | # print('Processed:',time.perf_counter()-st) 152 | # st = time.perf_counter() 153 | img = torch.from_numpy(img) 154 | # print('Img2tensor:',time.perf_counter()-st) 155 | # st = time.perf_counter() 156 | label = label[::self.rate,::self.rate] 157 | if self.tag=='part': 158 | label[label>self.class_num]=0 159 | label = torch.from_numpy(label) 160 | # print('Label2tensor:',time.perf_counter()-st) 161 | # st = time.perf_counter() 162 | label = F.one_hot(label.to(torch.int64),num_classes=self.class_num+1).permute(2,0,1) 163 | # print('final data:',time.perf_counter()-st) 164 | return {'path': [ins,frame],'image': img,'label': label} 165 | 166 | def __len__(self): 167 | return self.num_samples 168 | 169 | 170 | 171 | if __name__ == '__main__': 172 | from tqdm import tqdm 173 | import pickle 174 | 175 | dataset = endovis2017('train',fold=0,t=5,rate=1) 176 | for d in dataset: 177 | b1 = d 178 | print(d['image'].shape) 179 | print(d['label'].shape) 180 | break 181 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=8, 182 | shuffle=True, num_workers=2, 183 | pin_memory=True, drop_last=True) -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import random 4 | 5 | 6 | def flip(img): 7 | return img[:, :, ::-1].copy() 8 | 9 | # todo what the hell is this? 10 | def get_border(border, size): 11 | i = 1 12 | while size - border // i <= border // i: 13 | i *= 2 14 | return border // i 15 | 16 | def transform_preds(coords, center, scale, output_size): 17 | target_coords = np.zeros(coords.shape) 18 | trans = get_affine_transform(center, scale, 0, output_size, inv=1) 19 | for p in range(coords.shape[0]): 20 | target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) 21 | return target_coords 22 | 23 | 24 | def get_affine_transform(center, 25 | scale, 26 | rot, 27 | output_size, 28 | shift=np.array([0, 0], dtype=np.float32), 29 | inv=0): 30 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list): 31 | scale = np.array([scale, scale], dtype=np.float32) 32 | 33 | scale_tmp = scale 34 | src_w = scale_tmp[0] 35 | dst_w = output_size[0] 36 | dst_h = output_size[1] 37 | 38 | rot_rad = np.pi * rot / 180 39 | src_dir = get_dir([0, src_w * -0.5], rot_rad) 40 | dst_dir = np.array([0, dst_w * -0.5], np.float32) 41 | 42 | src = np.zeros((3, 2), dtype=np.float32) 43 | dst = np.zeros((3, 2), dtype=np.float32) 44 | src[0, :] = center + scale_tmp * shift 45 | src[1, :] = center + src_dir + scale_tmp * shift 46 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 47 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir 48 | 49 | src[2:, :] = get_3rd_point(src[0, :], src[1, :]) 50 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) 51 | 52 | if inv: 53 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 54 | else: 55 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 56 | 57 | return trans 58 | 59 | 60 | def affine_transform(pt, t): 61 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T 62 | new_pt = np.dot(t, new_pt) 63 | return new_pt[:2] 64 | 65 | 66 | def get_3rd_point(a, b): 67 | direct = a - b 68 | return b + np.array([-direct[1], direct[0]], dtype=np.float32) 69 | 70 | 71 | def get_dir(src_point, rot_rad): 72 | _sin, _cos = np.sin(rot_rad), np.cos(rot_rad) 73 | 74 | src_result = [0, 0] 75 | src_result[0] = src_point[0] * _cos - src_point[1] * _sin 76 | src_result[1] = src_point[0] * _sin + src_point[1] * _cos 77 | 78 | return src_result 79 | 80 | 81 | def crop(img, center, scale, output_size, rot=0): 82 | trans = get_affine_transform(center, scale, rot, output_size) 83 | 84 | dst_img = cv2.warpAffine(img, 85 | trans, 86 | (int(output_size[0]), int(output_size[1])), 87 | flags=cv2.INTER_LINEAR) 88 | 89 | return dst_img 90 | 91 | 92 | # todo what's this? 93 | def gaussian_radius(det_size, min_overlap=0.7): 94 | height, width = det_size 95 | 96 | a1 = 1 97 | b1 = (height + width) 98 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 99 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) 100 | r1 = (b1 + sq1) / 2 101 | 102 | a2 = 4 103 | b2 = 2 * (height + width) 104 | c2 = (1 - min_overlap) * width * height 105 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) 106 | r2 = (b2 + sq2) / 2 107 | 108 | a3 = 4 * min_overlap 109 | b3 = -2 * min_overlap * (height + width) 110 | c3 = (min_overlap - 1) * width * height 111 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) 112 | r3 = (b3 + sq3) / 2 113 | return min(r1, r2, r3) 114 | 115 | 116 | def gaussian2D(shape, sigma=1): 117 | m, n = [(ss - 1.) / 2. for ss in shape] 118 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 119 | 120 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 121 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 122 | return h 123 | 124 | 125 | def draw_umich_gaussian(heatmap, center, radius, k=1): 126 | diameter = 2 * radius + 1 127 | gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) 128 | 129 | x, y = int(center[0]), int(center[1]) 130 | 131 | height, width = heatmap.shape[0:2] 132 | 133 | left, right = min(x, radius), min(width - x, radius + 1) 134 | top, bottom = min(y, radius), min(height - y, radius + 1) 135 | 136 | masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] 137 | masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] 138 | if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug 139 | np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) 140 | return heatmap 141 | 142 | 143 | def draw_dense_reg(regmap, heatmap, center, value, radius, is_offset=False): 144 | diameter = 2 * radius + 1 145 | gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) 146 | value = np.array(value, dtype=np.float32).reshape(-1, 1, 1) 147 | dim = value.shape[0] 148 | reg = np.ones((dim, diameter * 2 + 1, diameter * 2 + 1), dtype=np.float32) * value 149 | if is_offset and dim == 2: 150 | delta = np.arange(diameter * 2 + 1) - radius 151 | reg[0] = reg[0] - delta.reshape(1, -1) 152 | reg[1] = reg[1] - delta.reshape(-1, 1) 153 | 154 | x, y = int(center[0]), int(center[1]) 155 | 156 | height, width = heatmap.shape[0:2] 157 | 158 | left, right = min(x, radius), min(width - x, radius + 1) 159 | top, bottom = min(y, radius), min(height - y, radius + 1) 160 | 161 | masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] 162 | masked_regmap = regmap[:, y - top:y + bottom, x - left:x + right] 163 | masked_gaussian = gaussian[radius - top:radius + bottom, 164 | radius - left:radius + right] 165 | masked_reg = reg[:, radius - top:radius + bottom, 166 | radius - left:radius + right] 167 | if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug 168 | idx = (masked_gaussian >= masked_heatmap).reshape( 169 | 1, masked_gaussian.shape[0], masked_gaussian.shape[1]) 170 | masked_regmap = (1 - idx) * masked_regmap + idx * masked_reg 171 | regmap[:, y - top:y + bottom, x - left:x + right] = masked_regmap 172 | return regmap 173 | 174 | 175 | def draw_msra_gaussian(heatmap, center, sigma): 176 | tmp_size = sigma * 3 177 | mu_x = int(center[0] + 0.5) 178 | mu_y = int(center[1] + 0.5) 179 | w, h = heatmap.shape[0], heatmap.shape[1] 180 | ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] 181 | br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] 182 | if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0: 183 | return heatmap 184 | size = 2 * tmp_size + 1 185 | x = np.arange(0, size, 1, np.float32) 186 | y = x[:, np.newaxis] 187 | x0 = y0 = size // 2 188 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 189 | g_x = max(0, -ul[0]), min(br[0], h) - ul[0] 190 | g_y = max(0, -ul[1]), min(br[1], w) - ul[1] 191 | img_x = max(0, ul[0]), min(br[0], h) 192 | img_y = max(0, ul[1]), min(br[1], w) 193 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum( 194 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]], 195 | g[g_y[0]:g_y[1], g_x[0]:g_x[1]]) 196 | return heatmap 197 | 198 | 199 | def grayscale(image): 200 | return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 201 | 202 | 203 | def lighting_(data_rng, image, alphastd, eigval, eigvec): 204 | alpha = data_rng.normal(scale=alphastd, size=(3,)) 205 | image += np.dot(eigvec, eigval * alpha) 206 | 207 | 208 | def blend_(alpha, image1, image2): 209 | image1 *= alpha 210 | image2 *= (1 - alpha) 211 | image1 += image2 212 | 213 | 214 | def saturation_(data_rng, image, gs, gs_mean, var): 215 | alpha = 1. + data_rng.uniform(low=-var, high=var) 216 | blend_(alpha, image, gs[:, :, None]) 217 | 218 | 219 | def brightness_(data_rng, image, gs, gs_mean, var): 220 | alpha = 1. + data_rng.uniform(low=-var, high=var) 221 | image *= alpha 222 | 223 | 224 | def contrast_(data_rng, image, gs, gs_mean, var): 225 | alpha = 1. + data_rng.uniform(low=-var, high=var) 226 | blend_(alpha, image, gs_mean) 227 | 228 | 229 | def color_aug(data_rng, image, eig_val, eig_vec): 230 | functions = [brightness_, contrast_, saturation_] 231 | random.shuffle(functions) 232 | 233 | gs = grayscale(image) 234 | gs_mean = gs.mean() 235 | for f in functions: 236 | f(data_rng, image, gs, gs_mean, 0.4) 237 | lighting_(data_rng, image, 0.1, eig_val, eig_vec) 238 | -------------------------------------------------------------------------------- /net/Ours/Module.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys, time, os 5 | 6 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../')) 7 | 8 | from net.utils.helpers import maybe_download 9 | from net.utils.layer_factory import conv1x1, conv3x3, convbnrelu, CRPBlock 10 | from net.LSTM.torch_convlstm import ConvLSTM 11 | from net.LSTM.bottlenecklstm import BottleneckLSTM 12 | from net.LSTM.grouplstm import GroupLSTM 13 | 14 | data_info = {21: "VOC"} 15 | 16 | models_urls = { 17 | "mbv2_voc": 18 | "https://cloudstor.aarnet.edu.au/plus/s/nQ6wDnTEFhyidot/download", 19 | "mbv2_imagenet": 20 | "https://cloudstor.aarnet.edu.au/plus/s/uRgFbkaRjD3qOg5/download", 21 | } 22 | 23 | 24 | class TimeProcesser(nn.Module): 25 | def __init__(self, inplanes, planes, size, batch_size, tag, group=1): 26 | super(TimeProcesser, self).__init__() 27 | self.inplanes = inplanes 28 | self.planes = planes 29 | self.size = size 30 | self.tag = tag 31 | self.batch_size = batch_size 32 | if not inplanes == planes: 33 | self.refine = conv1x1(planes, inplanes) 34 | if self.tag == 'convlstm': 35 | self.processer = ConvLSTM(input_size=size, 36 | input_dim=inplanes, 37 | hidden_dim=[planes], 38 | kernel_size=(3, 3), 39 | num_layers=1, 40 | batch_first=True, 41 | bias=True, 42 | return_all_layers=False) 43 | elif self.tag == 'btnlstm': 44 | self.processer = BottleneckLSTM(inplanes, planes, size[0], size[1]) 45 | elif self.tag == 'group': 46 | self.processer = GroupLSTM(inplanes, planes, size[0], size[1], 47 | group) 48 | else: 49 | pass 50 | 51 | def forward(self, x): 52 | x = self.processer(x) 53 | return x 54 | 55 | 56 | class InvertedResidualBlock(nn.Module): 57 | """Inverted Residual Block from https://arxiv.org/abs/1801.04381""" 58 | def __init__(self, in_planes, out_planes, expansion_factor, stride=1): 59 | super(InvertedResidualBlock, self).__init__() 60 | intermed_planes = in_planes * expansion_factor 61 | self.residual = (in_planes == out_planes) and (stride == 1) 62 | self.output = nn.Sequential( 63 | convbnrelu(in_planes, intermed_planes, 1), 64 | convbnrelu( 65 | intermed_planes, 66 | intermed_planes, 67 | 3, 68 | stride=stride, 69 | groups=intermed_planes, 70 | ), 71 | convbnrelu(intermed_planes, out_planes, 1, act=False), 72 | ) 73 | 74 | def forward(self, x): 75 | residual = x 76 | out = self.output(x) 77 | if self.residual: 78 | return out + residual 79 | else: 80 | return out 81 | 82 | 83 | class MemoryCore(nn.Module): 84 | def __init__(self): 85 | super(MemoryCore, self).__init__() 86 | self.dropout = nn.Dropout(0.1) 87 | 88 | def forward(self, m_in, m_out, q_in, q_out): # m_in: o,t,c,h,w 89 | B, T, D_e, H, W = m_in.size() 90 | _, _, D_o, _, _ = m_out.size() 91 | 92 | mi = m_in.transpose(1, 2).contiguous().view(B, D_e, T * H * W) 93 | mi = torch.transpose(mi, 1, 2).contiguous() # b, THW, emb 94 | 95 | qi = q_in.view(B, D_e, H * W) # b, emb, HW 96 | 97 | p = torch.bmm(mi, qi) # b, THW, HW 98 | p = p / math.sqrt(D_e) 99 | p = F.softmax(p, dim=1) # b, THW, HW 100 | p = self.dropout(p) 101 | 102 | mo = m_out.view(B, D_o, T * H * W) 103 | mem = torch.bmm(mo, p) # Weighted-sum B, D_o, HW 104 | mem = mem.view(B, D_o, H, W) 105 | 106 | mem_out = torch.cat([mem, q_out], dim=1) 107 | 108 | return mem_out, p 109 | 110 | 111 | class KeyValue(nn.Module): 112 | def __init__(self, indim, keydim, valdim): 113 | super(KeyValue, self).__init__() 114 | self.Key = nn.Conv2d(indim, 115 | keydim, 116 | kernel_size=(3, 3), 117 | padding=(1, 1), 118 | stride=1) 119 | self.Value = nn.Conv2d(indim, 120 | valdim, 121 | kernel_size=(3, 3), 122 | padding=(1, 1), 123 | stride=1) 124 | 125 | def forward(self, x): 126 | return [self.Key(x), self.Value(x)] 127 | 128 | 129 | class Memory(nn.Module): 130 | def __init__(self, c): 131 | super(Memory, self).__init__() 132 | self.mem_core = MemoryCore() 133 | self.kv = KeyValue(c, c // 4, c // 2) 134 | 135 | def forward(self, mem, query): 136 | _, T, _, _, _ = mem.size() 137 | # print('Memory:{}'.format(T)) 138 | keys = [] 139 | values = [] 140 | for t in range(T): 141 | k, v = self.kv(mem[:, t]) 142 | keys.append(k.unsqueeze(1)) 143 | values.append(v.unsqueeze(1)) 144 | MemoryKeys = torch.cat(keys, dim=1) 145 | MemoryValues = torch.cat(values, dim=1) 146 | CurrentKey, CurrentValue = self.kv(query) 147 | mem_out, p = self.mem_core(MemoryKeys, MemoryValues, CurrentKey, 148 | CurrentValue) 149 | return mem_out, p 150 | 151 | 152 | class MobileEncoder(nn.Module): 153 | """Encoder mobilev2""" 154 | 155 | mobilenet_config = [ 156 | [1, 16, 1, 157 | 1], # expansion rate, output channels, number of repeats, stride 158 | [6, 24, 2, 2], 159 | [6, 32, 3, 2], 160 | [6, 64, 4, 2], 161 | [6, 96, 3, 1], 162 | [6, 160, 3, 2], 163 | [6, 320, 1, 1], 164 | ] 165 | in_planes = 32 # number of input channels 166 | num_layers = len(mobilenet_config) 167 | 168 | def __init__(self): 169 | super(MobileEncoder, self).__init__() 170 | self.layer1 = convbnrelu(3, self.in_planes, kernel_size=3, stride=2) 171 | c_layer = 2 172 | for t, c, n, s in self.mobilenet_config: 173 | layers = [] 174 | for idx in range(n): 175 | layers.append( 176 | InvertedResidualBlock( 177 | self.in_planes, 178 | c, 179 | expansion_factor=t, 180 | stride=s if idx == 0 else 1, 181 | )) 182 | self.in_planes = c 183 | setattr(self, "layer{}".format(c_layer), nn.Sequential(*layers)) 184 | c_layer += 1 185 | 186 | ## Light-Weight RefineNet ## 187 | self.conv8 = conv1x1(320, 256, bias=False) 188 | self.conv7 = conv1x1(160, 256, bias=False) 189 | 190 | self.relu = nn.ReLU6(inplace=True) 191 | 192 | def forward(self, x): 193 | l1 = self.layer1(x) 194 | l2 = self.layer2(l1) # x / 2 195 | l3 = self.layer3(l2) # 24, x / 4 196 | l4 = self.layer4(l3) # 32, x / 8 197 | l5 = self.layer5(l4) # 64, x / 16 198 | l6 = self.layer6(l5) # 96, x / 16 199 | l7 = self.layer7(l6) # 160, x / 32 200 | l8 = self.layer8(l7) # 320, x / 32 201 | l8 = self.conv8(l8) 202 | l7 = self.conv7(l7) 203 | l7 = self.relu(l8 + l7) # 256, x/32 204 | return [l3, l4, l5, l6, l7] 205 | 206 | 207 | class RefineDecoder(nn.Module): 208 | def __init__(self, num_classes): 209 | super(RefineDecoder, self).__init__() 210 | self.conv6 = conv1x1(96, 256, bias=False) 211 | self.conv5 = conv1x1(64, 256, bias=False) 212 | self.conv4 = conv1x1(32, 256, bias=False) 213 | self.conv3 = conv1x1(24, 256, bias=False) 214 | self.crp4 = self._make_crp(256, 256, 4) 215 | self.crp3 = self._make_crp(256, 256, 4) 216 | self.crp2 = self._make_crp(256, 256, 4) 217 | self.crp1 = self._make_crp(256, 256, 4) 218 | 219 | self.conv_adapt4 = conv1x1(256, 256, bias=False) 220 | self.conv_adapt3 = conv1x1(256, 256, bias=False) 221 | self.conv_adapt2 = conv1x1(256, 256, bias=False) 222 | 223 | self.segm = conv3x3(256, num_classes, bias=True) 224 | self.relu = nn.ReLU6(inplace=True) 225 | 226 | def _make_crp(self, in_planes, out_planes, stages): 227 | layers = [CRPBlock(in_planes, out_planes, stages)] 228 | return nn.Sequential(*layers) 229 | 230 | def forward(self, x): 231 | l3, l4, l5, l6, l7 = x 232 | 233 | l7 = self.crp4(l7) 234 | l7 = self.conv_adapt4(l7) 235 | l7 = nn.Upsample(size=l6.size()[2:], 236 | mode="bilinear", 237 | align_corners=True)(l7) 238 | 239 | l6 = self.conv6(l6) 240 | l5 = self.conv5(l5) 241 | l5 = self.relu(l5 + l6 + l7) 242 | l5 = self.crp3(l5) 243 | l5 = self.conv_adapt3(l5) 244 | l5 = nn.Upsample(size=l4.size()[2:], 245 | mode="bilinear", 246 | align_corners=True)(l5) 247 | 248 | l4 = self.conv4(l4) 249 | l4 = self.relu(l5 + l4) 250 | l4 = self.crp2(l4) 251 | l4 = self.conv_adapt2(l4) 252 | l4 = nn.Upsample(size=l3.size()[2:], 253 | mode="bilinear", 254 | align_corners=True)(l4) 255 | 256 | l3 = self.conv3(l3) 257 | l3 = self.relu(l3 + l4) 258 | l3 = self.crp1(l3) 259 | 260 | out_segm = self.segm(l3) 261 | return out_segm 262 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def albumentations2torchvision(transforms): 7 | """Wrap albumentations transformation so that they can be used in torchvision dataset""" 8 | from albumentations import Compose 9 | 10 | def wrapper_func(image, target): 11 | keys = ["image", "mask"] 12 | np_dtypes = [np.float32, np.uint8] 13 | torch_dtypes = [torch.float32, torch.long] 14 | sample_dict = { 15 | key: np.array(value, dtype=dtype) 16 | for key, value, dtype in zip(keys, [image, target], np_dtypes) 17 | } 18 | output = Compose(transforms)(**sample_dict) 19 | return [output[key].to(dtype) for key, dtype in zip(keys, torch_dtypes)] 20 | 21 | return wrapper_func 22 | 23 | 24 | def albumentations_transforms( 25 | crop_size, 26 | shorter_side, 27 | low_scale, 28 | high_scale, 29 | img_mean, 30 | img_std, 31 | img_scale, 32 | ignore_label, 33 | num_stages, 34 | dataset_type, 35 | ): 36 | from albumentations import ( 37 | Normalize, 38 | HorizontalFlip, 39 | RandomCrop, 40 | PadIfNeeded, 41 | RandomScale, 42 | LongestMaxSize, 43 | SmallestMaxSize, 44 | OneOf, 45 | ) 46 | from albumentations.pytorch import ToTensorV2 as ToTensor 47 | from densetorch.data import albumentations2densetorch 48 | 49 | if dataset_type == "densetorch": 50 | wrapper = albumentations2densetorch 51 | elif dataset_type == "torchvision": 52 | wrapper = albumentations2torchvision 53 | else: 54 | raise ValueError(f"Unknown dataset type: {dataset_type}") 55 | 56 | common_transformations = [ 57 | Normalize(max_pixel_value=1.0 / img_scale, mean=img_mean, std=img_std), 58 | ToTensor(), 59 | ] 60 | train_transforms = [] 61 | for stage in range(num_stages): 62 | train_transforms.append( 63 | wrapper( 64 | [ 65 | OneOf( 66 | [ 67 | RandomScale( 68 | scale_limit=(low_scale[stage], high_scale[stage]) 69 | ), 70 | LongestMaxSize(max_size=shorter_side[stage]), 71 | SmallestMaxSize(max_size=shorter_side[stage]), 72 | ] 73 | ), 74 | PadIfNeeded( 75 | min_height=crop_size[stage], 76 | min_width=crop_size[stage], 77 | border_mode=cv2.BORDER_CONSTANT, 78 | value=np.array(img_mean) / img_scale, 79 | mask_value=ignore_label, 80 | ), 81 | HorizontalFlip(p=0.5,), 82 | RandomCrop(height=crop_size[stage], width=crop_size[stage],), 83 | ] 84 | + common_transformations 85 | ) 86 | ) 87 | val_transforms = wrapper(common_transformations) 88 | return train_transforms, val_transforms 89 | 90 | 91 | def densetorch_transforms( 92 | crop_size, 93 | shorter_side, 94 | low_scale, 95 | high_scale, 96 | img_mean, 97 | img_std, 98 | img_scale, 99 | ignore_label, 100 | num_stages, 101 | dataset_type, 102 | ): 103 | from torchvision.transforms import Compose 104 | from densetorch.data import ( 105 | Pad, 106 | RandomCrop, 107 | RandomMirror, 108 | ResizeAndScale, 109 | ToTensor, 110 | Normalise, 111 | densetorch2torchvision, 112 | ) 113 | 114 | if dataset_type == "densetorch": 115 | wrapper = Compose 116 | elif dataset_type == "torchvision": 117 | wrapper = densetorch2torchvision 118 | else: 119 | raise ValueError(f"Unknown dataset type: {dataset_type}") 120 | 121 | common_transformations = [ 122 | Normalise(scale=img_scale, mean=img_mean, std=img_std), 123 | ToTensor(), 124 | ] 125 | train_transforms = [] 126 | for stage in range(num_stages): 127 | train_transforms.append( 128 | wrapper( 129 | [ 130 | ResizeAndScale( 131 | shorter_side[stage], low_scale[stage], high_scale[stage] 132 | ), 133 | Pad(crop_size[stage], img_mean, ignore_label), 134 | RandomMirror(), 135 | RandomCrop(crop_size[stage]), 136 | ] 137 | + common_transformations 138 | ) 139 | ) 140 | val_transforms = wrapper(common_transformations) 141 | return train_transforms, val_transforms 142 | 143 | 144 | def get_transforms( 145 | crop_size, 146 | shorter_side, 147 | low_scale, 148 | high_scale, 149 | img_mean, 150 | img_std, 151 | img_scale, 152 | ignore_label, 153 | num_stages, 154 | augmentations_type, 155 | dataset_type, 156 | ): 157 | """ 158 | Args: 159 | 160 | crop_size (int) : square crop to apply during the training. 161 | shorter_side (int) : parameter of the shorter_side resize transformation. 162 | low_scale (float) : lowest scale ratio for augmentations. 163 | high_scale (float) : highest scale ratio for augmentations. 164 | img_mean (list of float) : image mean. 165 | img_std (list of float) : image standard deviation 166 | img_scale (list of float) : image scale. 167 | ignore_label (int) : label to pad segmentation masks with. 168 | num_stages (int): how many train_transforms to create. 169 | augmentations_type (str): whether to use densetorch augmentations or albumentations. 170 | dataset_type (str): whether to use densetorch or torchvision dataset; 171 | needed to correctly wrap transformations. 172 | 173 | Returns: 174 | train_transforms, val_transforms 175 | 176 | """ 177 | if augmentations_type == "densetorch": 178 | func = densetorch_transforms 179 | elif augmentations_type == "albumentations": 180 | func = albumentations_transforms 181 | else: 182 | raise ValueError(f"Unknown augmentations type {augmentations_type}") 183 | return func( 184 | crop_size=crop_size, 185 | shorter_side=shorter_side, 186 | low_scale=low_scale, 187 | high_scale=high_scale, 188 | img_mean=img_mean, 189 | img_std=img_std, 190 | img_scale=img_scale, 191 | ignore_label=ignore_label, 192 | num_stages=num_stages, 193 | dataset_type=dataset_type, 194 | ) 195 | 196 | 197 | def densetorch_dataset( 198 | train_dir, 199 | val_dir, 200 | train_list_path, 201 | val_list_path, 202 | train_transforms, 203 | val_transforms, 204 | masks_names, 205 | stage_names, 206 | train_download, 207 | val_download, 208 | ): 209 | from densetorch.data import MMDataset as Dataset 210 | 211 | def line_to_paths_fn(x): 212 | rgb, segm = x.decode("utf-8").strip("\n").split("\t")[:2] 213 | return [rgb, segm] 214 | 215 | train_sets = [ 216 | Dataset( 217 | data_file=train_list_path[i], 218 | data_dir=train_dir[i], 219 | line_to_paths_fn=line_to_paths_fn, 220 | masks_names=masks_names, 221 | transform=train_transforms[i], 222 | ) 223 | for i in range(len(train_transforms)) 224 | ] 225 | val_set = Dataset( 226 | data_file=val_list_path, 227 | data_dir=val_dir, 228 | line_to_paths_fn=line_to_paths_fn, 229 | masks_names=masks_names, 230 | transform=val_transforms, 231 | ) 232 | return train_sets, val_set 233 | 234 | 235 | def torchvision_dataset( 236 | train_dir, 237 | val_dir, 238 | train_list_path, 239 | val_list_path, 240 | train_transforms, 241 | val_transforms, 242 | masks_names, 243 | stage_names, 244 | train_download, 245 | val_download, 246 | ): 247 | from torchvision.datasets.voc import VOCSegmentation 248 | from torchvision.datasets import SBDataset 249 | from functools import partial 250 | 251 | train_sets = [] 252 | for i, stage in enumerate(stage_names): 253 | if stage.lower() == "voc": 254 | Dataset = partial(VOCSegmentation, image_set="train", year="2012",) 255 | elif stage.lower() == "sbd": 256 | Dataset = partial(SBDataset, mode="segmentation", image_set="train_noval") 257 | train_sets.append( 258 | Dataset( 259 | root=train_dir[i], 260 | transforms=train_transforms[i], 261 | download=train_download[i], 262 | ) 263 | ) 264 | 265 | val_set = VOCSegmentation( 266 | root=val_dir, 267 | image_set="val", 268 | year="2012", 269 | download=val_download, 270 | transforms=val_transforms, 271 | ) 272 | 273 | return train_sets, val_set 274 | 275 | 276 | def get_datasets( 277 | train_dir, 278 | val_dir, 279 | train_list_path, 280 | val_list_path, 281 | train_transforms, 282 | val_transforms, 283 | masks_names, 284 | dataset_type, 285 | stage_names, 286 | train_download, 287 | val_download, 288 | ): 289 | if dataset_type == "densetorch": 290 | func = densetorch_dataset 291 | elif dataset_type == "torchvision": 292 | func = torchvision_dataset 293 | else: 294 | raise ValueError(f"Unknown dataset type {dataset_type}") 295 | return func( 296 | train_dir, 297 | val_dir, 298 | train_list_path, 299 | val_list_path, 300 | train_transforms, 301 | val_transforms, 302 | masks_names, 303 | stage_names, 304 | train_download, 305 | val_download, 306 | ) 307 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | 6 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'lib')) 7 | 8 | import numpy as np 9 | 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.utils.data 13 | import torch.distributed as dist 14 | 15 | from net.MobileNetRefine.LWANet import lwa 16 | from net.MobileNetRefine.mobilenet import mbv2 17 | from net.MobileNetRefine.resnet import rf_lw50 18 | # from net.TernausNet.ternausnet import UNet16 as TUnet 19 | from net.TernausNet.tunet import UNet16 as TUnet 20 | from net.Ours.base import TemporalNet 21 | # from net.BiseNet.r18 import BiSeNet 22 | from net.unet.unet import UNet 23 | from net.Ours.DenseST import DenseST 24 | from net.Ours.SpNet import spnet 25 | from net.Ours.GlobalDenseST import GDST 26 | ####data 27 | 28 | from dataset.Endovis2017 import endovis2017 29 | from utils.losses import BCELoss 30 | # from utils.EndoLoss import LossMulti 31 | from utils.EndoLoss import LossMulti 32 | from utils.metrics import compute_dice, compute_iou 33 | from utils.summary import create_summary, create_logger, create_saver, DisablePrint 34 | from utils.LoadModel import load_model 35 | # from net.BiseNet.seg_opr.loss_opr import SigmoidFocalLoss, ProbOhemCrossEntropy2d 36 | 37 | # Training settings 38 | parser = argparse.ArgumentParser(description='real-time segmentation') 39 | 40 | parser.add_argument('--local_rank', type=int, default=0) 41 | parser.add_argument('--dist', action='store_true') 42 | 43 | parser.add_argument('--root_dir', 44 | type=str, 45 | default='/raid/wjc/logs/RealtimeSegmentation') 46 | parser.add_argument('--dataset', type=str) 47 | parser.add_argument('--data_tag', type=str, choices=['part', 'type']) 48 | parser.add_argument('--log_name', type=str) 49 | 50 | parser.add_argument('--arch', 51 | type=str, 52 | choices=[ 53 | 'mb_rf', 'lwa', 'bise', 'unet', 'tunet', 'tpnet', 54 | 'spnet', 'densest', 'gdst', 'res50_rf' 55 | ]) 56 | parser.add_argument('--load_model', type=str, default=None) 57 | 58 | parser.add_argument('--folds', type=str) 59 | parser.add_argument('--lr', type=float, default=1e-4) 60 | parser.add_argument('--batch_size', type=int, default=32) 61 | parser.add_argument('--num_epochs', type=int, default=200) 62 | parser.add_argument('--loss', type=str) 63 | 64 | parser.add_argument('--gpus', type=str) 65 | 66 | parser.add_argument('--log_interval', type=int, default=10) 67 | parser.add_argument('--val_interval', type=int, default=1) 68 | parser.add_argument('--num_workers', type=int, default=3) 69 | 70 | parser.add_argument('--lstm', 71 | type=str, 72 | choices=['convlstm', 'btnlstm', 'grouplstm', 'kdlstm']) 73 | parser.add_argument('--t', type=int) 74 | 75 | parser.add_argument('--freeze_name', type=str) 76 | parser.add_argument('--spatial_layer', type=int) 77 | parser.add_argument('--global_n', type=int) 78 | parser.add_argument('--need_pretrain', type=int) 79 | parser.add_argument('--pre_name', type=str) 80 | parser.add_argument('--pretrain_ep', type=int, default=20) 81 | parser.add_argument('--decay', type=int, default=2) 82 | parser.add_argument('--fusion_type', type=str) 83 | 84 | parser.add_argument('--reset', action='store_true') 85 | parser.add_argument('--reset_epoch', type=int) 86 | 87 | cfg = parser.parse_args() 88 | # os.chdir(cfg.root_dir) 89 | cfg.folds = list(map(int, cfg.folds.split(','))) 90 | # loss_functions = {'dice': DiceLoss(ignore_index=4), 'bce': BCELoss()} 91 | loss_functions = {'bce': BCELoss()} 92 | # rate = 1 if cfg.arch=='bise' else 1 93 | rate = 1 94 | 95 | 96 | def main(): 97 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus 98 | # torch.manual_seed(317) 99 | torch.backends.cudnn.benchmark = True # disable this if OOM at beginning of training 100 | num_gpus = torch.cuda.device_count() 101 | 102 | if cfg.dist: 103 | cfg.device = torch.device('cuda:%d' % cfg.local_rank) 104 | torch.cuda.set_device(cfg.local_rank) 105 | dist.init_process_group(backend='nccl', 106 | init_method='env://', 107 | world_size=num_gpus, 108 | rank=cfg.local_rank) 109 | else: 110 | cfg.device = torch.device('cuda') 111 | 112 | for fold in cfg.folds: 113 | cfg.log_dir = os.path.join(cfg.root_dir, cfg.dataset, cfg.data_tag, 114 | cfg.log_name, 'logs', 'fold{}'.format(fold)) 115 | cfg.ckpt_dir = os.path.join(cfg.root_dir, cfg.dataset, cfg.data_tag, 116 | cfg.log_name, 'ckpt', 117 | 'fold{}'.format(fold)) 118 | os.makedirs(cfg.log_dir, exist_ok=True) 119 | os.makedirs(cfg.ckpt_dir, exist_ok=True) 120 | saver = create_saver(cfg.local_rank, save_dir=cfg.ckpt_dir) 121 | logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir) 122 | summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir) 123 | print = logger.info 124 | print(cfg) 125 | print('Setting up data...') 126 | 127 | if cfg.dataset == 'endovis2017': 128 | train_dataset = endovis2017('train', 129 | t=cfg.t, 130 | fold=fold, 131 | rate=rate, 132 | tag=cfg.data_tag, 133 | global_n=cfg.global_n) 134 | val_dataset = endovis2017('val', 135 | t=cfg.t, 136 | fold=fold, 137 | rate=rate, 138 | tag=cfg.data_tag, 139 | global_n=cfg.global_n) 140 | classes = train_dataset.class_num 141 | 142 | val_loader = torch.utils.data.DataLoader(val_dataset, 143 | batch_size=1, 144 | shuffle=False, 145 | num_workers=cfg.num_workers, 146 | pin_memory=True, 147 | drop_last=False) 148 | compute_loss = loss_functions[cfg.loss] 149 | # build model 150 | if 'mb_rf' in cfg.arch: 151 | model = mbv2(classes + 1, imagenet=True, rate=1) 152 | elif 'res50_rf' in cfg.arch: 153 | model = rf_lw50(classes + 1, imagenet=True) 154 | elif 'lwa' in cfg.arch: 155 | model = lwa(classes + 1, imagenet=True) 156 | elif 'tunet' in cfg.arch: 157 | model = TUnet(in_channels=64, 158 | num_classes=classes + 1, 159 | pretrained=True) 160 | # compute_loss = LossMulti(jaccard_weight=1,num_classes=classes+1) 161 | elif 'unet' in cfg.arch: 162 | model = UNet(3, classes + 1) 163 | elif 'tpnet' in cfg.arch: 164 | assert (cfg.t > 1) 165 | model = TemporalNet(classes + 1, 166 | batch_size=cfg.batch_size, 167 | tag=cfg.lstm, 168 | group=1) 169 | elif 'spnet' in cfg.arch: 170 | assert (cfg.global_n > 1) 171 | model = spnet(classes + 1, 172 | imagenet=True, 173 | global_n=cfg.global_n, 174 | spatial_layer=cfg.spatial_layer) 175 | elif 'densest' in cfg.arch: 176 | assert (cfg.t > 1) 177 | model = DenseST(classes + 1, tag=cfg.lstm) 178 | elif 'gdst' in cfg.arch: 179 | assert (cfg.t > 1 and cfg.global_n > 0 180 | and cfg.fusion_type is not None) 181 | model = GDST(classes + 1, 182 | batch_size=cfg.batch_size, 183 | tag=cfg.lstm, 184 | group=1, 185 | t=cfg.t, 186 | global_n=cfg.global_n, 187 | fusion_type=cfg.fusion_type) 188 | else: 189 | raise NotImplementedError 190 | 191 | optimizer = torch.optim.Adam(model.parameters(), cfg.lr) 192 | 193 | torch.cuda.empty_cache() 194 | 195 | def train(epoch): 196 | print('\n Epoch: %d' % epoch) 197 | model.train() 198 | tic = time.perf_counter() 199 | for batch_idx, batch in enumerate(train_loader): 200 | for k in batch: 201 | if not k == 'path': 202 | batch[k] = batch[k].to(device=cfg.device, 203 | non_blocking=True).float() 204 | outputs = model(batch['image']) 205 | if 'bise' in cfg.arch: 206 | loss = compute_loss(outputs, batch['label']) 207 | elif cfg.arch in ['unet', 'tunet']: 208 | loss = compute_loss(outputs, batch['label']) 209 | elif cfg.arch in [ 210 | 'mb_rf', 'lwa', 'tpnet', 'densest', 'spnet', 'gdst', 211 | 'res50_rf' 212 | ]: 213 | outputs = F.interpolate(outputs, scale_factor=4 // rate) 214 | loss = compute_loss(outputs, batch['label']) 215 | else: 216 | raise NotImplementedError 217 | optimizer.zero_grad() 218 | loss.backward() 219 | optimizer.step() 220 | 221 | if batch_idx % cfg.log_interval == 0: 222 | duration = time.perf_counter() - tic 223 | tic = time.perf_counter() 224 | print( 225 | '[%d/%d-%d/%d]' % 226 | (epoch, cfg.num_epochs, batch_idx, len(train_loader)) + 227 | 'Dice_loss:{:.4f} Time:{:.4f}'.format( 228 | loss.item(), duration)) 229 | 230 | step = len(train_loader) * epoch + batch_idx 231 | summary_writer.add_scalar('loss/AVG', loss.item(), step) 232 | return 233 | 234 | def val_map(epoch): 235 | print('\n Val@Epoch: %d' % epoch) 236 | model.eval() 237 | torch.cuda.empty_cache() 238 | dices = [] 239 | ious = [] 240 | metrics = np.zeros((2, classes)) 241 | with torch.no_grad(): 242 | 243 | for inputs in val_loader: 244 | inputs['image'] = inputs['image'].to(cfg.device).float() 245 | 246 | tic = time.perf_counter() 247 | output = model(inputs['image']) 248 | if 'bise' in cfg.arch: 249 | output = F.softmax(output, dim=1).cpu().numpy() 250 | elif cfg.arch in [ 251 | 'mb_rf', 'lwa', 'tpnet', 'densest', 'spnet', 252 | 'gdst', 'res50_rf' 253 | ]: 254 | output = F.interpolate(output, scale_factor=4 // rate) 255 | output = F.softmax(output, dim=1) 256 | output = F.one_hot(torch.argmax(output, dim=1), 257 | num_classes=classes + 1).permute( 258 | 0, 3, 1, 2) 259 | output = output.cpu().numpy() 260 | elif cfg.arch in ['unet', 'tunet']: 261 | output = F.softmax(output, dim=1).cpu().numpy() 262 | 263 | 264 | # output = output.cpu().numpy() 265 | else: 266 | 267 | raise NotImplementedError 268 | duration = time.perf_counter() - tic 269 | dice = compute_dice(output, 270 | inputs['label'].numpy(), 271 | return_all=True) 272 | iou = compute_iou(output, 273 | inputs['label'].numpy(), 274 | return_all=True) 275 | dices.append(dice) 276 | ious.append(iou) 277 | dices = np.array(dices) 278 | ious = np.array(ious) 279 | for i in range(classes): 280 | metrics[0, i] = np.mean(dices[:, i][dices[:, i] > -1]) 281 | metrics[1, i] = np.mean(ious[:, i][ious[:, i] > -1]) 282 | print(metrics) 283 | dc, jc = [ 284 | np.mean(metrics[i][metrics[i] > -1]) for i in range(2) 285 | ] 286 | print('Dice:{:.4f} IoU:{:.4f} Time:{:.4f}'.format( 287 | dc, jc, duration)) 288 | summary_writer.add_scalar('Dice/Fold{}'.format(fold), dc, epoch) 289 | summary_writer.add_scalar('IoU/Fold{}'.format(fold), jc, epoch) 290 | return dc 291 | 292 | print('Starting training...') 293 | best = 0 294 | best_ep = 0 295 | model = model.to(cfg.device) 296 | 297 | if cfg.arch in ['densest', 'gdst', 'bgdst']: 298 | mem_path = os.path.join(cfg.root_dir, cfg.dataset, cfg.data_tag, 299 | 'spnet', 'ckpt', 'fold{}'.format(fold), 300 | 'checkpoint.t7') 301 | cfg.load_model = os.path.join(cfg.root_dir, cfg.dataset, 302 | cfg.data_tag, cfg.pre_name, 'ckpt', 303 | 'fold{}'.format(fold), 304 | 'checkpoint.t7') 305 | model = load_model(model, mem_path, False) 306 | model = load_model(model, cfg.load_model, False) 307 | model.encoder = load_model(model.encoder, cfg.load_model, False) 308 | model.decoder = load_model(model.decoder, cfg.load_model, False) 309 | 310 | if cfg.arch in ['tpnet', 'spnet', 'densest']: 311 | if cfg.need_pretrain: 312 | if cfg.freeze_name is None: 313 | if cfg.arch == 'tpnet': 314 | cfg.freeze_name = ['lstm'] 315 | elif cfg.arch == 'spnet': 316 | cfg.freeze_name = ['memory'] 317 | else: 318 | cfg.freeze_name = ['non_local'] 319 | else: 320 | cfg.freeze_name = cfg.freeze_name.split(',') 321 | train_loader = torch.utils.data.DataLoader( 322 | train_dataset, 323 | batch_size=cfg.batch_size * 2, 324 | shuffle=True, 325 | num_workers=cfg.num_workers, 326 | pin_memory=True, 327 | drop_last=True) 328 | for name, param in model.named_parameters(): 329 | if not name.split('.')[0] in cfg.freeze_name: 330 | param.requires_grad = False 331 | else: 332 | print('{} NOT Freeze'.format(name)) 333 | 334 | cfg.load_model = os.path.join(cfg.root_dir, cfg.dataset, 335 | cfg.data_tag, cfg.pre_name, 336 | 'ckpt', 'fold{}'.format(fold), 337 | 'checkpoint.t7') 338 | assert os.path.exists(cfg.load_model) 339 | if cfg.arch in ['tpnet', 'spnet']: 340 | model.encoder = load_model(model.encoder, cfg.load_model) 341 | model.decoder = load_model(model.decoder, cfg.load_model) 342 | print('Pretrain for {} epochs and save the best weight'.format( 343 | cfg.pretrain_ep)) 344 | for epoch in range(1, cfg.pretrain_ep + 1): 345 | train(epoch) 346 | save_map = val_map(epoch) 347 | if save_map > best: 348 | best = save_map 349 | print(saver.save(model.state_dict(), 'stage1')) 350 | 351 | print( 352 | 'Finished Pretraining, reduce lr to a half and load the best weight' 353 | ) 354 | optimizer = torch.optim.Adam(model.parameters(), 355 | cfg.lr / cfg.decay) 356 | cfg.load_model = os.path.join(cfg.ckpt_dir, 'stage1.t7') 357 | model = load_model(model, cfg.load_model) 358 | for name, param in model.named_parameters(): 359 | param.requires_grad = True 360 | best_ep = cfg.pretrain_ep 361 | best = 0 362 | 363 | else: 364 | cfg.load_model = os.path.join(cfg.ckpt_dir, 'stage1.t7') 365 | assert os.path.exists(cfg.load_model) 366 | model = load_model(model, cfg.load_model) 367 | optimizer = torch.optim.Adam(model.parameters(), 368 | cfg.lr / cfg.decay) 369 | best_ep = cfg.pretrain_ep 370 | best = 0 371 | 372 | train_loader = torch.utils.data.DataLoader(train_dataset, 373 | batch_size=cfg.batch_size, 374 | shuffle=True, 375 | num_workers=cfg.num_workers, 376 | pin_memory=True, 377 | drop_last=True) 378 | if cfg.reset: 379 | cfg.reset_path = os.path.join(cfg.root_dir, cfg.dataset, 380 | cfg.data_tag, 'res50_rf', 'ckpt', 381 | 'fold{}'.format(fold), 382 | 'checkpoint.t7') 383 | model = load_model(model, cfg.reset_path) 384 | best = val_map(cfg.reset_epoch) 385 | best_ep = cfg.reset_epoch 386 | 387 | for epoch in range(best_ep + 1, cfg.num_epochs + 1): 388 | train(epoch) 389 | if cfg.val_interval > 0 and epoch % cfg.val_interval == 0: 390 | save_map = val_map(epoch) 391 | if save_map > best: 392 | best = save_map 393 | best_ep = epoch 394 | print(saver.save(model.state_dict(), 'checkpoint')) 395 | else: 396 | if epoch - best_ep > 30: 397 | break 398 | print(saver.save(model.state_dict(), 'latestcheckpoint')) 399 | summary_writer.close() 400 | 401 | if __name__ == '__main__': 402 | with DisablePrint(local_rank=cfg.local_rank): 403 | main() 404 | --------------------------------------------------------------------------------