├── gitignore
├── .idea
├── .gitignore
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── MSRF_net.iml
├── images
└── MSRF-NET.jpg
├── metrics.py
├── README.md
├── test.py
├── demo.py
├── utils.py
├── dataset.py
├── train.py
├── losses.py
└── msrf.py
/gitignore:
--------------------------------------------------------------------------------
1 | # Ignore log folders
2 | logs/**
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/images/MSRF-NET.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amlarraz/MSRF-Net_PyTorch/HEAD/images/MSRF-NET.jpg
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/MSRF_net.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from losses import one_hot
5 |
6 |
7 | def calculate_dice(pred, msk, eps=1e-6):
8 |
9 | # compute softmax over the classes axis
10 | input_soft = F.softmax(pred, dim=1)[:, 1:]
11 |
12 | # create the labels one hot tensor
13 | target_one_hot = one_hot(msk, num_classes=pred.shape[1],
14 | device=pred.device, dtype=pred.dtype)[:, 1:]
15 |
16 | # compute the actual dice score
17 | dims = (1, 2, 3)
18 | intersection = torch.sum(input_soft * target_one_hot, dims)
19 | cardinality = torch.sum(input_soft + target_one_hot, dims)
20 |
21 | dice_score = 2. * intersection / (cardinality + eps)
22 |
23 | return torch.mean(dice_score)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MSRF-Net_PyTorch
2 |
3 | Unofficial code of [MSRF-Net](https://arxiv.org/pdf/2105.07451.pdf) developed in PyTorch
4 |
5 | `------------ IN PROGRESS ------------`
6 |
7 | - [x] Write the model code based on [official TF code](https://github.com/NoviceMAn-prog/MSRF-Net).
8 | - [x] Write the training/evaluation code.
9 | - [x] Improve the training/evaluation code adding some stuff to tensorboard.
10 | - [x] Write the test code.
11 | - [x] Write the inferencing (demo) code.
12 | - [ ] Train and test.
13 |
14 | ## Implementation details
15 |
16 | - PyTorch 1.9.0 was used with cuda 11.1.
17 | - The hyperparameter init_feat was added. It controls the number of initial channels for the UNet. In the original code It was 32. I recommend to use a power of two because the reduction ratio in [Squeeze and Excitation blocks](https://arxiv.org/abs/1709.01507).
18 | - The Shape Stream isn't copy exactly from official code, It was copied from the [original Shape Stream repo.](https://github.com/leftthomas/GatedSCNN)
19 | - Added image visualization during training to TensorBoard. This improvement will help you to check the performance during training.
20 | - ~~During training DICE coefficient (in loss and as a metric) is computed without the BG.~~
21 |
22 |
23 | ## Model architecture
24 |
25 | 
26 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tqdm import tqdm
4 | from torch.utils.data import DataLoader
5 |
6 | from dataset import DataSet
7 | from msrf import MSRF
8 | from losses import CombinedLoss
9 | from metrics import calculate_dice
10 |
11 |
12 | #### IN PROGRESS - IT IS NOT FUNCTIONAL YET #########################
13 | data_dir = '/media/poto/Gordo1/SegThor'
14 | checkpoint = './logs/SegThor-22_12_2021-15h3m46s/ep-8-val_loss-3.6380-val_dice-0.0406.pt'
15 | n_classes = 5
16 | resize = (256, 256)
17 | batch_size = 3
18 | init_feat = 32 # In the original code it was 32
19 | device = torch.device('cuda:0')
20 |
21 |
22 | dataset_test = DataSet(data_dir, n_classes, mode='test', augmentation=False, resize=resize)
23 | dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=4,
24 | pin_memory=torch.cuda.is_available())
25 |
26 | model = MSRF(in_ch=1, n_classes=n_classes, init_feat=init_feat)
27 | model.to(device)
28 | model.eval()
29 |
30 | class_weights = dataset_test.class_weights().cuda()#to(device) #REVISAR
31 | criterion = CombinedLoss(class_weights)
32 |
33 | tq = tqdm(total=len(dataloader_test)*batch_size, position=0, leave=True)
34 | tq.set_description('Testing:')
35 | metrics = {'test_loss': [], 'test_dice': []}
36 | for i, (img, canny, msk, canny_label) in enumerate(dataloader_test):
37 | img, canny, msk, canny_label = img.to(device), canny.to(device), msk.to(device), canny_label.to(device)
38 | with torch.no_grad():
39 | pred_3, pred_canny, pred_1, pred_2 = model(img, canny)
40 | loss = criterion(pred_3, pred_canny, pred_1, pred_2, msk, canny_label)
41 | metrics['test_loss'].append(loss.item())
42 | dice = calculate_dice(pred_3, msk)
43 | metrics['test_dice'].append(dice.item())
44 | tq.update(batch_size)
45 |
46 | print('Checkpoint: {}'.format(checkpoint))
47 | print('Test loss: {}, test dice: {}'.format(np.mean(metrics['test_loss']), np.mean(metrics['test_dice'])))
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | import torch.nn.functional as F
7 |
8 | from dataset import normalize
9 | from msrf import MSRF
10 | from metrics import calculate_dice
11 |
12 |
13 | #### IN PROGRESS - IT IS NOT FUNCTIONAL YET #########################
14 | data_dir = '/media/poto/Gordo1/SegThor/Images'
15 | save_dir = '/media/poto/Gordo1/SegThor/Inferences'
16 | checkpoint = './logs/SegThor-22_12_2021-15h3m46s/ep-8-val_loss-3.6380-val_dice-0.0406.pt'
17 | n_classes = 5
18 | threshold = None # None of value to threholding the probabilities
19 | resize = (256, 256)
20 | init_feat = 32 # In the original code it was 32
21 | device = torch.device('cuda:0')
22 |
23 |
24 | if not os.path.isdir(save_dir):
25 | os.mkdir(save_dir)
26 |
27 | model = MSRF(in_ch=1, n_classes=n_classes, init_feat=init_feat)
28 | model.to(device)
29 | model.eval()
30 |
31 | image_list = os.listdir(data_dir)
32 | tq = tqdm(total=len(image_list), position=0, leave=True)
33 | tq.set_description('Inferencing:')
34 | for img_name in image_list:
35 | img = cv2.imread(os.path.join(data_dir, img_name), 0)
36 | if resize is not None:
37 | img = cv2.resize(img, resize, interpolation=cv2.INTER_CUBIC)
38 | canny = cv2.Canny(img, 10, 100)
39 | canny = np.asarray(canny, np.float32)
40 | canny /= 255.0
41 | img = normalize(img)
42 | img, canny = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0).to(device), torch.FloatTensor(canny).unsqueeze(0).unsqueeze(0).to(device)
43 | with torch.no_grad():
44 | pred_3, pred_canny, pred_1, pred_2 = model(img, canny)
45 | pred_3 = F.softmax(pred_3, dim=1)[0]
46 | if threshold is not None:
47 | final_pred = torch.zeros_like(pred_3[0])
48 | for n_class in range(1, pred_3.shape[0]):
49 | final_pred[pred_3[n_class] >= threshold] = n_class
50 | else:
51 | final_pred = torch.argmax(pred_3, dim=0)
52 | cv2.imwrite(os.path.join(save_dir, img_name), final_pred.detach().cpu().numpy()*(255//n_classes))
53 | tq.update(1)
54 | tq.close()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import numpy as np
4 | from time import localtime
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 |
10 | def prepare_writer(dataset_name):
11 |
12 | log_name = '{}-{}_{}_{}-{}h{}m{}s'.format(dataset_name,
13 | localtime().tm_mday, localtime().tm_mon,
14 | localtime().tm_year, localtime().tm_hour,
15 | localtime().tm_min, localtime().tm_sec)
16 | if not os.path.exists('./logs'):
17 | os.mkdir('./logs')
18 | if not os.path.exists(os.path.join('./logs', log_name)):
19 | os.mkdir(os.path.join('./logs', log_name))
20 |
21 | writer = SummaryWriter(os.path.join('./logs', log_name))
22 |
23 | return writer, os.path.join('./logs', log_name)
24 |
25 |
26 | def imgs2tb(img, msk, pred, canny, pred_canny, writer, n_img, epoch):
27 | n_classes = pred.shape[1]
28 | # It shows only the 1st img in the batch
29 | # Prepare data
30 | img = img[0, 0].detach().cpu().numpy()
31 | img *= np.ones((img.shape)) * (0.151) # 8-bit std
32 | img += np.ones((img.shape)) * (0.175) # 8-bit mean
33 | msk = msk[0].detach().cpu().numpy()*(255//n_classes) # to show in bright colors
34 | pred = torch.argmax(F.softmax(pred[0], dim=0), dim=0).detach().cpu().numpy()*(255//n_classes) # No threshold->argmax
35 |
36 | canny = canny[0, 0].detach().cpu().numpy()*50
37 | pred_canny = pred_canny[0, 0].detach().cpu().numpy()*50
38 |
39 | final_img = np.concatenate([img, msk, pred, canny, pred_canny], axis=1)
40 | writer.add_image('Image {}'.format(n_img), final_img, epoch, dataformats='HW')
41 |
42 | return None
43 |
44 |
45 | def save_checkpoint(model, optimizer, save_dir, epoch, val_metrics):
46 | file_name = 'ep-{}'.format(epoch + 1)
47 | for key in val_metrics.keys():
48 | if len(key.split('_')) < 3:
49 | file_name += '-{}-{:.4f}'.format(key, val_metrics[key])
50 |
51 | save_states = {'model': model.state_dict(),
52 | 'optimizer': optimizer.state_dict(),
53 | 'epoch': epoch}
54 |
55 | torch.save(save_states, os.path.join(save_dir, file_name + '.pt'))
56 |
57 | return None
58 |
59 |
60 | def load_checkpoint(model, optimizer, checkpoint_path, model_name):
61 | checkpoint = torch.load(checkpoint_path)
62 | model.load_state_dict(checkpoint['model'])
63 | optimizer.load_state_dict(checkpoint['optimizer'])
64 | print('Checkpoint for model {} and optimizer loaded from {}, epoch: {}'
65 | .format(model_name, checkpoint_path, checkpoint['epoch']))
66 |
67 | return model, optimizer
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import numpy as np
5 | import albumentations as albu
6 | from collections import defaultdict
7 | import torch
8 | from torch.utils.data import Dataset
9 |
10 | random.seed(42)
11 |
12 |
13 | def normalize(img):
14 | if img.dtype == np.uint8:
15 | mean = 0.175 # Mean / max_pixel_value
16 | std = 0.151 # Std / max_pixel_value
17 | max_pixel_value = 255.0
18 |
19 | elif img.dtype == np.uint16:
20 | mean = 0.0575716
21 | std = 0.12446098
22 | max_pixel_value = 65535.0
23 |
24 | img = img.astype(np.float32) / max_pixel_value
25 | img -= np.ones(img.shape) * mean
26 | img /= np.ones(img.shape) * std
27 |
28 | return img
29 |
30 |
31 | class DataSet(Dataset):
32 | def __init__(self, data_dir, n_classes, mode='train', augmentation=True, resize=None):
33 | """ Data_dir must be organized in:
34 | - Images: Folder that contains all the images (.png) in the dataset.
35 | - Masks: Folder that contains all the masks (.png) in the dataset
36 | """
37 | self.data_dir = data_dir
38 | self.n_classes = n_classes
39 | self.mode = mode
40 | self.augmentation = augmentation
41 | self.resize = resize
42 | percents = {'train': 0.75, 'val': 0.15, 'test': 0.1}
43 | assert mode in percents.keys(), 'Mode is {} and it must be one of: train, val, test'.format(self.mode)
44 | total_imgs = os.listdir(os.path.join(data_dir, 'Images'))
45 | if self.mode == 'train':
46 | self.img_names = total_imgs[:int(percents['train']*len(total_imgs))]
47 | elif self.mode == 'val':
48 | self.img_names = total_imgs[int(percents['train']*len(total_imgs)):int((percents['train']+percents['val'])*len(total_imgs))]
49 | elif self.mode == 'test':
50 | self.img_names = total_imgs[int((percents['train']+percents['val'])*len(total_imgs)):]
51 |
52 | if self.augmentation:
53 | self.augs = albu.OneOf([albu.ElasticTransform(p=0.5, alpha=120, sigma=280 * 0.05, alpha_affine=120 * 0.03),
54 | albu.GridDistortion(p=0.5, border_mode=cv2.BORDER_CONSTANT, distort_limit=0.2),
55 | albu.Rotate(p=0.5, limit=(-5, 5), interpolation=1, border_mode=cv2.BORDER_CONSTANT),
56 | ],)
57 | def __len__(self):
58 | return len(self.img_names)
59 |
60 |
61 |
62 | def class_weights(self):
63 | counts = defaultdict(lambda : 0)
64 | for img_name in self.img_names:
65 | msk = cv2.imread(os.path.join(self.data_dir, 'Masks', img_name), -1)
66 | for i, c in enumerate(range(self.n_classes)):
67 | counts[c] += np.sum(msk == c)
68 | counts = dict(sorted(counts.items()))
69 | weights = [1 - (x/sum(list(counts.values()))) for x in counts.values()]
70 |
71 | return torch.FloatTensor(weights)
72 |
73 | def __getitem__(self, idx):
74 | img = cv2.imread(os.path.join(self.data_dir, 'Images', self.img_names[idx]), 0)
75 | msk = cv2.imread(os.path.join(self.data_dir, 'Masks', self.img_names[idx]), 0)
76 | if self.resize is not None:
77 | img = cv2.resize(img, self.resize, interpolation=cv2.INTER_CUBIC)
78 | msk = cv2.resize(msk, self.resize, interpolation=cv2.INTER_NEAREST)
79 |
80 | if self.augmentation:
81 | augmented = self.augs(image=img, mask=msk)
82 | img = augmented['image']
83 | msk = augmented['mask']
84 |
85 | canny = cv2.Canny(img, 10, 100)
86 | canny = np.asarray(canny, np.float32)
87 | canny /= 255.0
88 |
89 | img = normalize(img)
90 |
91 | return torch.FloatTensor(img).unsqueeze(0), torch.FloatTensor(canny).unsqueeze(0), torch.LongTensor(msk), torch.FloatTensor(canny)
92 |
93 |
94 | if __name__== "__main__":
95 | dataset = DataSet('/media/poto/Gordo1/SegThor', 2, 'train', True)
96 | img, canny, msk, canny_label = dataset[0]
97 | print(img.shape, img.min(), img.max())
98 | print(canny.shape, canny.min(), canny.max())
99 | print(msk.shape)
100 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from dataset import DataSet
6 |
7 | from msrf import MSRF
8 | from utils import prepare_writer, save_checkpoint, imgs2tb
9 | from losses import CombinedLoss
10 | from metrics import calculate_dice
11 |
12 | # TRAIN PARAMS
13 | dataset_name = 'SegThor'
14 | data_dir = '/media/poto/Gordo1/SegThor'
15 | n_classes = 5
16 | n_img_to_tb = 5
17 | resize = (256, 256) # None or 2-tuple
18 |
19 | n_epochs = 100
20 | batch_size = 3
21 | lr = 1e-4
22 | accumulation_steps = 6
23 | weight_decay = 0.01
24 | device = torch.device('cuda:0')
25 |
26 | init_feat = 32 # In the original code it was 32
27 |
28 | # DATASET
29 | dataset_train = DataSet(data_dir, n_classes, mode='train', augmentation=True, resize=resize)
30 | dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4,
31 | pin_memory=torch.cuda.is_available())
32 |
33 | dataset_val = DataSet(data_dir, n_classes, mode='val', augmentation=False, resize=resize)
34 | dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=4,
35 | pin_memory=torch.cuda.is_available())
36 |
37 | # MODEL, OPTIM, LR_SCHED, LOSS, LOG
38 | model = MSRF(in_ch=1, n_classes=n_classes, init_feat=init_feat)
39 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8)
40 | class_weights = dataset_train.class_weights().to(device)
41 | criterion = CombinedLoss(class_weights)
42 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
43 | patience=10, verbose=False)
44 | writer, logdir = prepare_writer(dataset_name)
45 |
46 | print('Logdir: {}'.format(logdir))
47 | # TRAIN LOOP
48 | model.to(device)
49 | for epoch in range(1, n_epochs+1):
50 | model.train()
51 | metrics = {'train_loss': [], 'train_dice': [], 'val_loss': [], 'val_dice': []}
52 | tq = tqdm(total=len(dataloader_train)*batch_size, position=0, leave=True)
53 | tq.set_description('Train epoch: {}'.format(epoch))
54 |
55 | for i, (img, canny, msk, canny_label) in enumerate(dataloader_train):
56 | img, canny, msk, canny_label = img.to(device), canny.to(device), msk.to(device), canny_label.to(device)
57 |
58 | pred_3, pred_canny, pred_1, pred_2 = model(img, canny)
59 | # Forward + Backward + Optimize
60 | loss = criterion(pred_3, pred_canny, pred_1, pred_2, msk, canny_label)
61 | loss = loss/accumulation_steps
62 | loss.backward()
63 | # accumulative gradient
64 | if (i + 1) % accumulation_steps == 0: # Wait for several backward steps
65 | optimizer.step() # Now we can do an optimizer step
66 | model.zero_grad() # Reset gradients tensors
67 |
68 | metrics['train_loss'].append(loss.item())
69 | dice = calculate_dice(pred_3, msk)
70 | metrics['train_dice'].append(dice.item())
71 | tq.update(batch_size)
72 |
73 | print('Epoch {}: train loss: {}, train dice: {}'.format(epoch, np.mean(metrics['train_loss']), np.mean(metrics['train_dice'])))
74 | writer.add_scalar('train loss', np.mean(metrics['train_loss']), epoch)
75 | writer.add_scalar('train dice', np.mean(metrics['train_dice']), epoch)
76 |
77 | model.eval()
78 | tq = tqdm(total=len(dataloader_val) * batch_size, position=0, leave=True)
79 | tq.set_description('Val epoch: {}'.format(epoch))
80 | k = 0
81 | for i, (img, canny, msk, canny_label) in enumerate(dataloader_val):
82 | img, canny, msk, canny_label = img.to(device), canny.to(device), msk.to(device), canny_label.to(device)
83 | with torch.no_grad():
84 | pred_3, pred_canny, pred_1, pred_2 = model(img, canny)
85 | # Forward + Backward + Optimize
86 | loss = criterion(pred_3, pred_canny, pred_1, pred_2, msk, canny_label)
87 | metrics['val_loss'].append(loss.item())
88 | dice = calculate_dice(pred_3, msk)
89 | metrics['val_dice'].append(dice.item())
90 | tq.update(batch_size)
91 |
92 | if k < n_img_to_tb:
93 | imgs2tb(img, msk, pred_3, canny, pred_canny, writer, k, epoch+1)
94 | k += 1
95 |
96 | print('Epoch {}: val loss: {}, val dice: {}'.format(epoch, np.mean(metrics['val_loss']), np.mean(metrics['val_dice'])))
97 | writer.add_scalar('val loss', np.mean(metrics['val_loss']), epoch)
98 | writer.add_scalar('val dice', np.mean(metrics['val_dice']), epoch)
99 |
100 | save_checkpoint(model, optimizer, logdir, epoch, {'val_loss': np.mean(metrics['val_loss']), 'val_dice': np.mean(metrics['val_dice'])})
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | CODE FORKED FROM:
3 | https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/losses/
4 | """
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from torch.nn.modules.loss import _Loss, _WeightedLoss
9 | import numpy as np
10 | from torch.autograd import Variable
11 |
12 |
13 | def one_hot(labels, num_classes, device, dtype, eps= 1e-6):
14 | r"""Converts an integer label 2D tensor to a one-hot 3D tensor.
15 |
16 | Args:
17 | labels (torch.Tensor) : tensor with labels of shape :math:`(N, H, W)`,
18 | where N is batch siz. Each value is an integer
19 | representing correct classification.
20 | num_classes (int): number of classes in labels.
21 | device (Optional[torch.device]): the desired device of returned tensor.
22 | Default: if None, uses the current device for the default tensor type
23 | (see torch.set_default_tensor_type()). device will be the CPU for CPU
24 | tensor types and the current CUDA device for CUDA tensor types.
25 | dtype (Optional[torch.dtype]): the desired data type of returned
26 | tensor. Default: if None, infers data type from values.
27 |
28 | Returns:
29 | torch.Tensor: the labels in one hot tensor.
30 |
31 | Examples::
32 | >>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
33 | >>> tgm.losses.one_hot(labels, num_classes=3)
34 | tensor([[[[1., 0.],
35 | [0., 1.]],
36 | [[0., 1.],
37 | [0., 0.]],
38 | [[0., 0.],
39 | [1., 0.]]]]
40 | """
41 | if not torch.is_tensor(labels):
42 | raise TypeError("Input labels type is not a torch.Tensor. Got {}"
43 | .format(type(labels)))
44 | if not len(labels.shape) == 3:
45 | raise ValueError("Invalid depth shape, we expect BxHxW. Got: {}"
46 | .format(labels.shape))
47 | if not labels.dtype == torch.int64:
48 | raise ValueError(
49 | "labels must be of the same dtype torch.int64. Got: {}" .format(
50 | labels.dtype))
51 | if num_classes < 1:
52 | raise ValueError("The number of classes must be bigger than one."
53 | " Got: {}".format(num_classes))
54 | batch_size, height, width = labels.shape
55 | one_hot = torch.zeros(batch_size, num_classes, height, width,
56 | device=device, dtype=dtype)
57 | return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps
58 |
59 |
60 | class DiceLoss(nn.Module):
61 | r"""Criterion that computes Sørensen-Dice Coefficient loss.
62 |
63 | According to [1], we compute the Sørensen-Dice Coefficient as follows:
64 |
65 | .. math::
66 |
67 | \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}
68 |
69 | where:
70 | - :math:`X` expects to be the scores of each class.
71 | - :math:`Y` expects to be the one-hot tensor with the class labels.
72 |
73 | the loss, is finally computed as:
74 |
75 | .. math::
76 |
77 | \text{loss}(x, class) = 1 - \text{Dice}(x, class)
78 |
79 | [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
80 |
81 | Shape:
82 | - Input: :math:`(N, C, H, W)` where C = number of classes.
83 | - Target: :math:`(N, H, W)` where each value is
84 | :math:`0 ≤ targets[i] ≤ C−1`.
85 |
86 | Examples:
87 | >>> N = 5 # num_classes
88 | >>> loss = tgm.losses.DiceLoss()
89 | >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
90 | >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
91 | >>> output = loss(input, target)
92 | >>> output.backward()
93 | """
94 |
95 | def __init__(self, weights) -> None:
96 | super(DiceLoss, self).__init__()
97 | self.eps: float = 1e-6
98 | self.weights = weights
99 |
100 | def forward(
101 | self,
102 | input: torch.Tensor,
103 | target: torch.Tensor) -> torch.Tensor:
104 | if not torch.is_tensor(input):
105 | raise TypeError("Input type is not a torch.Tensor. Got {}"
106 | .format(type(input)))
107 | if not len(input.shape) == 4:
108 | raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
109 | .format(input.shape))
110 | if not input.shape[-2:] == target.shape[-2:]:
111 | raise ValueError("input and target shapes must be the same. Got: {}"
112 | .format(input.shape, input.shape))
113 | if not input.device == target.device:
114 | raise ValueError(
115 | "input and target must be in the same device. Got: {}" .format(
116 | input.device, target.device))
117 | # compute softmax over the classes axis
118 | input_soft = F.softmax(input, dim=1)*self.weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) #[:,1:]
119 |
120 | # create the labels one hot tensor
121 | target_one_hot = one_hot(target, num_classes=input.shape[1],
122 | device=input.device, dtype=input.dtype)#[:, 1:]
123 |
124 | # compute the actual dice score
125 | dims = (1, 2, 3)
126 | intersection = torch.sum(input_soft * target_one_hot, dims)
127 | cardinality = torch.sum(input_soft + target_one_hot, dims)
128 |
129 | dice_score = 2. * intersection / (cardinality + self.eps)
130 | return torch.mean(1. - dice_score)
131 |
132 | class CombinedLoss(nn.Module):
133 | def __init__(self, class_weights):
134 | super().__init__()
135 | self.ce_loss = nn.CrossEntropyLoss(class_weights)
136 | self.dice_loss = DiceLoss(class_weights)
137 | self.bce_loss = nn.BCEWithLogitsLoss()
138 |
139 | def forward(self, pred_3, pred_canny, pred_1, pred_2, msk, canny_label):
140 | loss_pred_3 = self.ce_loss(pred_3, msk) + self.dice_loss(pred_3, msk)
141 | loss_pred_1 = self.ce_loss(pred_1, msk) + self.dice_loss(pred_1, msk)
142 | loss_pred_2 = self.ce_loss(pred_2, msk) + self.dice_loss(pred_2, msk)
143 | loss_canny = self.bce_loss(pred_canny, canny_label.unsqueeze(1))
144 | loss = loss_pred_3 + loss_pred_1 + loss_pred_2 + loss_canny
145 |
146 | return loss
147 |
148 |
--------------------------------------------------------------------------------
/msrf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torchvision.models.resnet import BasicBlock
5 |
6 |
7 | # BLOCKS to construct the model
8 | class DSDF_block(nn.Module): #OK
9 | def __init__(self, in_ch_x, in_ch_y, nf1=128, nf2=256, gc=64, bias=True):
10 | super().__init__()
11 |
12 | self.nx1 = nn.Sequential(nn.Conv2d(in_ch_x, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias),
13 | nn.LeakyReLU(negative_slope=0.25))
14 |
15 | self.ny1 = nn.Sequential(nn.Conv2d(in_ch_y, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias),
16 | nn.LeakyReLU(negative_slope=0.25))
17 |
18 | self.nx1c = nn.Sequential(nn.Conv2d(in_ch_x, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4, stride 1 -> 2
19 | nn.LeakyReLU(negative_slope=0.25))
20 |
21 | self.ny1t = nn.Sequential(nn.ConvTranspose2d(in_ch_y, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4
22 | nn.LeakyReLU(negative_slope=0.25))
23 |
24 | self.nx2 = nn.Sequential(nn.Conv2d(in_ch_x+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias),
25 | nn.LeakyReLU(negative_slope=0.25))
26 |
27 | self.ny2 = nn.Sequential(nn.Conv2d(in_ch_y+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias),
28 | nn.LeakyReLU(negative_slope=0.25))
29 |
30 | self.nx2c = nn.Sequential(nn.Conv2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4, stride 1 -> 2
31 | nn.LeakyReLU(negative_slope=0.25))
32 |
33 | self.ny2t = nn.Sequential(nn.ConvTranspose2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4
34 | nn.LeakyReLU(negative_slope=0.25))
35 |
36 | self.nx3 = nn.Sequential(nn.Conv2d(in_ch_x+gc+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias),
37 | nn.LeakyReLU(negative_slope=0.25))
38 |
39 | self.ny3 = nn.Sequential(nn.Conv2d(in_ch_y+gc+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias),
40 | nn.LeakyReLU(negative_slope=0.25))
41 |
42 | self.nx3c = nn.Sequential(nn.Conv2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4, stride 1 -> 2
43 | nn.LeakyReLU(negative_slope=0.25))
44 |
45 | self.ny3t = nn.Sequential(nn.ConvTranspose2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4
46 | nn.LeakyReLU(negative_slope=0.25))
47 |
48 | self.nx4 = nn.Sequential(nn.Conv2d(in_ch_x+gc+gc+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias),
49 | nn.LeakyReLU(negative_slope=0.25))
50 |
51 | self.ny4 = nn.Sequential(nn.Conv2d(in_ch_y+gc+gc+gc+gc, gc, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias),
52 | nn.LeakyReLU(negative_slope=0.25))
53 |
54 | self.nx4c = nn.Sequential(nn.Conv2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4, stride 1 -> 2
55 | nn.LeakyReLU(negative_slope=0.25))
56 |
57 | self.ny4t = nn.Sequential(nn.ConvTranspose2d(gc, gc, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=bias), # ks 3 -> 4
58 | nn.LeakyReLU(negative_slope=0.25))
59 |
60 | self.nx5 = nn.Sequential(nn.Conv2d(in_ch_x+gc+gc+gc+gc+gc, nf1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias),
61 | nn.LeakyReLU(negative_slope=0.25))
62 |
63 | self.ny5 = nn.Sequential(nn.Conv2d(in_ch_y+gc+gc+gc+gc+gc, nf2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=bias),
64 | nn.LeakyReLU(negative_slope=0.25))
65 |
66 | def forward(self, x, y):
67 |
68 | x1 = self.nx1(x)
69 | y1 = self.ny1(y)
70 |
71 | x1c = self.nx1c(x)
72 | y1t = self.ny1t(y)
73 |
74 | x2_input = torch.cat([x, x1, y1t], dim=1)
75 | x2 = self.nx2(x2_input)
76 |
77 | y2_input = torch.cat([y, y1, x1c], dim=1)
78 | y2 = self.ny2(y2_input)
79 |
80 | x2c = self.nx2c(x1)
81 | y2t = self.ny2t(y1)
82 |
83 | x3_input = torch.cat([x, x1, x2, y2t], dim=1)
84 | x3 = self.nx3(x3_input)
85 |
86 | y3_input = torch.cat([y, y1, y2, x2c], dim=1)
87 | y3 = self.ny3(y3_input)
88 |
89 | x3c = self.nx3c(x3)
90 | y3t = self.ny3t(y3)
91 |
92 | x4_input = torch.cat([x, x1, x2, x3, y3t], dim=1)
93 | x4 = self.nx4(x4_input)
94 |
95 | y4_input = torch.cat([y, y1, y2, y3, x3c], dim=1)
96 | y4 = self.ny4(y4_input)
97 |
98 | x4c = self.nx4c(x4)
99 | y4t = self.ny4t(y4)
100 |
101 | x5_input = torch.cat([x, x1, x2, x3, x4, y4t], dim=1)
102 | x5 = self.nx5(x5_input)
103 |
104 | y5_input = torch.cat([y, y1, y2, y3, y4, x4c], dim=1)
105 | y5 = self.ny5(y5_input)
106 |
107 | x5 *= 0.4
108 | y5 *= 0.4
109 |
110 | return x5+x, y5+y
111 |
112 |
113 | class ATTENTION_block(nn.Module): #OK
114 | def __init__(self, in_ch_x, in_ch_g, med_ch):
115 | super().__init__()
116 | self.theta = nn.Conv2d(in_ch_x, med_ch, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True)
117 | self.phi = nn.Conv2d(in_ch_g, med_ch, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
118 | self.block = nn.Sequential(nn.ReLU(),
119 | nn.Conv2d(med_ch, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True),
120 | nn.Sigmoid(),
121 | nn.ConvTranspose2d(1, 1, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True))
122 | self.batchnorm = nn.BatchNorm2d(in_ch_x)
123 |
124 | def forward(self, x, g):
125 | theta = self.theta(x) + self.phi(g)
126 | out = self.batchnorm(self.block(theta) * x)
127 | return out
128 |
129 |
130 | class UP_block(nn.Module): #OK
131 | def __init__(self, input_1_ch, input_2_ch):
132 | super().__init__()
133 | self.up = nn.ConvTranspose2d(input_2_ch, input_1_ch, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True)
134 |
135 | def forward(self, input_1, input_2):
136 | x = torch.cat([self.up(input_2), input_1], dim=1)
137 | return x
138 |
139 |
140 | class SE_block(nn.Module): #OK
141 | def __init__(self, in_ch, ratio=16):
142 | super().__init__()
143 | self.block = nn.Sequential(nn.Linear(in_ch, in_ch//ratio),
144 | nn.ReLU(),
145 | nn.Linear(in_ch//ratio, in_ch),
146 | nn.Sigmoid())
147 | def forward(self, x):
148 | y = x.mean((-2, -1))
149 | y = self.block(y).unsqueeze(-1).unsqueeze(-1)
150 | return x*y
151 |
152 |
153 | class SPATIALATT_block(nn.Module): #OK
154 | def __init__(self, in_ch, med_ch):
155 | super().__init__()
156 | self.block = nn.Sequential(nn.Conv2d(in_ch, med_ch, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True),
157 | nn.BatchNorm2d(med_ch),
158 | nn.ReLU(),
159 | nn.Conv2d(med_ch, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True),
160 | nn.Sigmoid())
161 | def forward(self, x):
162 | x = self.block(x)
163 |
164 | return x
165 |
166 |
167 | class RES_block(nn.Module): #OK
168 | def __init__(self, in_ch):
169 | super().__init__()
170 | self.block = nn.Sequential(nn.Conv2d(in_ch, in_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
171 | nn.BatchNorm2d(in_ch),
172 | nn.ReLU(),
173 | nn.Conv2d(in_ch, in_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1,1), bias=True),
174 | nn.BatchNorm2d(in_ch))
175 | self.act = nn.ReLU()
176 |
177 | def forward(self, x):
178 | res = self.block(x)
179 | out = self.act(res+x)
180 |
181 | return out
182 |
183 |
184 | class DUALATT_block(nn.Module): #OK
185 | def __init__(self, skip_in_ch, prev_in_ch, out_ch):
186 | super().__init__()
187 | self.prev_block = nn.Sequential(nn.ConvTranspose2d(prev_in_ch, out_ch, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True),
188 | nn.BatchNorm2d(out_ch),
189 | nn.ReLU())
190 | self.block = nn.Sequential(nn.Conv2d(skip_in_ch+out_ch, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
191 | nn.BatchNorm2d(out_ch),
192 | nn.ReLU())
193 | self.se_block = SE_block(out_ch, ratio=16)
194 | self.spatial_att = SPATIALATT_block(out_ch, out_ch)
195 |
196 | def forward(self, skip, prev):
197 |
198 | prev = self.prev_block(prev)
199 | x = torch.cat([skip, prev], dim=1)
200 | inpt_layer = self.block(x)
201 | se_out = self.se_block(inpt_layer)
202 | sab = self.spatial_att(inpt_layer) + 1
203 |
204 | return sab*se_out
205 |
206 |
207 | class GSC_block(nn.Module):
208 | def __init__(self, in_ch_x, in_ch_y):
209 | super().__init__()
210 | self.block = nn.Sequential(nn.BatchNorm2d(in_ch_x+in_ch_y),
211 | nn.Conv2d(in_ch_x+in_ch_y, in_ch_x+in_ch_y+1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)), #in_ch->out_ch
212 | nn.ReLU(),
213 | nn.Conv2d(in_ch_x+in_ch_y+1, 1, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
214 | nn.BatchNorm2d(1),
215 | nn.Sigmoid())
216 |
217 | def forward(self, x, y):
218 | inpt = torch.cat([x, y], dim=1)
219 | inpt = self.block(inpt)
220 |
221 | return inpt
222 |
223 | ## SHAPE-STREAM
224 | # https://github.com/leftthomas/GatedSCNN/blob/master/model.py
225 |
226 | class GatedConv(nn.Conv2d):
227 | def __init__(self, in_channels, out_channels):
228 | super().__init__(in_channels, out_channels, 1, bias=False)
229 | self.attention = nn.Sequential(
230 | nn.BatchNorm2d(in_channels + 1),
231 | nn.Conv2d(in_channels + 1, in_channels + 1, 1),
232 | nn.ReLU(),
233 | nn.Conv2d(in_channels + 1, 1, 1),
234 | nn.BatchNorm2d(1),
235 | nn.Sigmoid()
236 | )
237 |
238 | def forward(self, feat, gate):
239 | attention = self.attention(torch.cat((feat, gate), dim=1))
240 | out = F.conv2d(feat * (attention + 1), self.weight)
241 | return out
242 |
243 | class ShapeStream(nn.Module):
244 | def __init__(self, init_feat):
245 | super().__init__()
246 | self.res2_conv = nn.Conv2d(init_feat * 2, 1, 1)
247 | self.res3_conv = nn.Conv2d(init_feat * 4, 1, 1)
248 | self.res4_conv = nn.Conv2d(init_feat * 8, 1, 1)
249 | self.res1 = BasicBlock(init_feat, init_feat, 1)
250 | self.res2 = BasicBlock(32, 32, 1)
251 | self.res3 = BasicBlock(16, 16, 1)
252 | self.res1_pre = nn.Conv2d(init_feat, 32, 1)
253 | self.res2_pre = nn.Conv2d(32, 16, 1)
254 | self.res3_pre = nn.Conv2d(16, 8, 1)
255 | self.gate1 = GatedConv(32, 32)
256 | self.gate2 = GatedConv(16, 16)
257 | self.gate3 = GatedConv(8, 8)
258 | self.gate = nn.Conv2d(8, 1, 1, bias=False)
259 | self.fuse = nn.Conv2d(2, 1, 1, bias=False)
260 |
261 | def forward(self, x, res2, res3, res4, grad): #def forward(self, x, res2, res3, res4, grad):
262 | size = grad.size()[-2:]
263 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
264 | res2 = F.interpolate(self.res2_conv(res2), size, mode='bilinear', align_corners=True)
265 | res3 = F.interpolate(self.res3_conv(res3), size, mode='bilinear', align_corners=True)
266 | res4 = F.interpolate(self.res4_conv(res4), size, mode='bilinear', align_corners=True)
267 | gate1 = self.gate1(self.res1_pre(self.res1(x)), res2)
268 | gate2 = self.gate2(self.res2_pre(self.res2(gate1)), res3)
269 | gate3 = self.gate3(self.res3_pre(self.res3(gate2)), res4)
270 | gate = torch.sigmoid(self.gate(gate3))
271 | #gate = torch.sigmoid(self.gate(gate2))
272 | feat = torch.sigmoid(self.fuse(torch.cat((gate, grad), dim=1)))
273 | return gate, feat
274 |
275 |
276 |
277 | # MODEL (NOT CHECKED)
278 | class MSRF(nn.Module):
279 | def __init__(self, in_ch, n_classes, init_feat=32):
280 | super().__init__()
281 |
282 | # ENCODER ----------------------------
283 | self.n11 = nn.Sequential(nn.Conv2d(in_ch, init_feat, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
284 | nn.ReLU(),
285 | nn.Conv2d(init_feat, init_feat, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
286 | nn.ReLU(),
287 | nn.BatchNorm2d(init_feat),
288 | SE_block(init_feat, ratio=init_feat//2)
289 | )
290 |
291 | self.n21 = nn.Sequential(nn.MaxPool2d(kernel_size=(2, 2)),
292 | nn.Dropout(0.2),
293 | nn.Conv2d(init_feat, init_feat*2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
294 | nn.ReLU(),
295 | nn.Conv2d(init_feat*2, init_feat*2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
296 | nn.ReLU(),
297 | nn.BatchNorm2d(init_feat*2),
298 | SE_block(init_feat*2, ratio=init_feat//2))
299 |
300 | self.n31 = nn.Sequential(nn.MaxPool2d(kernel_size=(2, 2)),
301 | nn.Dropout(0.2),
302 | nn.Conv2d(init_feat*2, init_feat*4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
303 | nn.ReLU(),
304 | nn.Conv2d(init_feat*4, init_feat*4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
305 | nn.ReLU(),
306 | nn.BatchNorm2d(init_feat*4),
307 | SE_block(init_feat*4, ratio=init_feat//2))
308 |
309 | self.n41 = nn.Sequential(nn.MaxPool2d(kernel_size=(2, 2)),
310 | nn.Dropout(0.2),
311 | nn.Conv2d(init_feat*4, init_feat*8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
312 | nn.ReLU(),
313 | nn.Conv2d(init_feat*8, init_feat*8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
314 | nn.ReLU(),
315 | nn.BatchNorm2d(init_feat*8))
316 | # MSRF-subnetwork ----------------------------
317 | self.dsfs_1 = DSDF_block(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2)
318 | self.dsfs_2 = DSDF_block(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2)
319 | self.dsfs_3 = DSDF_block(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2)
320 | self.dsfs_4 = DSDF_block(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2)
321 | self.dsfs_5 = DSDF_block(init_feat*2, init_feat*4, nf1=init_feat*2, nf2=init_feat*4, gc=init_feat*2//2)
322 | self.dsfs_6 = DSDF_block(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2)
323 | self.dsfs_7 = DSDF_block(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2)
324 | self.dsfs_8 = DSDF_block(init_feat*2, init_feat*4, nf1=init_feat*2, nf2=init_feat*4, gc=init_feat*2//2)
325 | self.dsfs_9 = DSDF_block(init_feat, init_feat*2, nf1=init_feat, nf2=init_feat*2, gc=init_feat//2)
326 | self.dsfs_10 = DSDF_block(init_feat*4, init_feat*8, nf1=init_feat*4, nf2=init_feat*8, gc=init_feat*4//2)
327 |
328 | # SHAPE STREAM ------------IN PROGRESS-------------------
329 | self.shape_stream = ShapeStream(init_feat)
330 |
331 | # DECODER
332 | # Stage 1:
333 | self.att_1 = ATTENTION_block(init_feat*4, init_feat*8, init_feat*8)
334 | self.up_1 = UP_block(init_feat*4, init_feat*8)
335 | self.dualatt_1 = DUALATT_block(init_feat*4, init_feat*8, init_feat*4)
336 | self.n34_t = nn.Conv2d(init_feat * 4 + init_feat * 8, init_feat * 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
337 | self.dec_block_1 = nn.Sequential(nn.BatchNorm2d(init_feat*4),
338 | nn.ReLU(),
339 | nn.Conv2d(init_feat*4, init_feat*4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
340 | nn.BatchNorm2d(init_feat*4),
341 | nn.ReLU(),
342 | nn.Conv2d(init_feat*4, init_feat*4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
343 | )
344 | self.head_dec_1 = nn.Sequential(nn.Conv2d(init_feat*4, n_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
345 | nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True))
346 |
347 | # Stage 2:
348 | self.att_2 = ATTENTION_block(init_feat * 2, init_feat * 4, init_feat * 2)
349 | self.up_2 = UP_block(init_feat * 2, init_feat * 4)
350 | self.dualatt_2 = DUALATT_block(init_feat * 2, init_feat * 4, init_feat * 2)
351 | self.n24_t = nn.Conv2d(init_feat * 2 + init_feat * 4, init_feat * 2, kernel_size=(1, 1), stride=(1, 1), padding=(0,0))
352 | self.dec_block_2 = nn.Sequential(nn.BatchNorm2d(init_feat * 2),
353 | nn.ReLU(),
354 | nn.Conv2d(init_feat * 2, init_feat * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
355 | nn.BatchNorm2d(init_feat * 2),
356 | nn.ReLU(),
357 | nn.Conv2d(init_feat*2, init_feat * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
358 | )
359 | self.head_dec_2 = nn.Sequential(nn.Conv2d(init_feat * 2, n_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
360 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
361 |
362 | # Stage 3:
363 | self.up_3 = nn.ConvTranspose2d(init_feat * 2, init_feat, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
364 | self.n14_input = nn.Sequential(nn.Conv2d(init_feat + init_feat + 1, init_feat, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
365 | nn.ReLU())
366 | self.dec_block_3 = nn.Sequential(nn.Conv2d(init_feat, init_feat, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
367 | nn.ReLU(),
368 | nn.BatchNorm2d(init_feat))
369 |
370 | self.head_dec_3 = nn.Sequential(nn.Conv2d(init_feat, init_feat, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
371 | nn.ReLU(),
372 | nn.Conv2d(init_feat, n_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)))
373 |
374 | def forward(self, x, canny):
375 | # ENCODER:
376 | x11 = self.n11(x)
377 | x21 = self.n21(x11)
378 | x31 = self.n31(x21)
379 | x41 = self.n41(x31)
380 |
381 | # MSRF-subnetwork
382 | x12, x22 = self.dsfs_1(x11, x21)
383 | x32, x42 = self.dsfs_2(x31, x41)
384 | x12, x22 = self.dsfs_3(x12, x22)
385 | x32, x42 = self.dsfs_4(x32, x42)
386 | x22, x32 = self.dsfs_5(x22, x32)
387 | x13, x23 = self.dsfs_6(x12, x22)
388 | x33, x43 = self.dsfs_7(x32, x42)
389 | x23, x33 = self.dsfs_8(x23, x33)
390 | x13, x23 = self.dsfs_9(x13, x23)
391 | x33, x43 = self.dsfs_10(x33, x43)
392 |
393 | x13 = (x13*0.4) + x11
394 | x23 = (x23*0.4) + x21
395 | x33 = (x33*0.4) + x31
396 | x43 = (x43*0.4) + x41
397 |
398 | # SHAPE STREAM
399 | # https://github.com/leftthomas/GatedSCNN (https://arxiv.org/pdf/1907.05740.pdf)
400 | canny_gate, canny_feat = self.shape_stream(x13, x23, x33, x43, canny)
401 |
402 | # DECODER
403 | # Stage 1:
404 | x34_preinput = self.att_1(x33, x43)
405 |
406 | x34 = self.up_1(x34_preinput, x43)
407 | x34_t = self.dualatt_1(x33, x43)
408 | x34_t = torch.cat([x34, x34_t], dim=1)
409 | x34_t = self.n34_t(x34_t)
410 | x34 = self.dec_block_1(x34_t) + x34_t
411 |
412 | pred_1 = self.head_dec_1(x34)
413 |
414 | # Stage 2:
415 | x24_preinput = self.att_2(x23, x34)
416 | x24 = self.up_2(x24_preinput, x34)
417 | x24_t = self.dualatt_2(x23, x34)
418 | x24_t = torch.cat([x24, x24_t], dim=1)
419 | x24_t = self.n24_t(x24_t)
420 | x24 = self.dec_block_2(x24_t) + x24_t
421 |
422 | pred_2 = self.head_dec_2(x24)
423 |
424 | # Stage 3:
425 | x14_preinput = self.up_3(x24)
426 | x14_input = torch.cat([x14_preinput, x13, canny_feat], dim=1)
427 | x14_input = self.n14_input(x14_input)
428 | x14 = self.dec_block_3(x14_input)
429 | x14 = x14 + x14_input
430 | pred_3 = self.head_dec_3(x14)
431 |
432 | return pred_3, canny_gate, pred_1, pred_2
433 |
434 | if __name__=="__main__":
435 | model = MSRF(1, 3, init_feat=32)
436 | x = torch.randn((2, 1, 128, 128))
437 | canny = torch.randn((2, 1, 128, 128))
438 | out = model(x, canny)
439 | for o in out:
440 | print(o.shape)
--------------------------------------------------------------------------------