├── gitignore ├── .idea ├── .gitignore ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── MSRF_net.iml ├── images └── MSRF-NET.jpg ├── metrics.py ├── README.md ├── test.py ├── demo.py ├── utils.py ├── dataset.py ├── train.py ├── losses.py └── msrf.py /gitignore: -------------------------------------------------------------------------------- 1 | # Ignore log folders 2 | logs/** -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /images/MSRF-NET.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amlarraz/MSRF-Net_PyTorch/HEAD/images/MSRF-NET.jpg -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/MSRF_net.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from losses import one_hot 5 | 6 | 7 | def calculate_dice(pred, msk, eps=1e-6): 8 | 9 | # compute softmax over the classes axis 10 | input_soft = F.softmax(pred, dim=1)[:, 1:] 11 | 12 | # create the labels one hot tensor 13 | target_one_hot = one_hot(msk, num_classes=pred.shape[1], 14 | device=pred.device, dtype=pred.dtype)[:, 1:] 15 | 16 | # compute the actual dice score 17 | dims = (1, 2, 3) 18 | intersection = torch.sum(input_soft * target_one_hot, dims) 19 | cardinality = torch.sum(input_soft + target_one_hot, dims) 20 | 21 | dice_score = 2. * intersection / (cardinality + eps) 22 | 23 | return torch.mean(dice_score) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSRF-Net_PyTorch 2 | 3 | Unofficial code of [MSRF-Net](https://arxiv.org/pdf/2105.07451.pdf) developed in PyTorch 4 | 5 | `------------ IN PROGRESS ------------` 6 | 7 | - [x] Write the model code based on [official TF code](https://github.com/NoviceMAn-prog/MSRF-Net). 8 | - [x] Write the training/evaluation code. 9 | - [x] Improve the training/evaluation code adding some stuff to tensorboard. 10 | - [x] Write the test code. 11 | - [x] Write the inferencing (demo) code. 12 | - [ ] Train and test. 13 | 14 | ## Implementation details 15 | 16 | - PyTorch 1.9.0 was used with cuda 11.1. 17 | - The hyperparameter init_feat was added. It controls the number of initial channels for the UNet. In the original code It was 32. I recommend to use a power of two because the reduction ratio in [Squeeze and Excitation blocks](https://arxiv.org/abs/1709.01507). 18 | - The Shape Stream isn't copy exactly from official code, It was copied from the [original Shape Stream repo.](https://github.com/leftthomas/GatedSCNN) 19 | - Added image visualization during training to TensorBoard. This improvement will help you to check the performance during training. 20 | - ~~During training DICE coefficient (in loss and as a metric) is computed without the BG.~~ 21 | 22 | 23 | ## Model architecture 24 | 25 | ![MSRF-NET diagram](./images/MSRF-NET.jpg) 26 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from torch.utils.data import DataLoader 5 | 6 | from dataset import DataSet 7 | from msrf import MSRF 8 | from losses import CombinedLoss 9 | from metrics import calculate_dice 10 | 11 | 12 | #### IN PROGRESS - IT IS NOT FUNCTIONAL YET ######################### 13 | data_dir = '/media/poto/Gordo1/SegThor' 14 | checkpoint = './logs/SegThor-22_12_2021-15h3m46s/ep-8-val_loss-3.6380-val_dice-0.0406.pt' 15 | n_classes = 5 16 | resize = (256, 256) 17 | batch_size = 3 18 | init_feat = 32 # In the original code it was 32 19 | device = torch.device('cuda:0') 20 | 21 | 22 | dataset_test = DataSet(data_dir, n_classes, mode='test', augmentation=False, resize=resize) 23 | dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=4, 24 | pin_memory=torch.cuda.is_available()) 25 | 26 | model = MSRF(in_ch=1, n_classes=n_classes, init_feat=init_feat) 27 | model.to(device) 28 | model.eval() 29 | 30 | class_weights = dataset_test.class_weights().cuda()#to(device) #REVISAR 31 | criterion = CombinedLoss(class_weights) 32 | 33 | tq = tqdm(total=len(dataloader_test)*batch_size, position=0, leave=True) 34 | tq.set_description('Testing:') 35 | metrics = {'test_loss': [], 'test_dice': []} 36 | for i, (img, canny, msk, canny_label) in enumerate(dataloader_test): 37 | img, canny, msk, canny_label = img.to(device), canny.to(device), msk.to(device), canny_label.to(device) 38 | with torch.no_grad(): 39 | pred_3, pred_canny, pred_1, pred_2 = model(img, canny) 40 | loss = criterion(pred_3, pred_canny, pred_1, pred_2, msk, canny_label) 41 | metrics['test_loss'].append(loss.item()) 42 | dice = calculate_dice(pred_3, msk) 43 | metrics['test_dice'].append(dice.item()) 44 | tq.update(batch_size) 45 | 46 | print('Checkpoint: {}'.format(checkpoint)) 47 | print('Test loss: {}, test dice: {}'.format(np.mean(metrics['test_loss']), np.mean(metrics['test_dice']))) -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch.nn.functional as F 7 | 8 | from dataset import normalize 9 | from msrf import MSRF 10 | from metrics import calculate_dice 11 | 12 | 13 | #### IN PROGRESS - IT IS NOT FUNCTIONAL YET ######################### 14 | data_dir = '/media/poto/Gordo1/SegThor/Images' 15 | save_dir = '/media/poto/Gordo1/SegThor/Inferences' 16 | checkpoint = './logs/SegThor-22_12_2021-15h3m46s/ep-8-val_loss-3.6380-val_dice-0.0406.pt' 17 | n_classes = 5 18 | threshold = None # None of value to threholding the probabilities 19 | resize = (256, 256) 20 | init_feat = 32 # In the original code it was 32 21 | device = torch.device('cuda:0') 22 | 23 | 24 | if not os.path.isdir(save_dir): 25 | os.mkdir(save_dir) 26 | 27 | model = MSRF(in_ch=1, n_classes=n_classes, init_feat=init_feat) 28 | model.to(device) 29 | model.eval() 30 | 31 | image_list = os.listdir(data_dir) 32 | tq = tqdm(total=len(image_list), position=0, leave=True) 33 | tq.set_description('Inferencing:') 34 | for img_name in image_list: 35 | img = cv2.imread(os.path.join(data_dir, img_name), 0) 36 | if resize is not None: 37 | img = cv2.resize(img, resize, interpolation=cv2.INTER_CUBIC) 38 | canny = cv2.Canny(img, 10, 100) 39 | canny = np.asarray(canny, np.float32) 40 | canny /= 255.0 41 | img = normalize(img) 42 | img, canny = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0).to(device), torch.FloatTensor(canny).unsqueeze(0).unsqueeze(0).to(device) 43 | with torch.no_grad(): 44 | pred_3, pred_canny, pred_1, pred_2 = model(img, canny) 45 | pred_3 = F.softmax(pred_3, dim=1)[0] 46 | if threshold is not None: 47 | final_pred = torch.zeros_like(pred_3[0]) 48 | for n_class in range(1, pred_3.shape[0]): 49 | final_pred[pred_3[n_class] >= threshold] = n_class 50 | else: 51 | final_pred = torch.argmax(pred_3, dim=0) 52 | cv2.imwrite(os.path.join(save_dir, img_name), final_pred.detach().cpu().numpy()*(255//n_classes)) 53 | tq.update(1) 54 | tq.close() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | from time import localtime 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | def prepare_writer(dataset_name): 11 | 12 | log_name = '{}-{}_{}_{}-{}h{}m{}s'.format(dataset_name, 13 | localtime().tm_mday, localtime().tm_mon, 14 | localtime().tm_year, localtime().tm_hour, 15 | localtime().tm_min, localtime().tm_sec) 16 | if not os.path.exists('./logs'): 17 | os.mkdir('./logs') 18 | if not os.path.exists(os.path.join('./logs', log_name)): 19 | os.mkdir(os.path.join('./logs', log_name)) 20 | 21 | writer = SummaryWriter(os.path.join('./logs', log_name)) 22 | 23 | return writer, os.path.join('./logs', log_name) 24 | 25 | 26 | def imgs2tb(img, msk, pred, canny, pred_canny, writer, n_img, epoch): 27 | n_classes = pred.shape[1] 28 | # It shows only the 1st img in the batch 29 | # Prepare data 30 | img = img[0, 0].detach().cpu().numpy() 31 | img *= np.ones((img.shape)) * (0.151) # 8-bit std 32 | img += np.ones((img.shape)) * (0.175) # 8-bit mean 33 | msk = msk[0].detach().cpu().numpy()*(255//n_classes) # to show in bright colors 34 | pred = torch.argmax(F.softmax(pred[0], dim=0), dim=0).detach().cpu().numpy()*(255//n_classes) # No threshold->argmax 35 | 36 | canny = canny[0, 0].detach().cpu().numpy()*50 37 | pred_canny = pred_canny[0, 0].detach().cpu().numpy()*50 38 | 39 | final_img = np.concatenate([img, msk, pred, canny, pred_canny], axis=1) 40 | writer.add_image('Image {}'.format(n_img), final_img, epoch, dataformats='HW') 41 | 42 | return None 43 | 44 | 45 | def save_checkpoint(model, optimizer, save_dir, epoch, val_metrics): 46 | file_name = 'ep-{}'.format(epoch + 1) 47 | for key in val_metrics.keys(): 48 | if len(key.split('_')) < 3: 49 | file_name += '-{}-{:.4f}'.format(key, val_metrics[key]) 50 | 51 | save_states = {'model': model.state_dict(), 52 | 'optimizer': optimizer.state_dict(), 53 | 'epoch': epoch} 54 | 55 | torch.save(save_states, os.path.join(save_dir, file_name + '.pt')) 56 | 57 | return None 58 | 59 | 60 | def load_checkpoint(model, optimizer, checkpoint_path, model_name): 61 | checkpoint = torch.load(checkpoint_path) 62 | model.load_state_dict(checkpoint['model']) 63 | optimizer.load_state_dict(checkpoint['optimizer']) 64 | print('Checkpoint for model {} and optimizer loaded from {}, epoch: {}' 65 | .format(model_name, checkpoint_path, checkpoint['epoch'])) 66 | 67 | return model, optimizer -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | import albumentations as albu 6 | from collections import defaultdict 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | random.seed(42) 11 | 12 | 13 | def normalize(img): 14 | if img.dtype == np.uint8: 15 | mean = 0.175 # Mean / max_pixel_value 16 | std = 0.151 # Std / max_pixel_value 17 | max_pixel_value = 255.0 18 | 19 | elif img.dtype == np.uint16: 20 | mean = 0.0575716 21 | std = 0.12446098 22 | max_pixel_value = 65535.0 23 | 24 | img = img.astype(np.float32) / max_pixel_value 25 | img -= np.ones(img.shape) * mean 26 | img /= np.ones(img.shape) * std 27 | 28 | return img 29 | 30 | 31 | class DataSet(Dataset): 32 | def __init__(self, data_dir, n_classes, mode='train', augmentation=True, resize=None): 33 | """ Data_dir must be organized in: 34 | - Images: Folder that contains all the images (.png) in the dataset. 35 | - Masks: Folder that contains all the masks (.png) in the dataset 36 | """ 37 | self.data_dir = data_dir 38 | self.n_classes = n_classes 39 | self.mode = mode 40 | self.augmentation = augmentation 41 | self.resize = resize 42 | percents = {'train': 0.75, 'val': 0.15, 'test': 0.1} 43 | assert mode in percents.keys(), 'Mode is {} and it must be one of: train, val, test'.format(self.mode) 44 | total_imgs = os.listdir(os.path.join(data_dir, 'Images')) 45 | if self.mode == 'train': 46 | self.img_names = total_imgs[:int(percents['train']*len(total_imgs))] 47 | elif self.mode == 'val': 48 | self.img_names = total_imgs[int(percents['train']*len(total_imgs)):int((percents['train']+percents['val'])*len(total_imgs))] 49 | elif self.mode == 'test': 50 | self.img_names = total_imgs[int((percents['train']+percents['val'])*len(total_imgs)):] 51 | 52 | if self.augmentation: 53 | self.augs = albu.OneOf([albu.ElasticTransform(p=0.5, alpha=120, sigma=280 * 0.05, alpha_affine=120 * 0.03), 54 | albu.GridDistortion(p=0.5, border_mode=cv2.BORDER_CONSTANT, distort_limit=0.2), 55 | albu.Rotate(p=0.5, limit=(-5, 5), interpolation=1, border_mode=cv2.BORDER_CONSTANT), 56 | ],) 57 | def __len__(self): 58 | return len(self.img_names) 59 | 60 | 61 | 62 | def class_weights(self): 63 | counts = defaultdict(lambda : 0) 64 | for img_name in self.img_names: 65 | msk = cv2.imread(os.path.join(self.data_dir, 'Masks', img_name), -1) 66 | for i, c in enumerate(range(self.n_classes)): 67 | counts[c] += np.sum(msk == c) 68 | counts = dict(sorted(counts.items())) 69 | weights = [1 - (x/sum(list(counts.values()))) for x in counts.values()] 70 | 71 | return torch.FloatTensor(weights) 72 | 73 | def __getitem__(self, idx): 74 | img = cv2.imread(os.path.join(self.data_dir, 'Images', self.img_names[idx]), 0) 75 | msk = cv2.imread(os.path.join(self.data_dir, 'Masks', self.img_names[idx]), 0) 76 | if self.resize is not None: 77 | img = cv2.resize(img, self.resize, interpolation=cv2.INTER_CUBIC) 78 | msk = cv2.resize(msk, self.resize, interpolation=cv2.INTER_NEAREST) 79 | 80 | if self.augmentation: 81 | augmented = self.augs(image=img, mask=msk) 82 | img = augmented['image'] 83 | msk = augmented['mask'] 84 | 85 | canny = cv2.Canny(img, 10, 100) 86 | canny = np.asarray(canny, np.float32) 87 | canny /= 255.0 88 | 89 | img = normalize(img) 90 | 91 | return torch.FloatTensor(img).unsqueeze(0), torch.FloatTensor(canny).unsqueeze(0), torch.LongTensor(msk), torch.FloatTensor(canny) 92 | 93 | 94 | if __name__== "__main__": 95 | dataset = DataSet('/media/poto/Gordo1/SegThor', 2, 'train', True) 96 | img, canny, msk, canny_label = dataset[0] 97 | print(img.shape, img.min(), img.max()) 98 | print(canny.shape, canny.min(), canny.max()) 99 | print(msk.shape) 100 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from dataset import DataSet 6 | 7 | from msrf import MSRF 8 | from utils import prepare_writer, save_checkpoint, imgs2tb 9 | from losses import CombinedLoss 10 | from metrics import calculate_dice 11 | 12 | # TRAIN PARAMS 13 | dataset_name = 'SegThor' 14 | data_dir = '/media/poto/Gordo1/SegThor' 15 | n_classes = 5 16 | n_img_to_tb = 5 17 | resize = (256, 256) # None or 2-tuple 18 | 19 | n_epochs = 100 20 | batch_size = 3 21 | lr = 1e-4 22 | accumulation_steps = 6 23 | weight_decay = 0.01 24 | device = torch.device('cuda:0') 25 | 26 | init_feat = 32 # In the original code it was 32 27 | 28 | # DATASET 29 | dataset_train = DataSet(data_dir, n_classes, mode='train', augmentation=True, resize=resize) 30 | dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, 31 | pin_memory=torch.cuda.is_available()) 32 | 33 | dataset_val = DataSet(data_dir, n_classes, mode='val', augmentation=False, resize=resize) 34 | dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=4, 35 | pin_memory=torch.cuda.is_available()) 36 | 37 | # MODEL, OPTIM, LR_SCHED, LOSS, LOG 38 | model = MSRF(in_ch=1, n_classes=n_classes, init_feat=init_feat) 39 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8) 40 | class_weights = dataset_train.class_weights().to(device) 41 | criterion = CombinedLoss(class_weights) 42 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, 43 | patience=10, verbose=False) 44 | writer, logdir = prepare_writer(dataset_name) 45 | 46 | print('Logdir: {}'.format(logdir)) 47 | # TRAIN LOOP 48 | model.to(device) 49 | for epoch in range(1, n_epochs+1): 50 | model.train() 51 | metrics = {'train_loss': [], 'train_dice': [], 'val_loss': [], 'val_dice': []} 52 | tq = tqdm(total=len(dataloader_train)*batch_size, position=0, leave=True) 53 | tq.set_description('Train epoch: {}'.format(epoch)) 54 | 55 | for i, (img, canny, msk, canny_label) in enumerate(dataloader_train): 56 | img, canny, msk, canny_label = img.to(device), canny.to(device), msk.to(device), canny_label.to(device) 57 | 58 | pred_3, pred_canny, pred_1, pred_2 = model(img, canny) 59 | # Forward + Backward + Optimize 60 | loss = criterion(pred_3, pred_canny, pred_1, pred_2, msk, canny_label) 61 | loss = loss/accumulation_steps 62 | loss.backward() 63 | # accumulative gradient 64 | if (i + 1) % accumulation_steps == 0: # Wait for several backward steps 65 | optimizer.step() # Now we can do an optimizer step 66 | model.zero_grad() # Reset gradients tensors 67 | 68 | metrics['train_loss'].append(loss.item()) 69 | dice = calculate_dice(pred_3, msk) 70 | metrics['train_dice'].append(dice.item()) 71 | tq.update(batch_size) 72 | 73 | print('Epoch {}: train loss: {}, train dice: {}'.format(epoch, np.mean(metrics['train_loss']), np.mean(metrics['train_dice']))) 74 | writer.add_scalar('train loss', np.mean(metrics['train_loss']), epoch) 75 | writer.add_scalar('train dice', np.mean(metrics['train_dice']), epoch) 76 | 77 | model.eval() 78 | tq = tqdm(total=len(dataloader_val) * batch_size, position=0, leave=True) 79 | tq.set_description('Val epoch: {}'.format(epoch)) 80 | k = 0 81 | for i, (img, canny, msk, canny_label) in enumerate(dataloader_val): 82 | img, canny, msk, canny_label = img.to(device), canny.to(device), msk.to(device), canny_label.to(device) 83 | with torch.no_grad(): 84 | pred_3, pred_canny, pred_1, pred_2 = model(img, canny) 85 | # Forward + Backward + Optimize 86 | loss = criterion(pred_3, pred_canny, pred_1, pred_2, msk, canny_label) 87 | metrics['val_loss'].append(loss.item()) 88 | dice = calculate_dice(pred_3, msk) 89 | metrics['val_dice'].append(dice.item()) 90 | tq.update(batch_size) 91 | 92 | if k < n_img_to_tb: 93 | imgs2tb(img, msk, pred_3, canny, pred_canny, writer, k, epoch+1) 94 | k += 1 95 | 96 | print('Epoch {}: val loss: {}, val dice: {}'.format(epoch, np.mean(metrics['val_loss']), np.mean(metrics['val_dice']))) 97 | writer.add_scalar('val loss', np.mean(metrics['val_loss']), epoch) 98 | writer.add_scalar('val dice', np.mean(metrics['val_dice']), epoch) 99 | 100 | save_checkpoint(model, optimizer, logdir, epoch, {'val_loss': np.mean(metrics['val_loss']), 'val_dice': np.mean(metrics['val_dice'])}) -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | CODE FORKED FROM: 3 | https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/losses/ 4 | """ 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.loss import _Loss, _WeightedLoss 9 | import numpy as np 10 | from torch.autograd import Variable 11 | 12 | 13 | def one_hot(labels, num_classes, device, dtype, eps= 1e-6): 14 | r"""Converts an integer label 2D tensor to a one-hot 3D tensor. 15 | 16 | Args: 17 | labels (torch.Tensor) : tensor with labels of shape :math:`(N, H, W)`, 18 | where N is batch siz. Each value is an integer 19 | representing correct classification. 20 | num_classes (int): number of classes in labels. 21 | device (Optional[torch.device]): the desired device of returned tensor. 22 | Default: if None, uses the current device for the default tensor type 23 | (see torch.set_default_tensor_type()). device will be the CPU for CPU 24 | tensor types and the current CUDA device for CUDA tensor types. 25 | dtype (Optional[torch.dtype]): the desired data type of returned 26 | tensor. Default: if None, infers data type from values. 27 | 28 | Returns: 29 | torch.Tensor: the labels in one hot tensor. 30 | 31 | Examples:: 32 | >>> labels = torch.LongTensor([[[0, 1], [2, 0]]]) 33 | >>> tgm.losses.one_hot(labels, num_classes=3) 34 | tensor([[[[1., 0.], 35 | [0., 1.]], 36 | [[0., 1.], 37 | [0., 0.]], 38 | [[0., 0.], 39 | [1., 0.]]]] 40 | """ 41 | if not torch.is_tensor(labels): 42 | raise TypeError("Input labels type is not a torch.Tensor. Got {}" 43 | .format(type(labels))) 44 | if not len(labels.shape) == 3: 45 | raise ValueError("Invalid depth shape, we expect BxHxW. Got: {}" 46 | .format(labels.shape)) 47 | if not labels.dtype == torch.int64: 48 | raise ValueError( 49 | "labels must be of the same dtype torch.int64. Got: {}" .format( 50 | labels.dtype)) 51 | if num_classes < 1: 52 | raise ValueError("The number of classes must be bigger than one." 53 | " Got: {}".format(num_classes)) 54 | batch_size, height, width = labels.shape 55 | one_hot = torch.zeros(batch_size, num_classes, height, width, 56 | device=device, dtype=dtype) 57 | return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps 58 | 59 | 60 | class DiceLoss(nn.Module): 61 | r"""Criterion that computes Sørensen-Dice Coefficient loss. 62 | 63 | According to [1], we compute the Sørensen-Dice Coefficient as follows: 64 | 65 | .. math:: 66 | 67 | \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|} 68 | 69 | where: 70 | - :math:`X` expects to be the scores of each class. 71 | - :math:`Y` expects to be the one-hot tensor with the class labels. 72 | 73 | the loss, is finally computed as: 74 | 75 | .. math:: 76 | 77 | \text{loss}(x, class) = 1 - \text{Dice}(x, class) 78 | 79 | [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient 80 | 81 | Shape: 82 | - Input: :math:`(N, C, H, W)` where C = number of classes. 83 | - Target: :math:`(N, H, W)` where each value is 84 | :math:`0 ≤ targets[i] ≤ C−1`. 85 | 86 | Examples: 87 | >>> N = 5 # num_classes 88 | >>> loss = tgm.losses.DiceLoss() 89 | >>> input = torch.randn(1, N, 3, 5, requires_grad=True) 90 | >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) 91 | >>> output = loss(input, target) 92 | >>> output.backward() 93 | """ 94 | 95 | def __init__(self, weights) -> None: 96 | super(DiceLoss, self).__init__() 97 | self.eps: float = 1e-6 98 | self.weights = weights 99 | 100 | def forward( 101 | self, 102 | input: torch.Tensor, 103 | target: torch.Tensor) -> torch.Tensor: 104 | if not torch.is_tensor(input): 105 | raise TypeError("Input type is not a torch.Tensor. Got {}" 106 | .format(type(input))) 107 | if not len(input.shape) == 4: 108 | raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}" 109 | .format(input.shape)) 110 | if not input.shape[-2:] == target.shape[-2:]: 111 | raise ValueError("input and target shapes must be the same. Got: {}" 112 | .format(input.shape, input.shape)) 113 | if not input.device == target.device: 114 | raise ValueError( 115 | "input and target must be in the same device. Got: {}" .format( 116 | input.device, target.device)) 117 | # compute softmax over the classes axis 118 | input_soft = F.softmax(input, dim=1)*self.weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) #[:,1:] 119 | 120 | # create the labels one hot tensor 121 | target_one_hot = one_hot(target, num_classes=input.shape[1], 122 | device=input.device, dtype=input.dtype)#[:, 1:] 123 | 124 | # compute the actual dice score 125 | dims = (1, 2, 3) 126 | intersection = torch.sum(input_soft * target_one_hot, dims) 127 | cardinality = torch.sum(input_soft + target_one_hot, dims) 128 | 129 | dice_score = 2. * intersection / (cardinality + self.eps) 130 | return torch.mean(1. - dice_score) 131 | 132 | class CombinedLoss(nn.Module): 133 | def __init__(self, class_weights): 134 | super().__init__() 135 | self.ce_loss = nn.CrossEntropyLoss(class_weights) 136 | self.dice_loss = DiceLoss(class_weights) 137 | self.bce_loss = nn.BCEWithLogitsLoss() 138 | 139 | def forward(self, pred_3, pred_canny, pred_1, pred_2, msk, canny_label): 140 | loss_pred_3 = self.ce_loss(pred_3, msk) + self.dice_loss(pred_3, msk) 141 | loss_pred_1 = self.ce_loss(pred_1, msk) + self.dice_loss(pred_1, msk) 142 | loss_pred_2 = self.ce_loss(pred_2, msk) + self.dice_loss(pred_2, msk) 143 | loss_canny = self.bce_loss(pred_canny, canny_label.unsqueeze(1)) 144 | loss = loss_pred_3 + loss_pred_1 + loss_pred_2 + loss_canny 145 | 146 | return loss 147 | 148 | -------------------------------------------------------------------------------- /msrf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torchvision.models.resnet import BasicBlock 5 | 6 | 7 | # BLOCKS to construct the model 8 | class DSDF_block(nn.Module): #OK 9 | def __init__(self, in_ch_x, in_ch_y, nf1=128, nf2=256, gc=64, bias=True): 10 | super().__init__() 11 | 12 | self.nx1 = nn.Sequential(nn.Conv2d(in_ch_x, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias), 13 | nn.LeakyReLU(negative_slope=0.25)) 14 | 15 | self.ny1 = nn.Sequential(nn.Conv2d(in_ch_y, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias), 16 | nn.LeakyReLU(negative_slope=0.25)) 17 | 18 | self.nx1c = nn.Sequential(nn.Conv2d(in_ch_x, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4, stride 1 -> 2 19 | nn.LeakyReLU(negative_slope=0.25)) 20 | 21 | self.ny1t = nn.Sequential(nn.ConvTranspose2d(in_ch_y, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4 22 | nn.LeakyReLU(negative_slope=0.25)) 23 | 24 | self.nx2 = nn.Sequential(nn.Conv2d(in_ch_x+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias), 25 | nn.LeakyReLU(negative_slope=0.25)) 26 | 27 | self.ny2 = nn.Sequential(nn.Conv2d(in_ch_y+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias), 28 | nn.LeakyReLU(negative_slope=0.25)) 29 | 30 | self.nx2c = nn.Sequential(nn.Conv2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4, stride 1 -> 2 31 | nn.LeakyReLU(negative_slope=0.25)) 32 | 33 | self.ny2t = nn.Sequential(nn.ConvTranspose2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4 34 | nn.LeakyReLU(negative_slope=0.25)) 35 | 36 | self.nx3 = nn.Sequential(nn.Conv2d(in_ch_x+gc+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias), 37 | nn.LeakyReLU(negative_slope=0.25)) 38 | 39 | self.ny3 = nn.Sequential(nn.Conv2d(in_ch_y+gc+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias), 40 | nn.LeakyReLU(negative_slope=0.25)) 41 | 42 | self.nx3c = nn.Sequential(nn.Conv2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4, stride 1 -> 2 43 | nn.LeakyReLU(negative_slope=0.25)) 44 | 45 | self.ny3t = nn.Sequential(nn.ConvTranspose2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4 46 | nn.LeakyReLU(negative_slope=0.25)) 47 | 48 | self.nx4 = nn.Sequential(nn.Conv2d(in_ch_x+gc+gc+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias), 49 | nn.LeakyReLU(negative_slope=0.25)) 50 | 51 | self.ny4 = nn.Sequential(nn.Conv2d(in_ch_y+gc+gc+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias), 52 | nn.LeakyReLU(negative_slope=0.25)) 53 | 54 | self.nx4c = nn.Sequential(nn.Conv2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4, stride 1 -> 2 55 | nn.LeakyReLU(negative_slope=0.25)) 56 | 57 | self.ny4t = nn.Sequential(nn.ConvTranspose2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4 58 | nn.LeakyReLU(negative_slope=0.25)) 59 | 60 | self.nx5 = nn.Sequential(nn.Conv2d(in_ch_x+gc+gc+gc+gc+gc, nf1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias), 61 | nn.LeakyReLU(negative_slope=0.25)) 62 | 63 | self.ny5 = nn.Sequential(nn.Conv2d(in_ch_y+gc+gc+gc+gc+gc, nf2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias), 64 | nn.LeakyReLU(negative_slope=0.25)) 65 | 66 | def forward(self, x, y): 67 | 68 | x1 = self.nx1(x) 69 | y1 = self.ny1(y) 70 | 71 | x1c = self.nx1c(x) 72 | y1t = self.ny1t(y) 73 | 74 | x2_input = torch.cat([x, x1, y1t], dim=1) 75 | x2 = self.nx2(x2_input) 76 | 77 | y2_input = torch.cat([y, y1, x1c], dim=1) 78 | y2 = self.ny2(y2_input) 79 | 80 | x2c = self.nx2c(x1) 81 | y2t = self.ny2t(y1) 82 | 83 | x3_input = torch.cat([x, x1, x2, y2t], dim=1) 84 | x3 = self.nx3(x3_input) 85 | 86 | y3_input = torch.cat([y, y1, y2, x2c], dim=1) 87 | y3 = self.ny3(y3_input) 88 | 89 | x3c = self.nx3c(x3) 90 | y3t = self.ny3t(y3) 91 | 92 | x4_input = torch.cat([x, x1, x2, x3, y3t], dim=1) 93 | x4 = self.nx4(x4_input) 94 | 95 | y4_input = torch.cat([y, y1, y2, y3, x3c], dim=1) 96 | y4 = self.ny4(y4_input) 97 | 98 | x4c = self.nx4c(x4) 99 | y4t = self.ny4t(y4) 100 | 101 | x5_input = torch.cat([x, x1, x2, x3, x4, y4t], dim=1) 102 | x5 = self.nx5(x5_input) 103 | 104 | y5_input = torch.cat([y, y1, y2, y3, y4, x4c], dim=1) 105 | y5 = self.ny5(y5_input) 106 | 107 | x5 *= 0.4 108 | y5 *= 0.4 109 | 110 | return x5+x, y5+y 111 | 112 | 113 | class ATTENTION_block(nn.Module): #OK 114 | def __init__(self, in_ch_x, in_ch_g, med_ch): 115 | super().__init__() 116 | self.theta = nn.Conv2d(in_ch_x, med_ch, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True) 117 | self.phi = nn.Conv2d(in_ch_g, med_ch, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True) 118 | self.block = nn.Sequential(nn.ReLU(), 119 | nn.Conv2d(med_ch, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), 120 | nn.Sigmoid(), 121 | nn.ConvTranspose2d(1, 1, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True)) 122 | self.batchnorm = nn.BatchNorm2d(in_ch_x) 123 | 124 | def forward(self, x, g): 125 | theta = self.theta(x) + self.phi(g) 126 | out = self.batchnorm(self.block(theta) * x) 127 | return out 128 | 129 | 130 | class UP_block(nn.Module): #OK 131 | def __init__(self, input_1_ch, input_2_ch): 132 | super().__init__() 133 | self.up = nn.ConvTranspose2d(input_2_ch, input_1_ch, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True) 134 | 135 | def forward(self, input_1, input_2): 136 | x = torch.cat([self.up(input_2), input_1], dim=1) 137 | return x 138 | 139 | 140 | class SE_block(nn.Module): #OK 141 | def __init__(self, in_ch, ratio=16): 142 | super().__init__() 143 | self.block = nn.Sequential(nn.Linear(in_ch, in_ch//ratio), 144 | nn.ReLU(), 145 | nn.Linear(in_ch//ratio, in_ch), 146 | nn.Sigmoid()) 147 | def forward(self, x): 148 | y = x.mean((-2, -1)) 149 | y = self.block(y).unsqueeze(-1).unsqueeze(-1) 150 | return x*y 151 | 152 | 153 | class SPATIALATT_block(nn.Module): #OK 154 | def __init__(self, in_ch, med_ch): 155 | super().__init__() 156 | self.block = nn.Sequential(nn.Conv2d(in_ch, med_ch, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), 157 | nn.BatchNorm2d(med_ch), 158 | nn.ReLU(), 159 | nn.Conv2d(med_ch, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), 160 | nn.Sigmoid()) 161 | def forward(self, x): 162 | x = self.block(x) 163 | 164 | return x 165 | 166 | 167 | class RES_block(nn.Module): #OK 168 | def __init__(self, in_ch): 169 | super().__init__() 170 | self.block = nn.Sequential(nn.Conv2d(in_ch, in_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True), 171 | nn.BatchNorm2d(in_ch), 172 | nn.ReLU(), 173 | nn.Conv2d(in_ch, in_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1,1), bias=True), 174 | nn.BatchNorm2d(in_ch)) 175 | self.act = nn.ReLU() 176 | 177 | def forward(self, x): 178 | res = self.block(x) 179 | out = self.act(res+x) 180 | 181 | return out 182 | 183 | 184 | class DUALATT_block(nn.Module): #OK 185 | def __init__(self, skip_in_ch, prev_in_ch, out_ch): 186 | super().__init__() 187 | self.prev_block = nn.Sequential(nn.ConvTranspose2d(prev_in_ch, out_ch, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True), 188 | nn.BatchNorm2d(out_ch), 189 | nn.ReLU()) 190 | self.block = nn.Sequential(nn.Conv2d(skip_in_ch+out_ch, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True), 191 | nn.BatchNorm2d(out_ch), 192 | nn.ReLU()) 193 | self.se_block = SE_block(out_ch, ratio=16) 194 | self.spatial_att = SPATIALATT_block(out_ch, out_ch) 195 | 196 | def forward(self, skip, prev): 197 | 198 | prev = self.prev_block(prev) 199 | x = torch.cat([skip, prev], dim=1) 200 | inpt_layer = self.block(x) 201 | se_out = self.se_block(inpt_layer) 202 | sab = self.spatial_att(inpt_layer) + 1 203 | 204 | return sab*se_out 205 | 206 | 207 | class GSC_block(nn.Module): 208 | def __init__(self, in_ch_x, in_ch_y): 209 | super().__init__() 210 | self.block = nn.Sequential(nn.BatchNorm2d(in_ch_x+in_ch_y), 211 | nn.Conv2d(in_ch_x+in_ch_y, in_ch_x+in_ch_y+1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)), #in_ch->out_ch 212 | nn.ReLU(), 213 | nn.Conv2d(in_ch_x+in_ch_y+1, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)), 214 | nn.BatchNorm2d(1), 215 | nn.Sigmoid()) 216 | 217 | def forward(self, x, y): 218 | inpt = torch.cat([x, y], dim=1) 219 | inpt = self.block(inpt) 220 | 221 | return inpt 222 | 223 | ## SHAPE-STREAM 224 | # https://github.com/leftthomas/GatedSCNN/blob/master/model.py 225 | 226 | class GatedConv(nn.Conv2d): 227 | def __init__(self, in_channels, out_channels): 228 | super().__init__(in_channels, out_channels, 1, bias=False) 229 | self.attention = nn.Sequential( 230 | nn.BatchNorm2d(in_channels + 1), 231 | nn.Conv2d(in_channels + 1, in_channels + 1, 1), 232 | nn.ReLU(), 233 | nn.Conv2d(in_channels + 1, 1, 1), 234 | nn.BatchNorm2d(1), 235 | nn.Sigmoid() 236 | ) 237 | 238 | def forward(self, feat, gate): 239 | attention = self.attention(torch.cat((feat, gate), dim=1)) 240 | out = F.conv2d(feat * (attention + 1), self.weight) 241 | return out 242 | 243 | class ShapeStream(nn.Module): 244 | def __init__(self, init_feat): 245 | super().__init__() 246 | self.res2_conv = nn.Conv2d(init_feat * 2, 1, 1) 247 | self.res3_conv = nn.Conv2d(init_feat * 4, 1, 1) 248 | self.res4_conv = nn.Conv2d(init_feat * 8, 1, 1) 249 | self.res1 = BasicBlock(init_feat, init_feat, 1) 250 | self.res2 = BasicBlock(32, 32, 1) 251 | self.res3 = BasicBlock(16, 16, 1) 252 | self.res1_pre = nn.Conv2d(init_feat, 32, 1) 253 | self.res2_pre = nn.Conv2d(32, 16, 1) 254 | self.res3_pre = nn.Conv2d(16, 8, 1) 255 | self.gate1 = GatedConv(32, 32) 256 | self.gate2 = GatedConv(16, 16) 257 | self.gate3 = GatedConv(8, 8) 258 | self.gate = nn.Conv2d(8, 1, 1, bias=False) 259 | self.fuse = nn.Conv2d(2, 1, 1, bias=False) 260 | 261 | def forward(self, x, res2, res3, res4, grad): #def forward(self, x, res2, res3, res4, grad): 262 | size = grad.size()[-2:] 263 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 264 | res2 = F.interpolate(self.res2_conv(res2), size, mode='bilinear', align_corners=True) 265 | res3 = F.interpolate(self.res3_conv(res3), size, mode='bilinear', align_corners=True) 266 | res4 = F.interpolate(self.res4_conv(res4), size, mode='bilinear', align_corners=True) 267 | gate1 = self.gate1(self.res1_pre(self.res1(x)), res2) 268 | gate2 = self.gate2(self.res2_pre(self.res2(gate1)), res3) 269 | gate3 = self.gate3(self.res3_pre(self.res3(gate2)), res4) 270 | gate = torch.sigmoid(self.gate(gate3)) 271 | #gate = torch.sigmoid(self.gate(gate2)) 272 | feat = torch.sigmoid(self.fuse(torch.cat((gate, grad), dim=1))) 273 | return gate, feat 274 | 275 | 276 | 277 | # MODEL (NOT CHECKED) 278 | class MSRF(nn.Module): 279 | def __init__(self, in_ch, n_classes, init_feat=32): 280 | super().__init__() 281 | 282 | # ENCODER ---------------------------- 283 | self.n11 = nn.Sequential(nn.Conv2d(in_ch, init_feat, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 284 | nn.ReLU(), 285 | nn.Conv2d(init_feat, init_feat, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 286 | nn.ReLU(), 287 | nn.BatchNorm2d(init_feat), 288 | SE_block(init_feat, ratio=init_feat//2) 289 | ) 290 | 291 | self.n21 = nn.Sequential(nn.MaxPool2d(kernel_size=(2, 2)), 292 | nn.Dropout(0.2), 293 | nn.Conv2d(init_feat, init_feat*2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 294 | nn.ReLU(), 295 | nn.Conv2d(init_feat*2, init_feat*2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 296 | nn.ReLU(), 297 | nn.BatchNorm2d(init_feat*2), 298 | SE_block(init_feat*2, ratio=init_feat//2)) 299 | 300 | self.n31 = nn.Sequential(nn.MaxPool2d(kernel_size=(2, 2)), 301 | nn.Dropout(0.2), 302 | nn.Conv2d(init_feat*2, init_feat*4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 303 | nn.ReLU(), 304 | nn.Conv2d(init_feat*4, init_feat*4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 305 | nn.ReLU(), 306 | nn.BatchNorm2d(init_feat*4), 307 | SE_block(init_feat*4, ratio=init_feat//2)) 308 | 309 | self.n41 = nn.Sequential(nn.MaxPool2d(kernel_size=(2, 2)), 310 | nn.Dropout(0.2), 311 | nn.Conv2d(init_feat*4, init_feat*8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 312 | nn.ReLU(), 313 | nn.Conv2d(init_feat*8, init_feat*8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 314 | nn.ReLU(), 315 | nn.BatchNorm2d(init_feat*8)) 316 | # MSRF-subnetwork ---------------------------- 317 | self.dsfs_1 = DSDF_block(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2) 318 | self.dsfs_2 = DSDF_block(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2) 319 | self.dsfs_3 = DSDF_block(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2) 320 | self.dsfs_4 = DSDF_block(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2) 321 | self.dsfs_5 = DSDF_block(init_feat*2, init_feat*4, nf1=init_feat*2, nf2=init_feat*4, gc=init_feat*2//2) 322 | self.dsfs_6 = DSDF_block(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2) 323 | self.dsfs_7 = DSDF_block(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2) 324 | self.dsfs_8 = DSDF_block(init_feat*2, init_feat*4, nf1=init_feat*2, nf2=init_feat*4, gc=init_feat*2//2) 325 | self.dsfs_9 = DSDF_block(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2) 326 | self.dsfs_10 = DSDF_block(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2) 327 | 328 | # SHAPE STREAM ------------IN PROGRESS------------------- 329 | self.shape_stream = ShapeStream(init_feat) 330 | 331 | # DECODER 332 | # Stage 1: 333 | self.att_1 = ATTENTION_block(init_feat*4, init_feat*8, init_feat*8) 334 | self.up_1 = UP_block(init_feat*4, init_feat*8) 335 | self.dualatt_1 = DUALATT_block(init_feat*4, init_feat*8, init_feat*4) 336 | self.n34_t = nn.Conv2d(init_feat * 4 + init_feat * 8, init_feat * 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)) 337 | self.dec_block_1 = nn.Sequential(nn.BatchNorm2d(init_feat*4), 338 | nn.ReLU(), 339 | nn.Conv2d(init_feat*4, init_feat*4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 340 | nn.BatchNorm2d(init_feat*4), 341 | nn.ReLU(), 342 | nn.Conv2d(init_feat*4, init_feat*4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 343 | ) 344 | self.head_dec_1 = nn.Sequential(nn.Conv2d(init_feat*4, n_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)), 345 | nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)) 346 | 347 | # Stage 2: 348 | self.att_2 = ATTENTION_block(init_feat * 2, init_feat * 4, init_feat * 2) 349 | self.up_2 = UP_block(init_feat * 2, init_feat * 4) 350 | self.dualatt_2 = DUALATT_block(init_feat * 2, init_feat * 4, init_feat * 2) 351 | self.n24_t = nn.Conv2d(init_feat * 2 + init_feat * 4, init_feat * 2, kernel_size=(1, 1), stride=(1, 1), padding=(0,0)) 352 | self.dec_block_2 = nn.Sequential(nn.BatchNorm2d(init_feat * 2), 353 | nn.ReLU(), 354 | nn.Conv2d(init_feat * 2, init_feat * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 355 | nn.BatchNorm2d(init_feat * 2), 356 | nn.ReLU(), 357 | nn.Conv2d(init_feat*2, init_feat * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 358 | ) 359 | self.head_dec_2 = nn.Sequential(nn.Conv2d(init_feat * 2, n_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)), 360 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 361 | 362 | # Stage 3: 363 | self.up_3 = nn.ConvTranspose2d(init_feat * 2, init_feat, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 364 | self.n14_input = nn.Sequential(nn.Conv2d(init_feat + init_feat + 1, init_feat, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)), 365 | nn.ReLU()) 366 | self.dec_block_3 = nn.Sequential(nn.Conv2d(init_feat, init_feat, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 367 | nn.ReLU(), 368 | nn.BatchNorm2d(init_feat)) 369 | 370 | self.head_dec_3 = nn.Sequential(nn.Conv2d(init_feat, init_feat, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 371 | nn.ReLU(), 372 | nn.Conv2d(init_feat, n_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))) 373 | 374 | def forward(self, x, canny): 375 | # ENCODER: 376 | x11 = self.n11(x) 377 | x21 = self.n21(x11) 378 | x31 = self.n31(x21) 379 | x41 = self.n41(x31) 380 | 381 | # MSRF-subnetwork 382 | x12, x22 = self.dsfs_1(x11, x21) 383 | x32, x42 = self.dsfs_2(x31, x41) 384 | x12, x22 = self.dsfs_3(x12, x22) 385 | x32, x42 = self.dsfs_4(x32, x42) 386 | x22, x32 = self.dsfs_5(x22, x32) 387 | x13, x23 = self.dsfs_6(x12, x22) 388 | x33, x43 = self.dsfs_7(x32, x42) 389 | x23, x33 = self.dsfs_8(x23, x33) 390 | x13, x23 = self.dsfs_9(x13, x23) 391 | x33, x43 = self.dsfs_10(x33, x43) 392 | 393 | x13 = (x13*0.4) + x11 394 | x23 = (x23*0.4) + x21 395 | x33 = (x33*0.4) + x31 396 | x43 = (x43*0.4) + x41 397 | 398 | # SHAPE STREAM 399 | # https://github.com/leftthomas/GatedSCNN (https://arxiv.org/pdf/1907.05740.pdf) 400 | canny_gate, canny_feat = self.shape_stream(x13, x23, x33, x43, canny) 401 | 402 | # DECODER 403 | # Stage 1: 404 | x34_preinput = self.att_1(x33, x43) 405 | 406 | x34 = self.up_1(x34_preinput, x43) 407 | x34_t = self.dualatt_1(x33, x43) 408 | x34_t = torch.cat([x34, x34_t], dim=1) 409 | x34_t = self.n34_t(x34_t) 410 | x34 = self.dec_block_1(x34_t) + x34_t 411 | 412 | pred_1 = self.head_dec_1(x34) 413 | 414 | # Stage 2: 415 | x24_preinput = self.att_2(x23, x34) 416 | x24 = self.up_2(x24_preinput, x34) 417 | x24_t = self.dualatt_2(x23, x34) 418 | x24_t = torch.cat([x24, x24_t], dim=1) 419 | x24_t = self.n24_t(x24_t) 420 | x24 = self.dec_block_2(x24_t) + x24_t 421 | 422 | pred_2 = self.head_dec_2(x24) 423 | 424 | # Stage 3: 425 | x14_preinput = self.up_3(x24) 426 | x14_input = torch.cat([x14_preinput, x13, canny_feat], dim=1) 427 | x14_input = self.n14_input(x14_input) 428 | x14 = self.dec_block_3(x14_input) 429 | x14 = x14 + x14_input 430 | pred_3 = self.head_dec_3(x14) 431 | 432 | return pred_3, canny_gate, pred_1, pred_2 433 | 434 | if __name__=="__main__": 435 | model = MSRF(1, 3, init_feat=32) 436 | x = torch.randn((2, 1, 128, 128)) 437 | canny = torch.randn((2, 1, 128, 128)) 438 | out = model(x, canny) 439 | for o in out: 440 | print(o.shape) --------------------------------------------------------------------------------