├── .gitattributes ├── Data └── __init__.py ├── README.md ├── eval.py ├── eval_seg.py ├── hd5_maker ├── __init__.py ├── bing.py └── viah.py ├── loader ├── __init__.py ├── bing_loader.py ├── transforms.py └── viah_loader.py ├── models ├── __init__.py ├── hardnet.py ├── model.py ├── model_seg.py └── vanilla.py ├── out └── __init__.py ├── pics ├── 16.jpg └── 8.jpg ├── requirements.txt ├── results └── __init__.py ├── train.py ├── train_seg.py └── utils ├── __init__.py ├── loss.py ├── snake_loss.py ├── utils_TB.py ├── utils_args.py ├── utils_eval.py ├── utils_lr.py ├── utils_train.py ├── utils_tri.py └── utils_vis.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /Data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/Data/__init__.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Active Contours Model(ACM) for Buildings Segmentations 2 |

3 | 4 |

5 | Deep snake algorithm for 2D images based on - [ICLR2020 paper(revisiting the results)](https://arxiv.org/abs/1912.00367) 6 | Architecture based on [Hardnet85](https://arxiv.org/abs/1909.00948) 7 | Data and weights = (https://drive.google.com/drive/folders/1fBSjPse3d8geV_iI3-PXV3x2qmLoUnzL?usp=sharing) 8 | 9 | ### Get Started 10 | **To train a segmentation model :** 11 | ``` 12 | python train_seg.py -bs 50 -WD 0.00005 -D_rate 3000 -task bing -opt sgd -lr 0.02 -nW 8 13 | ``` 14 | **To train a ACM model :** 15 | ``` 16 | python train.py -bs 25 -WD 0.00005 -D_rate 30 -it 2 17 | ``` 18 | **To eval a segmentation model :** 19 | ``` 20 | python eval.py -task viah -nP 100 -it 3 -a 0.4 21 | ``` 22 | **To eval a ACM model :** 23 | ``` 24 | python eval_seg.py -task bing -nP 24 -it 2 -a 0.4 25 | ``` 26 | ### Results 27 | 28 | | Method | Viah
mIoU | Bing
mIoU| 29 | | :---: | :---: | :---: | 30 | | DARNet | 88.24 | 75.29 | 31 | | DSAC | 71.10 | 38.74 | 32 | | **ours** | **90.33** | **75.53** | 33 | 34 |

35 | 36 |

-------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from tensorboardX import SummaryWriter 3 | 4 | from models.model import * 5 | from models.model_seg import Segmentation 6 | from loader.viah_loader import * 7 | from loader.bing_loader import * 8 | from utils.utils_args import * 9 | from utils.utils_eval import get_dice_ji, vis_ds 10 | from utils.utils_train import * 11 | from utils.utils_tri import * 12 | from utils.loss import * 13 | from utils.snake_loss import Snakeloss 14 | import random 15 | 16 | 17 | def eval_ds(ds, model, segnet, PTrain, faces, args): 18 | model.eval() 19 | IoU_list = [] 20 | Dice_list = [] 21 | model.eval() 22 | with torch.no_grad(): 23 | for ix, (_x, _y) in enumerate(ds): 24 | _x = _x.float().cuda() 25 | _p = PTrain.float().cuda().clone() 26 | _y = _y.float().cuda() 27 | seg_out = segnet(_x).detach() 28 | _x = norm_input(_x, seg_out, float(args['a'])) 29 | iter = int(args['DeepIt']) 30 | net_out = model(_x, _p, faces, iter) 31 | Mask = net_out[0][iter-1] 32 | cDice, cIoU = get_dice_ji(Mask, _y) 33 | IoU_list.append(cIoU) 34 | Dice_list.append(cDice) 35 | IoU = np.mean(IoU_list) 36 | Dice = np.mean(Dice_list) 37 | model.train() 38 | return Dice, IoU 39 | 40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 41 | torch.backends.cudnn.benchmark = True 42 | args = get_args() 43 | 44 | segnet = Segmentation(args) 45 | model = DeepACM(args) 46 | 47 | 48 | P_test, faces_test = get_poly(int(args['im_size']), int(args['nP']), 49 | int(args['Radius']), int(args['im_size']) / 2, int(args['im_size']) / 2) 50 | faces_test = faces_test.unsqueeze(dim=0).unsqueeze(dim=0).cuda() 51 | faces = faces_test.repeat(1, 1, 1, 1).cuda() 52 | PTrain = P_test.repeat(1, 1, 1, 1).cuda() 53 | if args['task'] == 'viah': 54 | PATH = r'results/viah/best/' 55 | testset = viah_segmentation(ann='test', args=args) 56 | elif args['task'] == 'bing': 57 | testset = bing_segmentation(ann='test', args=args) 58 | PATH = r'results/bing/best/' 59 | ds_val = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, 60 | num_workers=1, drop_last=False) 61 | 62 | model1 = torch.load(PATH + 'ACM.pt') 63 | model.load_state_dict(model1.state_dict()) 64 | model.eval().to(device) 65 | segnet1 = torch.load(PATH + 'SEG.pt') 66 | segnet.load_state_dict(segnet1.state_dict()) 67 | segnet.eval().to(device) 68 | vis_ds(ds_val, model, segnet, PTrain, faces, args, num_of_ex=20) 69 | dice, iou = eval_ds(ds_val, model, segnet, P_test, faces_test, args) 70 | print((dice, iou)) 71 | -------------------------------------------------------------------------------- /eval_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from loader.viah_loader import * 4 | from loader.bing_loader import * 5 | from utils.utils_args import * 6 | from utils.utils_eval import * 7 | from utils.utils_train import * 8 | from utils.utils_tri import * 9 | from utils.utils_vis import * 10 | from utils.loss import * 11 | from models.model_seg import * 12 | 13 | 14 | def eval_ds(ds, model): 15 | TestDice_list = [] 16 | TestIoU_list = [] 17 | for ix, (_x, _y) in enumerate(ds): 18 | _x = _x.float().cpu() 19 | _y = _y.float().cpu() 20 | Mask = model(_x) 21 | Mask[Mask >= 0.5] = 1 22 | Mask[Mask < 0.5] = 0 23 | (cDice, cIoU) = get_dice_ji(Mask, _y) 24 | TestDice_list.append(cDice) 25 | TestIoU_list.append(cIoU) 26 | Dice = np.mean(TestDice_list) 27 | IoU = np.mean(TestIoU_list) 28 | print((Dice, IoU)) 29 | 30 | 31 | def main(): 32 | torch.backends.cudnn.benchmark = True 33 | args = get_args() 34 | save_args(args) 35 | 36 | if args['task'] == 'viah': 37 | PATH = r'results/viah/best/' 38 | testset = viah_segmentation(ann='test', args=args) 39 | elif args['task'] == 'bing': 40 | testset = bing_segmentation(ann='test', args=args) 41 | PATH = r'results/bing/best/' 42 | segnet = Segmentation(args) 43 | segnet1 = torch.load(PATH + 'SEG.pt') 44 | segnet.load_state_dict(segnet1.state_dict()) 45 | segnet.cpu().eval() 46 | ds_val = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, 47 | num_workers=1, drop_last=False) 48 | eval_ds(ds_val, segnet) 49 | 50 | if __name__ == '__main__': 51 | main() 52 | 53 | -------------------------------------------------------------------------------- /hd5_maker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/hd5_maker/__init__.py -------------------------------------------------------------------------------- /hd5_maker/bing.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | import numpy as np 4 | from glob import glob 5 | from PIL import Image 6 | import torchvision.transforms as transforms 7 | 8 | 9 | def get_img(cfile): 10 | image = Image.open(cfile) 11 | image = transforms.functional.resize(image, (256, 256), 3) 12 | return np.asarray(image) 13 | 14 | def get_mask(cfile): 15 | image = Image.open(cfile) 16 | image = np.asarray(image).copy() 17 | image[image > 0] = 255 18 | image = image.astype(np.uint8) 19 | image = Image.fromarray(image).convert('1') 20 | mask_one = transforms.functional.resize(image, (256, 256)) 21 | mask = np.asarray(mask_one).astype(np.float) 22 | return mask 23 | 24 | src = '/media/data1/talshah/DAR/single_buildings/' 25 | hf_tri = h5py.File('/media/data1/talshah/DeepACM/Data/full_training_Bing.h5', 'w') 26 | hf_test = h5py.File('/media/data1/talshah/DeepACM/Data/full_test_Bing.h5', 'w') 27 | a = os.listdir(src) 28 | a.sort() 29 | 30 | img_list = glob(src + 'building_*') 31 | mask_list = glob(src + 'building_mask_*') 32 | mask_all_list = glob(src + 'building_mask_all_*') 33 | img_list.sort() 34 | mask_list.sort() 35 | mask_all_list.sort() 36 | 37 | imgs_tri = hf_tri.create_group('imgs') 38 | mask_tri = hf_tri.create_group('mask') 39 | mask_single_tri = hf_tri.create_group('mask_single') 40 | 41 | imgs_test = hf_test.create_group('imgs') 42 | mask_test = hf_test.create_group('mask') 43 | mask_single_test = hf_test.create_group('mask_single') 44 | 45 | for folder in img_list[0:335]: 46 | print('training: ' + folder) 47 | img = get_img(folder) 48 | imgs_tri.create_dataset(folder.split('/')[-1], data=img, dtype=np.uint8) 49 | 50 | for folder in img_list[335:606]: 51 | print('validation: ' + folder) 52 | img = get_img(folder) 53 | imgs_test.create_dataset(folder.split('/')[-1], data=img, dtype=np.uint8) 54 | 55 | for folder in mask_list[0:335]: 56 | print('training: ' + folder) 57 | mask = get_mask(folder) 58 | mask_single_tri.create_dataset(folder.split('/')[-1], data=mask, dtype=np.uint8) 59 | 60 | for folder in mask_list[335:606]: 61 | print('validation: ' + folder) 62 | mask = get_mask(folder) 63 | mask_single_test.create_dataset(folder.split('/')[-1], data=mask, dtype=np.uint8) 64 | 65 | for folder in mask_all_list[0:335]: 66 | print('training: ' + folder) 67 | mask = get_mask(folder) 68 | mask_tri.create_dataset(folder.split('/')[-1], data=mask, dtype=np.uint8) 69 | 70 | for folder in mask_all_list[335:606]: 71 | print('validation: ' + folder) 72 | mask = get_mask(folder) 73 | mask_test.create_dataset(folder.split('/')[-1], data=mask, dtype=np.uint8) 74 | 75 | hf_tri.close() 76 | hf_test.close() 77 | 78 | -------------------------------------------------------------------------------- /hd5_maker/viah.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | import cv2 4 | import numpy as np 5 | from skimage.transform import resize 6 | 7 | 8 | def get_img(cfile, shape=(256, 256)): 9 | img = cv2.cvtColor(cv2.imread(cfile, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 10 | img = resize(img, shape)*255 11 | return img 12 | 13 | 14 | def get_mask(cfile, shape=(256, 256)): 15 | GT = cv2.imread(cfile, 0) 16 | GT = resize(GT, shape) 17 | GT[GT >= 0.5] = 1 18 | GT[GT < 0.5] = 0 19 | return GT 20 | 21 | 22 | img_src = '/media/data1/talshah/GenA/buildings_vaihingen/buildings/img/' 23 | mask_src = '/media/data1/talshah/GenA/buildings_vaihingen/buildings/AllBuildingsMask/' 24 | mask_single_src = '/media/data1/talshah/GenA/buildings_vaihingen/buildings/mask_sizeold/' 25 | hf_tri = h5py.File('/media/data1/talshah/DeepACM/Data/full_training_viah.h5', 'w') 26 | hf_test = h5py.File('/media/data1/talshah/DeepACM/Data/full_test_viah.h5', 'w') 27 | a_img = os.listdir(img_src) 28 | a_img.sort() 29 | a_mask = os.listdir(mask_src) 30 | a_mask.sort() 31 | a_mask_single = os.listdir(mask_single_src) 32 | a_mask_single.sort() 33 | 34 | imgs_tri = hf_tri.create_group('imgs') 35 | mask_tri = hf_tri.create_group('mask') 36 | mask_single_tri = hf_tri.create_group('mask_single') 37 | 38 | imgs_test = hf_test.create_group('imgs') 39 | mask_test = hf_test.create_group('mask') 40 | mask_single_test = hf_test.create_group('mask_single') 41 | shape_mask = (256, 256) 42 | shape_img = (256, 256) 43 | 44 | for folder in a_img[:100]: 45 | print('training: ' + folder) 46 | cfile = img_src + folder 47 | img = get_img(cfile, shape_img) 48 | imgs_tri.create_dataset(folder, data=img, dtype=np.uint8) 49 | 50 | for folder in a_img[100:]: 51 | print('validation: ' + folder) 52 | cfile = img_src + folder 53 | img = get_img(cfile, shape_img) 54 | imgs_test.create_dataset(folder, data=img, dtype=np.uint8) 55 | 56 | for folder in a_mask[:100]: 57 | print('training: ' + folder) 58 | cfile = mask_src + folder 59 | mask = get_mask(cfile, shape_mask) 60 | mask_tri.create_dataset(folder, data=mask, dtype=np.uint8) 61 | 62 | for folder in a_mask[100:]: 63 | print('validation: ' + folder) 64 | cfile = mask_src + folder 65 | mask = get_mask(cfile, shape_mask) 66 | mask_test.create_dataset(folder, data=mask, dtype=np.uint8) 67 | 68 | for folder in a_mask_single[:100]: 69 | print('training: ' + folder) 70 | cfile = mask_single_src + folder 71 | mask = get_mask(cfile, shape_mask) 72 | mask_single_tri.create_dataset(folder, data=mask, dtype=np.uint8) 73 | 74 | for folder in a_mask_single[100:]: 75 | print('validation: ' + folder) 76 | cfile = mask_single_src + folder 77 | mask = get_mask(cfile, shape_mask) 78 | mask_single_test.create_dataset(folder, data=mask, dtype=np.uint8) 79 | 80 | hf_tri.close() 81 | hf_test.close() 82 | 83 | -------------------------------------------------------------------------------- /loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/loader/__init__.py -------------------------------------------------------------------------------- /loader/bing_loader.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | import torch 4 | from torch.utils.data import Dataset 5 | import h5py 6 | import numpy as np 7 | import loader.transforms as transforms 8 | from PIL import Image 9 | 10 | 11 | class bing_segmentation(Dataset): 12 | def __init__(self, ann='training', args=None): 13 | self.ann = ann 14 | self.MEAN = np.array([101.87901, 100.81404, 110.389275]) 15 | self.STD = np.array([17.022379, 17.664776, 20.302572]) 16 | if ann == 'training': 17 | self.transformations = transforms.Compose([transforms.ToPILImage(), 18 | transforms.ColorJitter(brightness=0.3, 19 | contrast=0.3, 20 | saturation=0.3, 21 | hue=0.01), 22 | transforms.RandomRotation(90), 23 | transforms.RandomVerticalFlip(), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ToTensor(), 26 | transforms.Normalize(self.MEAN, self.STD)]) 27 | else: 28 | self.transformations = transforms.Compose([transforms.ToPILImage(), 29 | transforms.ToTensor(), 30 | transforms.Normalize(self.MEAN, self.STD)]) 31 | if ann == 'training': 32 | self.data_length = 335 33 | else: 34 | self.data_length = 270 35 | self.args = args 36 | 37 | def __len__(self): 38 | return self.data_length 39 | 40 | def __getitem__(self, item): 41 | if self.ann == 'training': 42 | self.data = h5py.File('Data/full_training_Bing.h5', 'r') 43 | else: 44 | self.data = h5py.File('Data/full_test_Bing.h5', 'r') 45 | self.mask = self.data['mask_single'] 46 | self.imgs = self.data['imgs'] 47 | self.img_list = list(self.imgs) 48 | self.mask_list = list(self.mask) 49 | cimage = self.img_list[item] 50 | img = self.imgs.get(cimage).value 51 | cmask = self.mask_list[item] 52 | mask = self.mask.get(cmask).value 53 | img = img.astype(np.uint8) 54 | mask = mask.astype(np.uint8) 55 | img, mask = self.transformations(img, mask) 56 | return img, mask 57 | -------------------------------------------------------------------------------- /loader/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import sys 5 | import random 6 | from PIL import Image 7 | 8 | try: 9 | import accimage 10 | except ImportError: 11 | accimage = None 12 | import numpy as np 13 | import numbers 14 | import types 15 | import collections 16 | import warnings 17 | 18 | from torchvision.transforms import functional as F 19 | 20 | if sys.version_info < (3, 3): 21 | Sequence = collections.Sequence 22 | Iterable = collections.Iterable 23 | else: 24 | Sequence = collections.abc.Sequence 25 | Iterable = collections.abc.Iterable 26 | 27 | __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad", 28 | "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", 29 | "RandomVerticalFlip", "RandomResizedCrop", "FiveCrop", "TenCrop", 30 | "ColorJitter", "RandomRotation", "RandomAffine", 31 | "RandomPerspective"] 32 | 33 | _pil_interpolation_to_str = { 34 | Image.NEAREST: 'PIL.Image.NEAREST', 35 | Image.BILINEAR: 'PIL.Image.BILINEAR', 36 | Image.BICUBIC: 'PIL.Image.BICUBIC', 37 | Image.LANCZOS: 'PIL.Image.LANCZOS', 38 | Image.HAMMING: 'PIL.Image.HAMMING', 39 | Image.BOX: 'PIL.Image.BOX', 40 | } 41 | 42 | 43 | class Compose(object): 44 | def __init__(self, transforms): 45 | self.transforms = transforms 46 | 47 | def __call__(self, img, mask): 48 | for t in self.transforms: 49 | img, mask = t(img, mask) 50 | return img, mask 51 | 52 | 53 | class ToTensor(object): 54 | def __call__(self, img, mask): 55 | # return F.to_tensor(img), F.to_tensor(mask) 56 | img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() 57 | mask = torch.from_numpy(np.array(mask)).long() 58 | return img, mask 59 | 60 | 61 | class ToPILImage(object): 62 | def __init__(self, mode=None): 63 | self.mode = mode 64 | 65 | def __call__(self, img, mask): 66 | return F.to_pil_image(img, self.mode), F.to_pil_image(mask, self.mode) 67 | 68 | 69 | class Normalize(object): 70 | def __init__(self, mean, std, inplace=False): 71 | self.mean = mean 72 | self.std = std 73 | self.inplace = inplace 74 | 75 | def __call__(self, img, mask): 76 | return F.normalize(img, self.mean, self.std, self.inplace), mask 77 | 78 | 79 | class Resize(object): 80 | def __init__(self, size, interpolation=Image.BILINEAR, do_mask=True): 81 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) 82 | self.size = size 83 | self.interpolation = interpolation 84 | self.do_mask = do_mask 85 | 86 | def __call__(self, img, mask): 87 | if self.do_mask: 88 | return F.resize(img, self.size, self.interpolation), F.resize(mask, self.size, Image.NEAREST) 89 | else: 90 | return F.resize(img, self.size, self.interpolation), mask 91 | 92 | 93 | class CenterCrop(object): 94 | def __init__(self, size): 95 | if isinstance(size, numbers.Number): 96 | self.size = (int(size), int(size)) 97 | else: 98 | self.size = size 99 | 100 | def __call__(self, img, mask): 101 | return F.center_crop(img, self.size), F.center_crop(mask, self.size) 102 | 103 | 104 | class Pad(object): 105 | def __init__(self, padding, fill=0, padding_mode='constant'): 106 | assert isinstance(padding, (numbers.Number, tuple)) 107 | assert isinstance(fill, (numbers.Number, str, tuple)) 108 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] 109 | if isinstance(padding, Sequence) and len(padding) not in [2, 4]: 110 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 111 | "{} element tuple".format(len(padding))) 112 | 113 | self.padding = padding 114 | self.fill = fill 115 | self.padding_mode = padding_mode 116 | 117 | def __call__(self, img, mask): 118 | return F.pad(img, self.padding, self.fill, self.padding_mode), \ 119 | F.pad(mask, self.padding, self.fill, self.padding_mode) 120 | 121 | 122 | class Lambda(object): 123 | def __init__(self, lambd): 124 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" 125 | self.lambd = lambd 126 | 127 | def __call__(self, img, mask): 128 | return self.lambd(img), self.lambd(mask) 129 | 130 | 131 | class Lambda_image(object): 132 | def __init__(self, lambd): 133 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" 134 | self.lambd = lambd 135 | 136 | def __call__(self, img, mask): 137 | return self.lambd(img), mask 138 | 139 | 140 | class RandomTransforms(object): 141 | def __init__(self, transforms): 142 | assert isinstance(transforms, (list, tuple)) 143 | self.transforms = transforms 144 | 145 | def __call__(self, *args, **kwargs): 146 | raise NotImplementedError() 147 | 148 | 149 | class RandomApply(RandomTransforms): 150 | def __init__(self, transforms, p=0.5): 151 | super(RandomApply, self).__init__(transforms) 152 | self.p = p 153 | 154 | def __call__(self, img, mask): 155 | if self.p < random.random(): 156 | return img, mask 157 | for t in self.transforms: 158 | img, mask = t(img, mask) 159 | return img, mask 160 | 161 | 162 | class RandomOrder(RandomTransforms): 163 | def __call__(self, img, mask): 164 | order = list(range(len(self.transforms))) 165 | random.shuffle(order) 166 | for i in order: 167 | img, mask = self.transforms[i](img, mask) 168 | return img, mask 169 | 170 | 171 | class RandomChoice(RandomTransforms): 172 | def __call__(self, img, mask): 173 | t = random.choice(self.transforms) 174 | return t(img, mask) 175 | 176 | 177 | class RandomCrop(object): 178 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 179 | if isinstance(size, numbers.Number): 180 | self.size = (int(size), int(size)) 181 | else: 182 | self.size = size 183 | self.padding = padding 184 | self.pad_if_needed = pad_if_needed 185 | self.fill = fill 186 | self.padding_mode = padding_mode 187 | 188 | @staticmethod 189 | def get_params(img, output_size): 190 | w, h = img.size 191 | th, tw = output_size 192 | if w == tw and h == th: 193 | return 0, 0, h, w 194 | 195 | i = random.randint(0, h - th) 196 | j = random.randint(0, w - tw) 197 | return i, j, th, tw 198 | 199 | def __call__(self, img, mask): 200 | if self.padding is not None: 201 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 202 | 203 | # pad the width if needed 204 | if self.pad_if_needed and img.size[0] < self.size[1]: 205 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 206 | # pad the height if needed 207 | if self.pad_if_needed and img.size[1] < self.size[0]: 208 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 209 | 210 | i, j, h, w = self.get_params(img, self.size) 211 | 212 | return F.crop(img, i, j, h, w), F.crop(mask, i, j, h, w) 213 | 214 | 215 | class RandomHorizontalFlip(object): 216 | def __init__(self, p=0.5): 217 | self.p = p 218 | 219 | def __call__(self, img, mask): 220 | if random.random() < self.p: 221 | return F.hflip(img), F.hflip(mask) 222 | return img, mask 223 | 224 | 225 | class RandomVerticalFlip(object): 226 | def __init__(self, p=0.5): 227 | self.p = p 228 | 229 | def __call__(self, img, mask): 230 | if random.random() < self.p: 231 | return F.vflip(img), F.vflip(mask) 232 | return img, mask 233 | 234 | 235 | class RandomPerspective(object): 236 | def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC): 237 | self.p = p 238 | self.interpolation = interpolation 239 | self.distortion_scale = distortion_scale 240 | 241 | def __call__(self, img, mask): 242 | if not F._is_pil_image(img): 243 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 244 | 245 | if random.random() < self.p: 246 | width, height = img.size 247 | startpoints, endpoints = self.get_params(width, height, self.distortion_scale) 248 | return F.perspective(img, startpoints, endpoints, self.interpolation), \ 249 | F.perspective(mask, startpoints, endpoints, Image.NEAREST) 250 | return img, mask 251 | 252 | @staticmethod 253 | def get_params(width, height, distortion_scale): 254 | half_height = int(height / 2) 255 | half_width = int(width / 2) 256 | topleft = (random.randint(0, int(distortion_scale * half_width)), 257 | random.randint(0, int(distortion_scale * half_height))) 258 | topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), 259 | random.randint(0, int(distortion_scale * half_height))) 260 | botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), 261 | random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) 262 | botleft = (random.randint(0, int(distortion_scale * half_width)), 263 | random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) 264 | startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] 265 | endpoints = [topleft, topright, botright, botleft] 266 | return startpoints, endpoints 267 | 268 | 269 | class RandomResizedCrop(object): 270 | def __init__(self, size, mask_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 271 | if isinstance(size, tuple): 272 | self.size = size 273 | self.mask_size = mask_size 274 | else: 275 | self.size = (size, size) 276 | self.mask_size = (mask_size, mask_size) 277 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 278 | warnings.warn("range should be of kind (min, max)") 279 | 280 | self.interpolation = interpolation 281 | self.scale = scale 282 | self.ratio = ratio 283 | 284 | @staticmethod 285 | def get_params(img, scale, ratio): 286 | area = img.size[0] * img.size[1] 287 | 288 | for attempt in range(10): 289 | target_area = random.uniform(*scale) * area 290 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 291 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 292 | 293 | w = int(round(math.sqrt(target_area * aspect_ratio))) 294 | h = int(round(math.sqrt(target_area / aspect_ratio))) 295 | 296 | if w <= img.size[0] and h <= img.size[1]: 297 | i = random.randint(0, img.size[1] - h) 298 | j = random.randint(0, img.size[0] - w) 299 | return i, j, h, w 300 | 301 | # Fallback to central crop 302 | in_ratio = img.size[0] / img.size[1] 303 | if (in_ratio < min(ratio)): 304 | w = img.size[0] 305 | h = w / min(ratio) 306 | elif (in_ratio > max(ratio)): 307 | h = img.size[1] 308 | w = h * max(ratio) 309 | else: # whole image 310 | w = img.size[0] 311 | h = img.size[1] 312 | i = (img.size[1] - h) // 2 313 | j = (img.size[0] - w) // 2 314 | return i, j, h, w 315 | 316 | def __call__(self, img, mask): 317 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 318 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \ 319 | F.resized_crop(mask, i, j, h, w, self.mask_size, Image.NEAREST) 320 | 321 | 322 | class FiveCrop(object): 323 | def __init__(self, size): 324 | self.size = size 325 | if isinstance(size, numbers.Number): 326 | self.size = (int(size), int(size)) 327 | else: 328 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 329 | self.size = size 330 | 331 | def __call__(self, img, mask): 332 | return F.five_crop(img, self.size), F.five_crop(mask, self.size) 333 | 334 | 335 | class TenCrop(object): 336 | def __init__(self, size, vertical_flip=False): 337 | self.size = size 338 | if isinstance(size, numbers.Number): 339 | self.size = (int(size), int(size)) 340 | else: 341 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 342 | self.size = size 343 | self.vertical_flip = vertical_flip 344 | 345 | def __call__(self, img, mask): 346 | return F.ten_crop(img, self.size, self.vertical_flip), F.ten_crop(mask, self.size, self.vertical_flip) 347 | 348 | 349 | class ColorJitter(object): 350 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 351 | self.brightness = self._check_input(brightness, 'brightness') 352 | self.contrast = self._check_input(contrast, 'contrast') 353 | self.saturation = self._check_input(saturation, 'saturation') 354 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 355 | clip_first_on_zero=False) 356 | 357 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 358 | if isinstance(value, numbers.Number): 359 | if value < 0: 360 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 361 | value = [center - value, center + value] 362 | if clip_first_on_zero: 363 | value[0] = max(value[0], 0) 364 | elif isinstance(value, (tuple, list)) and len(value) == 2: 365 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 366 | raise ValueError("{} values should be between {}".format(name, bound)) 367 | else: 368 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 369 | 370 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 371 | # or (0., 0.) for hue, do nothing 372 | if value[0] == value[1] == center: 373 | value = None 374 | return value 375 | 376 | @staticmethod 377 | def get_params(brightness, contrast, saturation, hue): 378 | transforms = [] 379 | 380 | if brightness is not None: 381 | brightness_factor = random.uniform(brightness[0], brightness[1]) 382 | transforms.append(Lambda_image(lambda img: F.adjust_brightness(img, brightness_factor))) 383 | 384 | if contrast is not None: 385 | contrast_factor = random.uniform(contrast[0], contrast[1]) 386 | transforms.append(Lambda_image(lambda img: F.adjust_contrast(img, contrast_factor))) 387 | 388 | if saturation is not None: 389 | saturation_factor = random.uniform(saturation[0], saturation[1]) 390 | transforms.append(Lambda_image(lambda img: F.adjust_saturation(img, saturation_factor))) 391 | 392 | if hue is not None: 393 | hue_factor = random.uniform(hue[0], hue[1]) 394 | transforms.append(Lambda_image(lambda img: F.adjust_hue(img, hue_factor))) 395 | 396 | random.shuffle(transforms) 397 | transform = Compose(transforms) 398 | 399 | return transform 400 | 401 | def __call__(self, img, mask): 402 | transform = self.get_params(self.brightness, self.contrast, 403 | self.saturation, self.hue) 404 | return transform(img, mask) 405 | 406 | 407 | class RandomRotation(object): 408 | def __init__(self, degrees, resample=False, expand=False, center=None): 409 | if isinstance(degrees, numbers.Number): 410 | if degrees < 0: 411 | raise ValueError("If degrees is a single number, it must be positive.") 412 | self.degrees = (-degrees, degrees) 413 | else: 414 | if len(degrees) != 2: 415 | raise ValueError("If degrees is a sequence, it must be of len 2.") 416 | self.degrees = degrees 417 | 418 | self.resample = resample 419 | self.expand = expand 420 | self.center = center 421 | 422 | @staticmethod 423 | def get_params(degrees): 424 | angle = random.uniform(degrees[0], degrees[1]) 425 | 426 | return angle 427 | 428 | def __call__(self, img, mask): 429 | angle = self.get_params(self.degrees) 430 | 431 | return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), \ 432 | F.rotate(mask, angle, Image.NEAREST, self.expand, self.center) 433 | 434 | 435 | class RandomAffine(object): 436 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): 437 | if isinstance(degrees, numbers.Number): 438 | if degrees < 0: 439 | raise ValueError("If degrees is a single number, it must be positive.") 440 | self.degrees = (-degrees, degrees) 441 | else: 442 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ 443 | "degrees should be a list or tuple and it must be of length 2." 444 | self.degrees = degrees 445 | 446 | if translate is not None: 447 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 448 | "translate should be a list or tuple and it must be of length 2." 449 | for t in translate: 450 | if not (0.0 <= t <= 1.0): 451 | raise ValueError("translation values should be between 0 and 1") 452 | self.translate = translate 453 | 454 | if scale is not None: 455 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 456 | "scale should be a list or tuple and it must be of length 2." 457 | for s in scale: 458 | if s <= 0: 459 | raise ValueError("scale values should be positive") 460 | self.scale = scale 461 | 462 | if shear is not None: 463 | if isinstance(shear, numbers.Number): 464 | if shear < 0: 465 | raise ValueError("If shear is a single number, it must be positive.") 466 | self.shear = (-shear, shear) 467 | else: 468 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ 469 | "shear should be a list or tuple and it must be of length 2." 470 | self.shear = shear 471 | else: 472 | self.shear = shear 473 | 474 | self.resample = resample 475 | self.fillcolor = fillcolor 476 | 477 | @staticmethod 478 | def get_params(degrees, translate, scale_ranges, shears, img_size): 479 | angle = random.uniform(degrees[0], degrees[1]) 480 | if translate is not None: 481 | max_dx = translate[0] * img_size[0] 482 | max_dy = translate[1] * img_size[1] 483 | translations = (np.round(random.uniform(-max_dx, max_dx)), 484 | np.round(random.uniform(-max_dy, max_dy))) 485 | else: 486 | translations = (0, 0) 487 | 488 | if scale_ranges is not None: 489 | scale = random.uniform(scale_ranges[0], scale_ranges[1]) 490 | else: 491 | scale = 1.0 492 | 493 | if shears is not None: 494 | shear = random.uniform(shears[0], shears[1]) 495 | else: 496 | shear = 0.0 497 | 498 | return angle, translations, scale, shear 499 | 500 | def __call__(self, img, mask): 501 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) 502 | return F.affine(img, *ret, resample=Image.BILINEAR, fillcolor=self.fillcolor), \ 503 | F.affine(mask, *ret, resample=Image.NEAREST, fillcolor=self.fillcolor) 504 | -------------------------------------------------------------------------------- /loader/viah_loader.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | import torch 4 | from torch.utils.data import Dataset 5 | import h5py 6 | import numpy as np 7 | import loader.transforms as transforms 8 | from PIL import Image 9 | 10 | 11 | class viah_segmentation(Dataset): 12 | def __init__(self, ann='training', args=None): 13 | self.ann = ann 14 | self.MEAN = np.array([0.47341759*255, 0.28791303*255, 0.2850705*255]) 15 | self.STD = np.array([0.22645572*255, 0.15276193*255, 0.140702*255]) 16 | if ann == 'training': 17 | self.transformations = transforms.Compose([transforms.ToPILImage(), 18 | transforms.RandomResizedCrop(size=(256, 256), 19 | mask_size=(256, 256), 20 | scale=(0.75, 2)), 21 | transforms.ColorJitter(brightness=0.4, 22 | contrast=0.4, 23 | saturation=0.4, 24 | hue=0.1), 25 | transforms.RandomRotation(25), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | transforms.Normalize(self.MEAN, self.STD)]) 29 | else: 30 | self.transformations = transforms.Compose([transforms.ToPILImage(), 31 | transforms.ToTensor(), 32 | transforms.Normalize(self.MEAN, self.STD)]) 33 | if ann == 'training': 34 | self.data_length = 100 35 | else: 36 | self.data_length = 68 37 | self.args = args 38 | 39 | def __len__(self): 40 | return self.data_length 41 | 42 | def __getitem__(self, item): 43 | if self.ann == 'training': 44 | self.data = h5py.File('Data/full_training_viah.h5', 'r') 45 | else: 46 | self.data = h5py.File('Data/full_test_viah.h5', 'r') 47 | self.mask = self.data['mask_single'] 48 | self.imgs = self.data['imgs'] 49 | self.img_list = list(self.imgs) 50 | self.mask_list = list(self.mask) 51 | cimage = self.img_list[item] 52 | img = self.imgs.get(cimage).value 53 | cmask = self.mask_list[item] 54 | mask = self.mask.get(cmask).value 55 | img = img.astype(np.uint8) 56 | mask = mask.astype(np.uint8) 57 | img, mask = self.transformations(img, mask) 58 | return img, mask 59 | 60 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/models/__init__.py -------------------------------------------------------------------------------- /models/hardnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AdaptationMismatch(Exception): pass 8 | 9 | 10 | class Flatten(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, x): 15 | return x.view(x.data.size(0), -1) 16 | 17 | 18 | class CombConvLayer(nn.Sequential): 19 | def __init__(self, in_channels, out_channels, kernel=1, stride=1, dropout=0.1, bias=False): 20 | super().__init__() 21 | self.add_module('layer1', ConvLayer(in_channels, out_channels, kernel)) 22 | self.add_module('layer2', DWConvLayer(out_channels, out_channels, stride=stride)) 23 | 24 | def forward(self, x): 25 | return super().forward(x) 26 | 27 | 28 | class DWConvLayer(nn.Sequential): 29 | def __init__(self, in_channels, out_channels, stride=1, bias=False): 30 | super().__init__() 31 | out_ch = out_channels 32 | 33 | groups = in_channels 34 | kernel = 3 35 | # print(kernel, 'x', kernel, 'x', out_channels, 'x', out_channels, 'DepthWise') 36 | 37 | self.add_module('dwconv', nn.Conv2d(groups, groups, kernel_size=3, 38 | stride=stride, padding=1, groups=groups, bias=bias)) 39 | self.add_module('norm', nn.BatchNorm2d(groups)) 40 | 41 | def forward(self, x): 42 | return super().forward(x) 43 | 44 | 45 | class ConvLayer(nn.Sequential): 46 | def __init__(self, in_channels, out_channels, kernel=3, stride=1, dropout=0.1, bias=False): 47 | super().__init__() 48 | out_ch = out_channels 49 | groups = 1 50 | # print(kernel, 'x', kernel, 'x', in_channels, 'x', out_channels) 51 | self.add_module('conv', nn.Conv2d(in_channels, out_ch, kernel_size=kernel, 52 | stride=stride, padding=kernel // 2, groups=groups, bias=bias)) 53 | self.add_module('norm', nn.BatchNorm2d(out_ch)) 54 | self.add_module('relu', nn.ReLU6(True)) 55 | 56 | def forward(self, x): 57 | return super().forward(x) 58 | 59 | 60 | class HarDBlock(nn.Module): 61 | def get_link(self, layer, base_ch, growth_rate, grmul): 62 | if layer == 0: 63 | return base_ch, 0, [] 64 | out_channels = growth_rate 65 | link = [] 66 | for i in range(10): 67 | dv = 2 ** i 68 | if layer % dv == 0: 69 | k = layer - dv 70 | link.append(k) 71 | if i > 0: 72 | out_channels *= grmul 73 | out_channels = int(int(out_channels + 1) / 2) * 2 74 | in_channels = 0 75 | for i in link: 76 | ch, _, _ = self.get_link(i, base_ch, growth_rate, grmul) 77 | in_channels += ch 78 | return out_channels, in_channels, link 79 | 80 | def get_out_ch(self): 81 | return self.out_channels 82 | 83 | def __init__(self, in_channels, growth_rate, grmul, n_layers, keepBase=False, residual_out=False, dwconv=False): 84 | super().__init__() 85 | self.keepBase = keepBase 86 | self.links = [] 87 | layers_ = [] 88 | self.out_channels = 0 # if upsample else in_channels 89 | for i in range(n_layers): 90 | outch, inch, link = self.get_link(i + 1, in_channels, growth_rate, grmul) 91 | self.links.append(link) 92 | use_relu = residual_out 93 | if dwconv: 94 | layers_.append(CombConvLayer(inch, outch)) 95 | else: 96 | layers_.append(ConvLayer(inch, outch)) 97 | 98 | if (i % 2 == 0) or (i == n_layers - 1): 99 | self.out_channels += outch 100 | # print("Blk out =",self.out_channels) 101 | self.layers = nn.ModuleList(layers_) 102 | 103 | def forward(self, x): 104 | layers_ = [x] 105 | 106 | for layer in range(len(self.layers)): 107 | link = self.links[layer] 108 | tin = [] 109 | for i in link: 110 | tin.append(layers_[i]) 111 | if len(tin) > 1: 112 | x = torch.cat(tin, 1) 113 | else: 114 | x = tin[0] 115 | out = self.layers[layer](x) 116 | layers_.append(out) 117 | 118 | t = len(layers_) 119 | out_ = [] 120 | for i in range(t): 121 | if (i == 0 and self.keepBase) or \ 122 | (i == t - 1) or (i % 2 == 1): 123 | out_.append(layers_[i]) 124 | out = torch.cat(out_, 1) 125 | return out 126 | 127 | 128 | class HarDNet(nn.Module): 129 | def __init__(self, depth_wise=False, arch=85, pretrained=True, weight_path='', out=1, args=None): 130 | super().__init__() 131 | first_ch = [32, 64] 132 | second_kernel = 3 133 | max_pool = True 134 | grmul = 1.7 135 | drop_rate = 0.1 136 | args['order'] = arch 137 | # HarDNet68 138 | ch_list = [128, 256, 320, 640, 1024] 139 | gr = [14, 16, 20, 40, 160] 140 | n_layers = [8, 16, 16, 16, 4] 141 | downSamp = [1, 0, 1, 1, 0] 142 | 143 | if arch == 85: 144 | # HarDNet85 145 | first_ch = [48, 96] 146 | ch_list = [192, 256, 320, 480, 720, 1280] 147 | gr = [24, 24, 28, 36, 48, 256] 148 | n_layers = [8, 16, 16, 16, 16, 4] 149 | downSamp = [1, 0, 1, 0, 1, 0] 150 | drop_rate = 0.2 151 | elif arch == 39: 152 | # HarDNet39 153 | first_ch = [24, 48] 154 | ch_list = [96, 320, 640, 1024] 155 | grmul = 1.6 156 | gr = [16, 20, 64, 160] 157 | n_layers = [4, 16, 8, 4] 158 | downSamp = [1, 1, 1, 0] 159 | 160 | if depth_wise: 161 | second_kernel = 1 162 | max_pool = False 163 | drop_rate = 0.05 164 | 165 | blks = len(n_layers) 166 | self.base = nn.ModuleList([]) 167 | 168 | # First Layer: Standard Conv3x3, Stride=2 169 | self.base.append( 170 | ConvLayer(in_channels=3, out_channels=first_ch[0], kernel=3, 171 | stride=2, bias=False)) 172 | 173 | # Second Layer 174 | self.base.append(ConvLayer(first_ch[0], first_ch[1], kernel=second_kernel)) 175 | 176 | # Maxpooling or DWConv3x3 downsampling 177 | if max_pool: 178 | self.base.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 179 | else: 180 | self.base.append(DWConvLayer(first_ch[1], first_ch[1], stride=2)) 181 | 182 | # Build all HarDNet blocks 183 | ch = first_ch[1] 184 | for i in range(blks): 185 | blk = HarDBlock(ch, gr[i], grmul, n_layers[i], dwconv=depth_wise) 186 | ch = blk.get_out_ch() 187 | self.base.append(blk) 188 | 189 | if i == blks - 1 and arch == 85: 190 | self.base.append(nn.Dropout(0.1)) 191 | 192 | self.base.append(ConvLayer(ch, ch_list[i], kernel=1)) 193 | ch = ch_list[i] 194 | if downSamp[i] == 1: 195 | if max_pool: 196 | self.base.append(nn.MaxPool2d(kernel_size=2, stride=2)) 197 | else: 198 | self.base.append(DWConvLayer(ch, ch, stride=2)) 199 | 200 | ch = ch_list[blks - 1] 201 | self.base.append( 202 | nn.Sequential( 203 | nn.AdaptiveAvgPool2d((1, 1)), 204 | Flatten(), 205 | nn.Dropout(drop_rate), 206 | nn.Linear(ch, 1000))) 207 | 208 | # print(self.base) 209 | 210 | if pretrained: 211 | if hasattr(torch, 'hub'): 212 | 213 | if arch == 68 and not depth_wise: 214 | checkpoint = 'https://ping-chao.com/hardnet/hardnet68-5d684880.pth' 215 | elif arch == 85 and not depth_wise: 216 | checkpoint = 'https://ping-chao.com/hardnet/hardnet85-a28faa00.pth' 217 | elif arch == 68 and depth_wise: 218 | checkpoint = 'https://ping-chao.com/hardnet/hardnet68ds-632474d2.pth' 219 | else: 220 | checkpoint = 'https://ping-chao.com/hardnet/hardnet39ds-0e6c6fa9.pth' 221 | 222 | self.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False)) 223 | else: 224 | postfix = 'ds' if depth_wise else '' 225 | weight_file = '%shardnet%d%s.pth' % (weight_path, arch, postfix) 226 | if not os.path.isfile(weight_file): 227 | print(weight_file, 'is not found') 228 | exit(0) 229 | weights = torch.load(weight_file) 230 | self.load_state_dict(weights) 231 | 232 | postfix = 'DS' if depth_wise else '' 233 | print('ImageNet pretrained weights for HarDNet%d%s is loaded' % (arch, postfix)) 234 | 235 | if int(args['outlayer']) == 1: 236 | if int(args['order']) == 39: 237 | self.features = 96 238 | self.base = self.base[0:5] 239 | if int(args['order']) == 68: 240 | self.features = 128 241 | self.base = self.base[0:5] 242 | if int(args['order']) == 85: 243 | self.features = 192 244 | self.base = self.base[0:5] 245 | elif int(args['outlayer']) == 2: 246 | if int(args['order']) == 39: 247 | self.features = 320 248 | self.base = self.base[0:8] 249 | if int(args['order']) == 68: 250 | self.features = 320 251 | self.base = self.base[0:10] 252 | if int(args['order']) == 85: 253 | self.features = 320 254 | self.base = self.base[0:10] 255 | elif int(args['outlayer']) == 3: 256 | if int(args['order']) == 39: 257 | self.features = 640 258 | self.base = self.base[0:11] 259 | if int(args['order']) == 68: 260 | self.features = 640 261 | self.base = self.base[0:13] 262 | if int(args['order']) == 85: 263 | self.features = 720 264 | self.base = self.base[0:15] 265 | elif int(args['outlayer']) == 4: 266 | if int(args['order']) == 39: 267 | self.features = 1024 268 | self.base = self.base[0:14] 269 | if int(args['order']) == 68: 270 | self.features = 1024 271 | self.base = self.base[0:16] 272 | if int(args['order']) == 85: 273 | self.features = 1280 274 | self.base = self.base[0:19] 275 | if int(args['order']) == 39: 276 | self.full_features = [48, 96, 320, 640, 1024] 277 | self.list = [1, 4, 7, 10, 13] 278 | if int(args['order']) == 68: 279 | self.full_features = [64, 128, 320, 640, 1024] 280 | self.list = [1, 4, 9, 12, 15] 281 | if int(args['order']) == 85: 282 | self.full_features = [96, 192, 320, 720, 1280] 283 | self.list = [1, 4, 9, 14, 18] 284 | 285 | def forward(self, x): 286 | for inx, layer in enumerate(self.base): 287 | x = layer(x) 288 | if inx == self.list[0]: 289 | x2 = x 290 | if inx == len(self.base) - 1: 291 | return x2 292 | elif inx == self.list[1]: 293 | x4 = x 294 | if inx == len(self.base) - 1: 295 | return x2, x4 296 | elif inx == self.list[2]: 297 | x8 = x 298 | if inx == len(self.base) - 1: 299 | return x2, x4, x8 300 | elif inx == self.list[3]: 301 | x16 = x 302 | if inx == len(self.base) - 1: 303 | return x2, x4, x8, x16 304 | elif inx == self.list[4]: 305 | x32 = x 306 | if inx == len(self.base) - 1: 307 | return x2, x4, x8, x16, x32 308 | 309 | 310 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from models.vanilla import * 2 | import neural_renderer as nr 3 | from models.hardnet import * 4 | from utils.utils_train import norm_input 5 | 6 | class Decoder(nn.Module): 7 | def __init__(self, full_features, args): 8 | super(Decoder, self).__init__() 9 | if int(args['outlayer']) == 2: 10 | self.up1 = UpBlock(full_features[1] + full_features[0], 2, 11 | func='tanh', drop=float(args['drop'])).cuda() 12 | if int(args['outlayer']) == 3: 13 | self.up1 = UpBlock(full_features[2] + full_features[1], full_features[1], 14 | func='relu', drop=float(args['drop'])).cuda() 15 | self.up2 = UpBlock(full_features[1] + full_features[0], 2, 16 | func='tanh', drop=float(args['drop'])).cuda() 17 | if int(args['outlayer']) == 4: 18 | self.up1 = UpBlock(full_features[3] + full_features[2], full_features[2], 19 | func='relu', drop=float(args['drop'])).cuda() 20 | self.up2 = UpBlock(full_features[2] + full_features[1], full_features[1], 21 | func='relu', drop=float(args['drop'])).cuda() 22 | self.up3 = UpBlock(full_features[1] + full_features[0], 2, 23 | func='tanh', drop=float(args['drop'])).cuda() 24 | self.args = args 25 | 26 | def forward(self, x, size): 27 | if int(self.args['outlayer']) == 2: 28 | shift_map = self.up1(x[1], x[0]) 29 | if int(self.args['outlayer']) == 3: 30 | z = self.up1(x[2], x[1]) 31 | shift_map = self.up2(z, x[0]) 32 | if int(self.args['outlayer']) == 4: 33 | z = self.up1(x[3], x[2]) 34 | z = self.up2(z, x[1]) 35 | shift_map = self.up3(z, x[0]) 36 | shift_map = F.interpolate(shift_map, size=size, mode='bilinear', align_corners=True) 37 | return shift_map[:, 0, :, :].unsqueeze(dim=1), shift_map[:, 1, :, :].unsqueeze(dim=1) 38 | 39 | 40 | class DeepACM(nn.Module): 41 | def __init__(self, args): 42 | super(DeepACM, self).__init__() 43 | self.backbone = HarDNet(depth_wise=bool(int(args['depth_wise'])), arch=int(args['order']), args=args) 44 | self.ACMDecoder = Decoder(self.backbone.full_features, args) 45 | self.nP = int(args['nP']) 46 | self.texture_size = 2 47 | self.camera_distance = 1 48 | self.elevation = 0 49 | self.azimuth = 0 50 | self.image_size = int(args['im_size']) 51 | self.renderer = nr.Renderer(camera_mode='look_at', image_size=self.image_size, light_intensity_ambient=1, 52 | light_intensity_directional=1, perspective=False) 53 | 54 | def forward(self, I, P, faces, it): 55 | size = I.size()[2:] 56 | z = self.backbone(I) 57 | Ix, Iy = self.ACMDecoder(z, size) 58 | masks = [] 59 | Ps = [] 60 | for i in range(it): 61 | Pxx = F.grid_sample(Ix, P).transpose(3, 2) 62 | Pyy = F.grid_sample(Iy, P).transpose(3, 2) 63 | Pedge = torch.cat((Pxx, Pyy), -1) 64 | P = Pedge + P 65 | z = torch.ones((P.shape[0], 1, P.shape[2], 1)).cuda() 66 | PP = torch.cat((P, z), 3) 67 | PP = torch.squeeze(PP, dim=1) 68 | PP[:, :, 1] = PP[:, :, 1]*-1 69 | faces = torch.squeeze(faces, dim=1) 70 | self.renderer.eye = nr.get_points_from_angles(self.camera_distance, self.elevation, self.azimuth) 71 | mask = self.renderer(PP, faces, mode='silhouettes').unsqueeze(dim=1) 72 | PP[:, :, 1] = PP[:, :, 1]*-1 73 | Ps.append(PP[:, :, 0:2].unsqueeze(dim=1)) 74 | masks.append(mask) 75 | return masks, Ps, Ix, Iy, I 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /models/model_seg.py: -------------------------------------------------------------------------------- 1 | from models.vanilla import * 2 | import neural_renderer as nr 3 | from models.hardnet import * 4 | 5 | 6 | class final_layer(nn.Module): 7 | def __init__(self): 8 | super(final_layer, self).__init__() 9 | 10 | def forward(self, z, size): 11 | return F.interpolate(z, size=size, mode='bilinear', align_corners=True) 12 | 13 | 14 | class Decoder(nn.Module): 15 | def __init__(self, full_features, args): 16 | super(Decoder, self).__init__() 17 | if int(args['outlayer']) == 2: 18 | self.up1 = UpBlock(full_features[1] + full_features[0], 1, 19 | func='sigmoid', drop=float(args['drop'])).cuda() 20 | if int(args['outlayer']) == 3: 21 | self.up1 = UpBlock(full_features[2] + full_features[1], full_features[1], 22 | func='relu', drop=float(args['drop'])).cuda() 23 | self.up2 = UpBlock(full_features[1] + full_features[0], 1, 24 | func='sigmoid', drop=float(args['drop'])).cuda() 25 | if int(args['outlayer']) == 4: 26 | self.up1 = UpBlock(full_features[3] + full_features[2], full_features[2], 27 | func='relu', drop=float(args['drop'])).cuda() 28 | self.up2 = UpBlock(full_features[2] + full_features[1], full_features[1], 29 | func='relu', drop=float(args['drop'])).cuda() 30 | self.up3 = UpBlock(full_features[1] + full_features[0], 1, 31 | func='sigmoid', drop=float(args['drop'])).cuda() 32 | self.args = args 33 | self.final = final_layer() 34 | 35 | def forward(self, x, size): 36 | if int(self.args['outlayer']) == 2: 37 | z = self.up1(x[1], x[0]) 38 | if int(self.args['outlayer']) == 3: 39 | z = self.up1(x[2], x[1]) 40 | z = self.up2(z, x[0]) 41 | if int(self.args['outlayer']) == 4: 42 | z = self.up1(x[3], x[2]) 43 | z = self.up2(z, x[1]) 44 | z = self.up3(z, x[0]) 45 | return self.final(z, size) 46 | 47 | 48 | class Segmentation(nn.Module): 49 | def __init__(self, args): 50 | super(Segmentation, self).__init__() 51 | self.backbone = HarDNet(depth_wise=bool(int(args['depth_wise'])), arch=int(args['order']), args=args) 52 | self.decoder = Decoder(self.backbone.full_features, args) 53 | 54 | def forward(self, I): 55 | size = I.size()[2:] 56 | z = self.backbone(I) 57 | return self.decoder(z, size) 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /models/vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size=3, drop = 0): 8 | super(ResidualBlock, self).__init__() 9 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1) 10 | self.conv1_drop = nn.Dropout2d(drop) 11 | self.BN1 = nn.BatchNorm2d(out_channels) 12 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=1) 13 | self.conv2_drop = nn.Dropout2d(drop) 14 | self.BN2 = nn.BatchNorm2d(out_channels) 15 | 16 | def forward(self, x_in): 17 | x = self.conv1_drop(self.conv1(x_in)) 18 | x = F.relu(self.BN1(x)) 19 | x = self.conv2_drop(self.conv2(x)) 20 | x = F.relu(self.BN2(x)) 21 | return x 22 | 23 | 24 | class DownBlock(nn.Module): 25 | def __init__(self, in_channels, out_channels, kernel_size=3, drop=0): 26 | super(DownBlock, self).__init__() 27 | P = int((kernel_size -1 ) /2) 28 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=P) 29 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=P) 30 | self.pool = nn.MaxPool2d((2, 2)) 31 | self.conv1_drop = nn.Dropout2d(drop) 32 | self.conv2_drop = nn.Dropout2d(drop) 33 | self.BN = nn.BatchNorm2d(out_channels) 34 | 35 | def forward(self, x_in): 36 | x1 = self.conv2_drop(self.conv2(self.conv1_drop(self.conv1(x_in)))) 37 | x1_pool = F.relu(self.BN(self.pool(x1))) 38 | return x1, x1_pool 39 | 40 | 41 | class UpBlock(nn.Module): 42 | def __init__(self, in_channels, out_channels, kernel_size=3, func=None, drop=0): 43 | super(UpBlock, self).__init__() 44 | P = int((kernel_size -1 ) /2) 45 | self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear') 46 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=P) 47 | self.conv1_drop = nn.Dropout2d(drop) 48 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=P) 49 | self.conv2_drop = nn.Dropout2d(drop) 50 | self.BN = nn.BatchNorm2d(out_channels) 51 | self.func = func 52 | 53 | def forward(self, x_in, x_up): 54 | x = self.Upsample(x_in) 55 | x_cat = torch.cat((x, x_up), 1) 56 | x1 = self.conv2_drop(self.conv2(self.conv1_drop(self.conv1(x_cat)))) 57 | if self.func == 'tanh': 58 | return F.tanh(self.BN(x1)) 59 | elif self.func == 'relu': 60 | return F.relu(self.BN(x1)) 61 | elif self.func == 'sigmoid': 62 | return F.sigmoid(self.BN(x1)) 63 | 64 | 65 | class Encoder(nn.Module): 66 | def __init__(self ,AEdim, drop=0): 67 | super(Encoder, self).__init__() 68 | a, b, c = int(AEdim /4), int(AEdim /2), AEdim 69 | self.down1 = DownBlock(3, a, drop=drop) 70 | self.down2 = DownBlock(a, b, drop=drop) 71 | self.down3 = DownBlock(b, c, drop=drop) 72 | 73 | def forward(self, x_in): 74 | x1, x1_pool = self.down1(x_in) 75 | x2, x2_pool = self.down2(x1_pool) 76 | x3, x3_pool = self.down3(x2_pool) 77 | return x1, x2, x3, x3_pool 78 | -------------------------------------------------------------------------------- /out/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/out/__init__.py -------------------------------------------------------------------------------- /pics/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/pics/16.jpg -------------------------------------------------------------------------------- /pics/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/pics/8.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | absl-py=0.9.0=pypi_0 6 | attrs=19.3.0=py_0 7 | backcall=0.2.0=pyh9f0ad1d_0 8 | blas=1.0=mkl 9 | bleach=3.1.5=pyh9f0ad1d_0 10 | blinker=1.4=py38_0 11 | brotlipy=0.7.0=py38h1e0a361_1000 12 | c-ares=1.16.1=h7b6447c_0 13 | ca-certificates=2020.7.22=0 14 | cachetools=4.1.0=pypi_0 15 | certifi=2020.6.20=py38_0 16 | cffi=1.14.0=py38he30daa8_1 17 | chardet=3.0.4=py38h32f6830_1006 18 | click=7.1.2=py_0 19 | cryptography=2.9.2=py38h766eaa4_0 20 | cudatoolkit=10.0.130=0 21 | cycler=0.10.0=pypi_0 22 | cython=0.29.20=pypi_0 23 | decorator=4.4.2=py_0 24 | defusedxml=0.6.0=py_0 25 | entrypoints=0.3=py38h32f6830_1001 26 | freetype=2.9.1=h8a8886c_1 27 | google-auth=1.17.0=pypi_0 28 | google-auth-oauthlib=0.4.1=py_2 29 | grpcio=1.29.0=pypi_0 30 | h5py=2.10.0=pypi_0 31 | idna=2.9=pypi_0 32 | imageio=2.8.0=pypi_0 33 | importlib-metadata=1.7.0=py38h32f6830_0 34 | importlib_metadata=1.7.0=0 35 | intel-openmp=2020.1=217 36 | ipdb=0.13.2=pypi_0 37 | ipykernel=5.3.2=py38h23f93f0_0 38 | ipython=7.15.0=pypi_0 39 | ipython-genutils=0.2.0=pypi_0 40 | ipython_genutils=0.2.0=py_1 41 | jedi=0.17.0=pypi_0 42 | jinja2=2.11.2=pyh9f0ad1d_0 43 | joblib=0.15.1=pypi_0 44 | jpeg=9b=h024ee3a_2 45 | json5=0.9.5=pypi_0 46 | jsonpatch=1.25=pypi_0 47 | jsonpointer=2.0=pypi_0 48 | jsonschema=3.2.0=py38h32f6830_1 49 | jupyter_client=6.1.5=py_0 50 | jupyter_core=4.6.3=py38h32f6830_1 51 | jupyterlab=2.1.5=py_0 52 | jupyterlab_server=1.2.0=py_0 53 | kiwisolver=1.2.0=pypi_0 54 | ld_impl_linux-64=2.33.1=h53a641e_7 55 | libedit=3.1.20181209=hc058e9b_0 56 | libffi=3.3=he6710b0_1 57 | libgcc-ng=9.1.0=hdf63c60_0 58 | libgfortran-ng=7.3.0=hdf63c60_0 59 | libpng=1.6.37=hbc83047_0 60 | libprotobuf=3.12.4=hd408876_0 61 | libsodium=1.0.17=h516909a_0 62 | libstdcxx-ng=9.1.0=hdf63c60_0 63 | libtiff=4.1.0=h2733197_1 64 | lz4-c=1.9.2=he6710b0_0 65 | markdown=3.2.2=py38_0 66 | markupsafe=1.1.1=py38h1e0a361_1 67 | matplotlib=3.2.1=pypi_0 68 | mistune=0.8.4=py38h1e0a361_1001 69 | mkl=2020.1=217 70 | mkl-service=2.3.0=py38he904b0f_0 71 | mkl_fft=1.0.15=py38ha843d7b_0 72 | mkl_random=1.1.1=py38h0573a6f_0 73 | nbconvert=5.6.1=py38h32f6830_1 74 | nbformat=5.0.7=py_0 75 | ncurses=6.2=he6710b0_1 76 | networkx=2.4=pypi_0 77 | neural-renderer-pytorch=1.1.3=pypi_0 78 | ninja=1.9.0=py38hfd86e86_0 79 | notebook=6.0.3=py38h32f6830_1 80 | numpy=1.18.1=py38h4f9e942_0 81 | numpy-base=1.18.1=py38hde5b4d6_1 82 | oauthlib=3.1.0=py_0 83 | olefile=0.46=py_0 84 | opencv-python=4.2.0.34=pypi_0 85 | openssl=1.1.1h=h7b6447c_0 86 | packaging=20.4=pyh9f0ad1d_0 87 | pandoc=2.10=0 88 | pandocfilters=1.4.2=pypi_0 89 | parso=0.7.0=pyh9f0ad1d_0 90 | pexpect=4.8.0=py38h32f6830_1 91 | pickleshare=0.7.5=py38h32f6830_1001 92 | pillow=7.1.2=py38hb39fc2d_0 93 | pip=20.0.2=py38_3 94 | progressbar=2.5=pypi_0 95 | prometheus_client=0.8.0=pyh9f0ad1d_0 96 | prompt-toolkit=3.0.5=py_1 97 | protobuf=3.12.2=pypi_0 98 | ptyprocess=0.6.0=py_1001 99 | pyasn1=0.4.8=py_0 100 | pyasn1-modules=0.2.8=pypi_0 101 | pycocotools=2.0.0=pypi_0 102 | pycparser=2.20=pyh9f0ad1d_2 103 | pygments=2.6.1=py_0 104 | pyjwt=1.7.1=py38_0 105 | pyopengl=3.1.5=pypi_0 106 | pyopenssl=19.1.0=py_1 107 | pyparsing=2.4.7=pyh9f0ad1d_0 108 | pyrsistent=0.16.0=py38h1e0a361_0 109 | pysocks=1.7.1=py38h32f6830_1 110 | python=3.8.3=hcff3b4d_0 111 | python-dateutil=2.8.1=py_0 112 | python_abi=3.8=1_cp38 113 | pytorch=1.4.0=py3.8_cuda10.0.130_cudnn7.6.3_0 114 | pywavelets=1.1.1=pypi_0 115 | pyzmq=19.0.1=pypi_0 116 | readline=8.0=h7b6447c_0 117 | requests=2.23.0=pypi_0 118 | requests-oauthlib=1.3.0=py_0 119 | rsa=4.1=pypi_0 120 | scikit-image=0.17.2=pypi_0 121 | scikit-learn=0.23.1=pypi_0 122 | scipy=1.4.1=pypi_0 123 | send2trash=1.5.0=py_0 124 | setuptools=47.1.1=py38_0 125 | six=1.15.0=py_0 126 | sqlite=3.31.1=h62c20be_1 127 | tensorboard=2.2.2=pypi_0 128 | tensorboard-plugin-wit=1.6.0.post3=pypi_0 129 | tensorboardx=2.1=py_0 130 | terminado=0.8.3=pypi_0 131 | testpath=0.4.4=py_0 132 | threadpoolctl=2.1.0=pypi_0 133 | tifffile=2020.6.3=pypi_0 134 | tk=8.6.8=hbc83047_0 135 | torchfile=0.1.0=pypi_0 136 | torchvision=0.5.0=py38_cu100 137 | tornado=6.0.4=py38h1e0a361_1 138 | tqdm=4.46.1=pypi_0 139 | traitlets=4.3.3=py38h32f6830_1 140 | urllib3=1.25.9=py_0 141 | visdom=0.1.8.9=pypi_0 142 | wcwidth=0.2.4=pypi_0 143 | webencodings=0.5.1=pypi_0 144 | websocket-client=0.57.0=pypi_0 145 | werkzeug=1.0.1=py_0 146 | wheel=0.34.2=py38_0 147 | xlrd=1.2.0=py_0 148 | xz=5.2.5=h7b6447c_0 149 | zeromq=4.3.2=he1b5a44_2 150 | zipp=3.1.0=py_0 151 | zlib=1.2.11=h7b6447c_3 152 | zstd=1.4.4=h0b5b093_3 153 | -------------------------------------------------------------------------------- /results/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/results/__init__.py -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from tensorboardX import SummaryWriter 3 | 4 | from models.model import * 5 | from models.model_seg import Segmentation 6 | from loader.viah_loader import * 7 | from loader.bing_loader import * 8 | from utils.utils_args import * 9 | from utils.utils_eval import * 10 | from utils.utils_train import * 11 | from utils.utils_tri import * 12 | from utils.utils_vis import * 13 | from utils.utils_lr import * 14 | from utils.loss import * 15 | from utils.snake_loss import Snakeloss 16 | 17 | 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | torch.backends.cudnn.benchmark = True 20 | args = get_args() 21 | save_args(args) 22 | writer = SummaryWriter() 23 | 24 | PATH = r'results/' + args['task'] 25 | segnet = Segmentation(args) 26 | segnet1 = torch.load(PATH + '/best/SEG.pt') 27 | segnet.load_state_dict(segnet1.state_dict()) 28 | segnet.eval().to(device) 29 | model = DeepACM(args) 30 | model.train().to(device) 31 | 32 | P_test, faces_test = get_poly(int(args['im_size']), int(args['nP']), 33 | int(args['Radius']), int(args['im_size']) / 2, int(args['im_size']) / 2) 34 | faces_test = faces_test.unsqueeze(dim=0).unsqueeze(dim=0).cuda() 35 | faces = faces_test.repeat(int(args['Batch_size']), 1, 1, 1).cuda() 36 | PTrain = P_test.repeat(int(args['Batch_size']), 1, 1, 1).cuda() 37 | criterion = SoftDiceLoss() 38 | snake_loss = Snakeloss(criterion) 39 | optimizer = torch.optim.Adam(model.parameters(), lr=float(args['learning_rate']), weight_decay=float(args['WD'])) 40 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(args['D_rate']), gamma=0.3) 41 | if args['task'] == 'viah': 42 | trainset = viah_segmentation(ann='training', args=args) 43 | testset = viah_segmentation(ann='test', args=args) 44 | elif args['task'] == 'bing': 45 | trainset = bing_segmentation(ann='training', args=args) 46 | testset = bing_segmentation(ann='test', args=args) 47 | 48 | ds = torch.utils.data.DataLoader(trainset, batch_size=int(args['Batch_size']), shuffle=True, 49 | num_workers=int(args['nW']), drop_last=True) 50 | ds_tri = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, 51 | num_workers=1, drop_last=False) 52 | ds_val = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, 53 | num_workers=1, drop_last=False) 54 | best = 0 55 | max_iter = int(args['DeepIt']) 56 | for epoch in range(1, int(args['epochs'])): 57 | args['DeepIt'] = int(epoch/20) + 1 58 | if args['DeepIt'] > max_iter: 59 | args['DeepIt'] = max_iter 60 | loss_list = train(ds, model, segnet, optimizer, snake_loss, PTrain, faces, args) 61 | logger_train(epoch, writer, loss_list) 62 | scheduler.step() 63 | if epoch % 5 == 2: 64 | iou = eval_ds(ds_val, model, segnet, P_test, faces_test, args) 65 | writer.add_scalar('IoU', iou, global_step=epoch) 66 | if iou > best: 67 | best = iou 68 | print('best: ' + str(epoch) + ' with ' + str(best)) 69 | torch.save(model, PATH + '/' + 'ACM_best.pt') 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /train_seg.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from tensorboardX import SummaryWriter 3 | from tqdm import tqdm 4 | 5 | from models.model_seg import * 6 | from loader.viah_loader import * 7 | from loader.bing_loader import * 8 | from utils.utils_args import * 9 | from utils.loss import * 10 | from utils.utils_eval import get_dice_ji 11 | from utils.utils_lr import * 12 | 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | torch.backends.cudnn.benchmark = True 15 | 16 | 17 | def train(ds, model, optimizer, criterion, criterion2, scheduler, args): 18 | loss_list = [] 19 | for ix, (_x, _y) in tqdm(enumerate(ds)): 20 | _x = _x.float().cuda() 21 | _y = _y.float().cuda().unsqueeze(dim=1) 22 | optimizer.zero_grad() 23 | mask = model(_x) 24 | loss = 0.1*criterion(mask, _y) + 1*criterion2(mask, _y) 25 | loss_list.append(loss.item()) 26 | loss.backward() 27 | optimizer.step() 28 | if args['opt'] == 'sgd': 29 | scheduler.step() 30 | return loss_list 31 | 32 | 33 | def eval_ds(ds, model, writer, epoch, PATH1, best, label, args): 34 | model.eval() 35 | TestDice_list = [] 36 | TestIoU_list = [] 37 | for ix, (_x, _y) in enumerate(ds): 38 | _x = _x.float().cuda() 39 | _y = _y.float().cuda() 40 | Mask = model(_x) 41 | Mask[Mask >= 0.5] = 1 42 | Mask[Mask < 0.5] = 0 43 | (cDice, cIoU) = get_dice_ji(Mask, _y) 44 | TestDice_list.append(cDice) 45 | TestIoU_list.append(cIoU) 46 | Dice = np.mean(TestDice_list) 47 | IoU = np.mean(TestIoU_list) 48 | print((epoch, Dice, IoU)) 49 | if IoU > best and label=='test': 50 | torch.save(model, PATH1 + '/SEG_best.pt') 51 | print('best IOU results: ' + str(IoU)) 52 | writer.add_scalar('Dice_' + label, Dice, global_step=epoch) 53 | writer.add_scalar('IoU_' + label, IoU, global_step=epoch) 54 | model.train() 55 | return best, IoU 56 | 57 | 58 | def main(args, writer): 59 | PATH = r'results/' + args['task'] 60 | model = Segmentation(args) 61 | model.train().to(device) 62 | 63 | criterion = nn.BCELoss() 64 | criterion2 = SoftDiceLoss() 65 | if args['opt'] == 'adam': 66 | optimizer = torch.optim.Adam(model.parameters(), lr=float(args['learning_rate']), 67 | weight_decay=float(args['WD'])) 68 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(args['D_rate']), gamma=0.3) 69 | elif args['opt'] == 'sgd': 70 | wd_params, non_wd_params = [], [] 71 | for name, param in model.named_parameters(): 72 | if param.dim() == 1: 73 | non_wd_params.append(param) 74 | elif param.dim() == 2 or param.dim() == 4: 75 | wd_params.append(param) 76 | params_list = [ 77 | {'params': wd_params, }, 78 | {'params': non_wd_params, 'weight_decay': 0}, 79 | ] 80 | warmup_iters = 12 81 | optimizer = torch.optim.SGD(params_list, 82 | lr=float(args['learning_rate']), 83 | weight_decay=float(args['WD']), 84 | momentum=0.9) 85 | max_iter = int(args['D_rate']) 86 | scheduler = WarmupPolyLrScheduler(optimizer, 87 | power=0.9, 88 | max_iter=max_iter, 89 | warmup_iter=warmup_iters, 90 | warmup_ratio=0.001, 91 | warmup='exp', 92 | last_epoch=-1) 93 | 94 | if args['task'] == 'viah': 95 | trainset = viah_segmentation(ann='training', args=args) 96 | testset = viah_segmentation(ann='test', args=args) 97 | elif args['task'] == 'bing': 98 | trainset = bing_segmentation(ann='training', args=args) 99 | testset = bing_segmentation(ann='test', args=args) 100 | 101 | ds = torch.utils.data.DataLoader(trainset, batch_size=int(args['Batch_size']), shuffle=True, 102 | num_workers=int(args['nW']), drop_last=True) 103 | ds_val = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, 104 | num_workers=0, drop_last=False) 105 | best = 0 106 | for epoch in range(1, int(args['epochs'])): 107 | loss_list = train(ds, model, optimizer, criterion, criterion2, scheduler, args) 108 | print('************************************************************************') 109 | print('Epoch: ' + str(epoch) + ' Mask mean loss: ' + str(np.mean(loss_list)) + ' Mask max loss: ' + str( 110 | np.max(loss_list)) + ' Mask min loss: ' + str(np.min(loss_list))) 111 | writer.add_scalar('MaskLoss', np.mean(loss_list), global_step=epoch) 112 | print('************************************************************************') 113 | if args['opt'] == 'adam': 114 | scheduler.step() 115 | 116 | if epoch % 3 == 1: 117 | best, _ = eval_ds(ds_val, model, writer, epoch, PATH, best, 'test', args) 118 | 119 | if __name__ == '__main__': 120 | args = get_args() 121 | save_args(args) 122 | writer = SummaryWriter() 123 | main(args, writer) 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/utils/__init__.py -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | get_tp_fp_fn, SoftDiceLoss, and DC_and_CE/TopK_loss are from https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/loss_functions 3 | """ 4 | 5 | import torch 6 | # from ND_Crossentropy import CrossentropyND, TopKLoss, WeightedCrossEntropyLoss 7 | from torch import nn 8 | from torch.autograd import Variable 9 | from torch import einsum 10 | import numpy as np 11 | 12 | 13 | def softmax_helper(x): 14 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py 15 | rpt = [1 for _ in range(len(x.size()))] 16 | rpt[1] = x.size(1) 17 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt) 18 | e_x = torch.exp(x - x_max) 19 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) 20 | 21 | 22 | def sum_tensor(inp, axes, keepdim=False): 23 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/tensor_utilities.py 24 | axes = np.unique(axes).astype(int) 25 | if keepdim: 26 | for ax in axes: 27 | inp = inp.sum(int(ax), keepdim=True) 28 | else: 29 | for ax in sorted(axes, reverse=True): 30 | inp = inp.sum(int(ax)) 31 | return inp 32 | 33 | 34 | def ignore_null(gt, shp_x, nl_cls=250): 35 | shp_x = list(shp_x) 36 | nC = shp_x[1] 37 | gt[gt == nl_cls] = nC + 1 38 | shp_x[1] = nC + 1 39 | return gt, shp_x 40 | 41 | 42 | def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False): 43 | """ 44 | net_output must be (b, c, x, y(, z))) 45 | gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) 46 | if mask is provided it must have shape (b, 1, x, y(, z))) 47 | :param net_output: 48 | :param gt: 49 | :param axes: 50 | :param mask: mask must be 1 for valid pixels and 0 for invalid pixels 51 | :param square: if True then fp, tp and fn will be squared before summation 52 | :return: 53 | """ 54 | if axes is None: 55 | axes = tuple(range(2, len(net_output.size()))) 56 | 57 | shp_x = net_output.shape 58 | shp_y = gt.shape 59 | 60 | with torch.no_grad(): 61 | if len(shp_x) != len(shp_y): 62 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 63 | 64 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 65 | # if this is the case then gt is probably already a one hot encoding 66 | y_onehot = gt 67 | else: 68 | gt = gt.long() 69 | gt, shp_x = ignore_null(gt, shp_x) 70 | y_onehot = torch.zeros(shp_x) 71 | if net_output.device.type == "cuda": 72 | y_onehot = y_onehot.cuda(net_output.device.index) 73 | y_onehot.scatter_(1, gt, 1) 74 | y_onehot = y_onehot[:, :-1, :, :] 75 | # input(y_onehot.shape) 76 | tp = net_output * y_onehot 77 | fp = net_output * (1 - y_onehot) 78 | fn = (1 - net_output) * y_onehot 79 | 80 | if mask is not None: 81 | tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) 82 | fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) 83 | fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) 84 | 85 | if square: 86 | tp = tp ** 2 87 | fp = fp ** 2 88 | fn = fn ** 2 89 | 90 | tp = sum_tensor(tp, axes, keepdim=False) 91 | fp = sum_tensor(fp, axes, keepdim=False) 92 | fn = sum_tensor(fn, axes, keepdim=False) 93 | 94 | return tp, fp, fn 95 | 96 | 97 | class GDiceLoss(nn.Module): 98 | def __init__(self, apply_nonlin=None, smooth=1e-5): 99 | """ 100 | Generalized Dice; 101 | Copy from: https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/losses.py#L29 102 | paper: https://arxiv.org/pdf/1707.03237.pdf 103 | tf code: https://github.com/NifTK/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py#L279 104 | """ 105 | super(GDiceLoss, self).__init__() 106 | 107 | self.apply_nonlin = apply_nonlin 108 | self.smooth = smooth 109 | 110 | def forward(self, net_output, gt): 111 | shp_x = net_output.shape # (batch size,class_num,x,y,z) 112 | shp_y = gt.shape # (batch size,1,x,y,z) 113 | # one hot code for gt 114 | with torch.no_grad(): 115 | if len(shp_x) != len(shp_y): 116 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 117 | 118 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 119 | # if this is the case then gt is probably already a one hot encoding 120 | y_onehot = gt 121 | else: 122 | gt = gt.long() 123 | y_onehot = torch.zeros(shp_x) 124 | if net_output.device.type == "cuda": 125 | y_onehot = y_onehot.cuda(net_output.device.index) 126 | y_onehot.scatter_(1, gt, 1) 127 | 128 | if self.apply_nonlin is not None: 129 | softmax_output = self.apply_nonlin(net_output) 130 | 131 | # copy from https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/losses.py#L29 132 | w: torch.Tensor = 1 / (einsum("bcxyz->bc", y_onehot).type(torch.float32) + 1e-10) ** 2 133 | intersection: torch.Tensor = w * einsum("bcxyz, bcxyz->bc", softmax_output, y_onehot) 134 | union: torch.Tensor = w * (einsum("bcxyz->bc", softmax_output) + einsum("bcxyz->bc", y_onehot)) 135 | divided: torch.Tensor = 1 - 2 * (einsum("bc->b", intersection) + self.smooth) / ( 136 | einsum("bc->b", union) + self.smooth) 137 | gdc = divided.mean() 138 | 139 | return gdc 140 | 141 | 142 | def flatten(tensor): 143 | """Flattens a given tensor such that the channel axis is first. 144 | The shapes are transformed as follows: 145 | (N, C, D, H, W) -> (C, N * D * H * W) 146 | """ 147 | C = tensor.size(1) 148 | # new axis order 149 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) 150 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 151 | transposed = tensor.permute(axis_order).contiguous() 152 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 153 | return transposed.view(C, -1) 154 | 155 | 156 | class GDiceLossV2(nn.Module): 157 | def __init__(self, apply_nonlin=None, smooth=1e-5): 158 | """ 159 | Generalized Dice; 160 | Copy from: https://github.com/wolny/pytorch-3dunet/blob/6e5a24b6438f8c631289c10638a17dea14d42051/unet3d/losses.py#L75 161 | paper: https://arxiv.org/pdf/1707.03237.pdf 162 | tf code: https://github.com/NifTK/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py#L279 163 | """ 164 | super(GDiceLossV2, self).__init__() 165 | 166 | self.apply_nonlin = apply_nonlin 167 | self.smooth = smooth 168 | 169 | def forward(self, net_output, gt): 170 | shp_x = net_output.shape # (batch size,class_num,x,y,z) 171 | shp_y = gt.shape # (batch size,1,x,y,z) 172 | # one hot code for gt 173 | with torch.no_grad(): 174 | if len(shp_x) != len(shp_y): 175 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 176 | 177 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 178 | # if this is the case then gt is probably already a one hot encoding 179 | y_onehot = gt 180 | else: 181 | gt = gt.long() 182 | y_onehot = torch.zeros(shp_x) 183 | if net_output.device.type == "cuda": 184 | y_onehot = y_onehot.cuda(net_output.device.index) 185 | y_onehot.scatter_(1, gt, 1) 186 | 187 | if self.apply_nonlin is not None: 188 | softmax_output = self.apply_nonlin(net_output) 189 | 190 | input = flatten(softmax_output) 191 | target = flatten(y_onehot) 192 | target = target.float() 193 | target_sum = target.sum(-1) 194 | class_weights = Variable(1. / (target_sum * target_sum).clamp(min=self.smooth), requires_grad=False) 195 | 196 | intersect = (input * target).sum(-1) * class_weights 197 | intersect = intersect.sum() 198 | 199 | denominator = ((input + target).sum(-1) * class_weights).sum() 200 | 201 | return 1. - 2. * intersect / denominator.clamp(min=self.smooth) 202 | 203 | 204 | class SSLoss(nn.Module): 205 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1., 206 | square=False): 207 | """ 208 | Sensitivity-Specifity loss 209 | paper: http://www.rogertam.ca/Brosch_MICCAI_2015.pdf 210 | tf code: https://github.com/NifTK/NiftyNet/blob/df0f86733357fdc92bbc191c8fec0dcf49aa5499/niftynet/layer/loss_segmentation.py#L392 211 | """ 212 | super(SSLoss, self).__init__() 213 | 214 | self.square = square 215 | self.do_bg = do_bg 216 | self.batch_dice = batch_dice 217 | self.apply_nonlin = apply_nonlin 218 | self.smooth = smooth 219 | self.r = 0.1 # weight parameter in SS paper 220 | 221 | def forward(self, net_output, gt, loss_mask=None): 222 | shp_x = net_output.shape 223 | shp_y = gt.shape 224 | # class_num = shp_x[1] 225 | 226 | with torch.no_grad(): 227 | if len(shp_x) != len(shp_y): 228 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 229 | 230 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 231 | # if this is the case then gt is probably already a one hot encoding 232 | y_onehot = gt 233 | else: 234 | gt = gt.long() 235 | y_onehot = torch.zeros(shp_x) 236 | if net_output.device.type == "cuda": 237 | y_onehot = y_onehot.cuda(net_output.device.index) 238 | y_onehot.scatter_(1, gt, 1) 239 | 240 | if self.batch_dice: 241 | axes = [0] + list(range(2, len(shp_x))) 242 | else: 243 | axes = list(range(2, len(shp_x))) 244 | 245 | if self.apply_nonlin is not None: 246 | softmax_output = self.apply_nonlin(net_output) 247 | 248 | # no object value 249 | bg_onehot = 1 - y_onehot 250 | squared_error = (y_onehot - softmax_output) ** 2 251 | specificity_part = sum_tensor(squared_error * y_onehot, axes) / (sum_tensor(y_onehot, axes) + self.smooth) 252 | sensitivity_part = sum_tensor(squared_error * bg_onehot, axes) / (sum_tensor(bg_onehot, axes) + self.smooth) 253 | 254 | ss = self.r * specificity_part + (1 - self.r) * sensitivity_part 255 | 256 | if not self.do_bg: 257 | if self.batch_dice: 258 | ss = ss[1:] 259 | else: 260 | ss = ss[:, 1:] 261 | ss = ss.mean() 262 | 263 | return ss 264 | 265 | 266 | class SoftDiceLoss(nn.Module): 267 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1e-6, 268 | square=False): 269 | """ 270 | paper: https://arxiv.org/pdf/1606.04797.pdf 271 | """ 272 | super(SoftDiceLoss, self).__init__() 273 | 274 | self.square = square 275 | self.do_bg = do_bg 276 | self.batch_dice = batch_dice 277 | self.apply_nonlin = apply_nonlin 278 | self.smooth = smooth 279 | 280 | def forward(self, x, y, loss_mask=None): 281 | shp_x = x.shape 282 | 283 | if self.batch_dice: 284 | axes = [0] + list(range(2, len(shp_x))) 285 | else: 286 | axes = list(range(2, len(shp_x))) 287 | 288 | if self.apply_nonlin is not None: 289 | x = self.apply_nonlin(x) 290 | 291 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square) 292 | 293 | dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth) 294 | 295 | if not self.do_bg: 296 | if self.batch_dice: 297 | dc = dc[1:] 298 | else: 299 | dc = dc[:, 1:] 300 | dc = dc.mean() 301 | 302 | return 1 - dc 303 | 304 | 305 | class IoULoss(nn.Module): 306 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1., 307 | square=False): 308 | """ 309 | paper: https://link.springer.com/chapter/10.1007/978-3-319-50835-1_22 310 | 311 | """ 312 | super(IoULoss, self).__init__() 313 | 314 | self.square = square 315 | self.do_bg = do_bg 316 | self.batch_dice = batch_dice 317 | self.apply_nonlin = apply_nonlin 318 | self.smooth = smooth 319 | 320 | def forward(self, x, y, loss_mask=None): 321 | shp_x = x.shape 322 | 323 | if self.batch_dice: 324 | axes = [0] + list(range(2, len(shp_x))) 325 | else: 326 | axes = list(range(2, len(shp_x))) 327 | 328 | if self.apply_nonlin is not None: 329 | x = self.apply_nonlin(x) 330 | 331 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square) 332 | 333 | iou = (tp + self.smooth) / (tp + fp + fn + self.smooth) 334 | 335 | if not self.do_bg: 336 | if self.batch_dice: 337 | iou = iou[1:] 338 | else: 339 | iou = iou[:, 1:] 340 | iou = iou.mean() 341 | 342 | return -iou 343 | 344 | 345 | class TverskyLoss(nn.Module): 346 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1., 347 | square=False): 348 | """ 349 | paper: https://arxiv.org/pdf/1706.05721.pdf 350 | """ 351 | super(TverskyLoss, self).__init__() 352 | 353 | self.square = square 354 | self.do_bg = do_bg 355 | self.batch_dice = batch_dice 356 | self.apply_nonlin = apply_nonlin 357 | self.smooth = smooth 358 | self.alpha = 0.3 359 | self.beta = 0.7 360 | 361 | def forward(self, x, y, loss_mask=None): 362 | shp_x = x.shape 363 | 364 | if self.batch_dice: 365 | axes = [0] + list(range(2, len(shp_x))) 366 | else: 367 | axes = list(range(2, len(shp_x))) 368 | 369 | if self.apply_nonlin is not None: 370 | x = self.apply_nonlin(x) 371 | 372 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square) 373 | 374 | tversky = (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth) 375 | 376 | if not self.do_bg: 377 | if self.batch_dice: 378 | tversky = tversky[1:] 379 | else: 380 | tversky = tversky[:, 1:] 381 | tversky = tversky.mean() 382 | 383 | return -tversky 384 | 385 | 386 | class FocalTversky_loss(nn.Module): 387 | """ 388 | paper: https://arxiv.org/pdf/1810.07842.pdf 389 | author code: https://github.com/nabsabraham/focal-tversky-unet/blob/347d39117c24540400dfe80d106d2fb06d2b99e1/losses.py#L65 390 | """ 391 | 392 | def __init__(self, tversky_kwargs, gamma=0.75): 393 | super(FocalTversky_loss, self).__init__() 394 | self.gamma = gamma 395 | self.tversky = TverskyLoss(**tversky_kwargs) 396 | 397 | def forward(self, net_output, target): 398 | tversky_loss = 1 + self.tversky(net_output, target) # = 1-tversky(net_output, target) 399 | focal_tversky = torch.pow(tversky_loss, self.gamma) 400 | return focal_tversky 401 | 402 | 403 | class AsymLoss(nn.Module): 404 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1., 405 | square=False): 406 | """ 407 | paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8573779 408 | """ 409 | super(AsymLoss, self).__init__() 410 | 411 | self.square = square 412 | self.do_bg = do_bg 413 | self.batch_dice = batch_dice 414 | self.apply_nonlin = apply_nonlin 415 | self.smooth = smooth 416 | self.beta = 1.5 417 | 418 | def forward(self, x, y, loss_mask=None): 419 | shp_x = x.shape 420 | 421 | if self.batch_dice: 422 | axes = [0] + list(range(2, len(shp_x))) 423 | else: 424 | axes = list(range(2, len(shp_x))) 425 | 426 | if self.apply_nonlin is not None: 427 | x = self.apply_nonlin(x) 428 | 429 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square) # shape: (batch size, class num) 430 | weight = (self.beta ** 2) / (1 + self.beta ** 2) 431 | asym = (tp + self.smooth) / (tp + weight * fn + (1 - weight) * fp + self.smooth) 432 | 433 | if not self.do_bg: 434 | if self.batch_dice: 435 | asym = asym[1:] 436 | else: 437 | asym = asym[:, 1:] 438 | asym = asym.mean() 439 | 440 | return -asym 441 | 442 | 443 | class DC_and_CE_loss(nn.Module): 444 | def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"): 445 | super(DC_and_CE_loss, self).__init__() 446 | self.aggregate = aggregate 447 | self.ce = CrossentropyND(**ce_kwargs) 448 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) 449 | 450 | def forward(self, net_output, target): 451 | dc_loss = self.dc(net_output, target) 452 | ce_loss = self.ce(net_output, target) 453 | if self.aggregate == "sum": 454 | result = ce_loss + dc_loss 455 | else: 456 | raise NotImplementedError("nah son") # reserved for other stuff (later) 457 | return result 458 | 459 | 460 | class PenaltyGDiceLoss(nn.Module): 461 | """ 462 | paper: https://openreview.net/forum?id=H1lTh8unKN 463 | """ 464 | 465 | def __init__(self, gdice_kwargs): 466 | super(PenaltyGDiceLoss, self).__init__() 467 | self.k = 2.5 468 | self.gdc = GDiceLoss(apply_nonlin=softmax_helper, **gdice_kwargs) 469 | 470 | def forward(self, net_output, target): 471 | gdc_loss = self.gdc(net_output, target) 472 | penalty_gdc = gdc_loss / (1 + self.k * (1 - gdc_loss)) 473 | 474 | return penalty_gdc 475 | 476 | 477 | class DC_and_topk_loss(nn.Module): 478 | def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"): 479 | super(DC_and_topk_loss, self).__init__() 480 | self.aggregate = aggregate 481 | self.ce = TopKLoss(**ce_kwargs) 482 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) 483 | 484 | def forward(self, net_output, target): 485 | dc_loss = self.dc(net_output, target) 486 | ce_loss = self.ce(net_output, target) 487 | if self.aggregate == "sum": 488 | result = ce_loss + dc_loss 489 | else: 490 | raise NotImplementedError("nah son") # reserved for other stuff (later?) 491 | return result 492 | 493 | 494 | class ExpLog_loss(nn.Module): 495 | """ 496 | paper: 3D Segmentation with Exponential Logarithmic Loss for Highly Unbalanced Object Sizes 497 | https://arxiv.org/pdf/1809.00076.pdf 498 | """ 499 | 500 | def __init__(self, soft_dice_kwargs, wce_kwargs, gamma=0.3): 501 | super(ExpLog_loss, self).__init__() 502 | self.wce = WeightedCrossEntropyLoss(**wce_kwargs) 503 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) 504 | self.gamma = gamma 505 | 506 | def forward(self, net_output, target): 507 | dc_loss = -self.dc(net_output, target) # weight=0.8 508 | wce_loss = self.wce(net_output, target) # weight=0.2 509 | # with torch.no_grad(): 510 | # print('dc loss:', dc_loss.cpu().numpy(), 'ce loss:', ce_loss.cpu().numpy()) 511 | # a = torch.pow(-torch.log(torch.clamp(dc_loss, 1e-6)), self.gamma) 512 | # b = torch.pow(-torch.log(torch.clamp(ce_loss, 1e-6)), self.gamma) 513 | # print('ExpLog dc loss:', a.cpu().numpy(), 'ExpLogce loss:', b.cpu().numpy()) 514 | # print('*'*20) 515 | explog_loss = 0.8 * torch.pow(-torch.log(torch.clamp(dc_loss, 1e-6)), self.gamma) + \ 516 | 0.2 * wce_loss 517 | 518 | return explog_loss -------------------------------------------------------------------------------- /utils/snake_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Snakeloss: 5 | def __init__(self, criterion): 6 | self.criterion = criterion 7 | 8 | def ballon_loss(self, mask): 9 | return 1 - torch.mean(mask, dim=[2, 3]).cuda().mean() 10 | 11 | def avg_dis(self, P): 12 | nP = P.shape[2] 13 | even1 = P[:, :, 0:nP:2, :] 14 | odd1 = P[:, :, 1:nP:2, :] 15 | diff1 = torch.sum((even1-odd1)**2, dim=[2, 3]) 16 | even2 = P[:, :, 2:nP:2, :] 17 | odd2 = P[:, :, 1:nP-1:2, :] 18 | diff2 = torch.sum((even2-odd2)**2, dim=[2, 3]) 19 | diff3 = torch.sum((P[:, :, 0, :] - P[:, :, nP-1, :])**2, dim=[1, 2]) 20 | return (1/3)*diff1**0.5+(1/3)*diff2**0.5+(1/3)*diff3**0.5 21 | 22 | def curvature_loss(self, P): 23 | Pf = P.roll(-1, dims=2) 24 | Pb = P.roll(1, dims=2) 25 | K = Pf + Pb - 2 * P 26 | return K.abs().mean() 27 | 28 | def snake_loss(self, num_of_it, net_out, gt): 29 | for it in range(num_of_it): 30 | if it == 0: 31 | loss = self.criterion(net_out[0][it], gt) + \ 32 | 0.1*self.avg_dis(net_out[1][it]).mean() 33 | else: 34 | loss = loss + \ 35 | self.criterion(net_out[0][it], gt) + \ 36 | 0.1*self.avg_dis(net_out[1][it]).mean() 37 | return loss 38 | 39 | -------------------------------------------------------------------------------- /utils/utils_TB.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/utils/utils_TB.py -------------------------------------------------------------------------------- /utils/utils_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser(description='Description of your program') 5 | parser.add_argument('-it', '--DeepIt', default=3, help='number of snake iteration', required=False) 6 | parser.add_argument('-lr', '--learning_rate', default=0.001, help='learning_rate', required=False) 7 | parser.add_argument('-bs', '--Batch_size', default=20, help='batch_size', required=False) 8 | parser.add_argument('-ep', '--epochs', default=25000, help='number of epoches', required=False) 9 | parser.add_argument('-drop', '--drop', default=0, help='dropout value', required=False) 10 | parser.add_argument('-R', '--Radius', default=64, help='initial guess circle radius', required=False) 11 | parser.add_argument('-D_rate', '--D_rate', default=50, help='opt. drop parameter', required=False) 12 | parser.add_argument('-opt', '--opt', default='adam', help='opt. type', required=False) 13 | parser.add_argument('-a', '--a', default=0.5, help='mask coeff with image', required=False) 14 | parser.add_argument('-nW', '--nW', default=0, help='number os workers', required=False) 15 | parser.add_argument('-WD', '--WD', default=0.00005, help='weight decay', required=False) 16 | parser.add_argument('-nP', '--nP', default=32, help='number of contour points', required=False) 17 | parser.add_argument('-order', '--order', default=85, help='backbone dimension', required=False) 18 | parser.add_argument('-depth_wise', '--depth_wise', default=0, help='depth_wise hardnet', required=False) 19 | parser.add_argument('-outlayer', '--outlayer', default=3, help='number of hardnet blocks', required=False) 20 | parser.add_argument('-im_size', '--im_size', default=256, help='image size', required=False) 21 | parser.add_argument('-task', '--task', default='bing', help='which dataset to use?', required=False) 22 | args = vars(parser.parse_args()) 23 | return args 24 | 25 | 26 | def save_args(args): 27 | path = r'results/' + args['task'] + '/params.csv' 28 | f = open(path, 'w') 29 | keys = list(args.keys()) 30 | vals = list(args.values()) 31 | for i in range(len(keys)): 32 | f.write(str(keys[i])+','+str(vals[i])+'\n') 33 | f.flush() 34 | -------------------------------------------------------------------------------- /utils/utils_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.morphology import binary_dilation, disk 3 | import os 4 | import torch 5 | from utils.utils_vis import send_image_to_TB 6 | from utils.utils_train import norm_input 7 | import cv2 8 | 9 | def image_norm(img): 10 | return (img - img.min()) / (img.max() - img.min()) 11 | 12 | 13 | def vis_ds(ds, model, segnet, PTrain, faces, args, num_of_ex=5): 14 | model.eval() 15 | for ix, (_x, _y) in enumerate(ds): 16 | if ix > num_of_ex: break 17 | _x = _x.float().cuda() 18 | img = image_norm(_x.squeeze(dim=0).detach().cpu().numpy().transpose(1, 2, 0)) 19 | _p = PTrain.float().cuda().clone() 20 | _y = _y.float().cuda() 21 | seg_out = segnet(_x) 22 | _x = norm_input(_x, seg_out, float(args['a'])) 23 | iter = int(args['DeepIt']) 24 | net_out = model(_x, _p, faces, iter) 25 | Mask = net_out[0][iter - 1] 26 | (_, cIoU) = get_dice_ji(Mask, _y) 27 | P = net_out[1][iter - 1] 28 | P_init = PTrain.squeeze().detach().cpu().numpy().transpose(1, 0) 29 | Mask = Mask.squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy() 30 | P = P.squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy().transpose(1, 0) 31 | P = np.concatenate((P, P[:, 0:1]), 1) 32 | Ix = net_out[2].squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy() 33 | Iy = net_out[3].squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy() 34 | GT = _y.squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy() 35 | im = cv2.cvtColor(send_image_to_TB(img, P_init, Mask, P, Ix, Iy, GT, cIoU), cv2.COLOR_RGBA2BGR) 36 | cv2.imwrite('out/' + str(ix) + '.jpg', im) 37 | model.train() 38 | 39 | 40 | def eval_ds(ds, model, segnet, PTrain, faces, args): 41 | model.eval() 42 | TestIoU_list = [] 43 | model.eval() 44 | with torch.no_grad(): 45 | for ix, (_x, _y) in enumerate(ds): 46 | _x = _x.float().cuda() 47 | _p = PTrain.float().cuda().clone() 48 | _y = _y.float().cuda() 49 | seg_out = segnet(_x).detach() 50 | _x = norm_input(_x, seg_out, float(args['a'])) 51 | iter = int(args['DeepIt']) 52 | net_out = model(_x, _p, faces, iter) 53 | Mask = net_out[0][iter-1] 54 | _, cIoU = get_dice_ji(Mask, _y) 55 | TestIoU_list.append(cIoU) 56 | IoU = np.mean(TestIoU_list) 57 | model.train() 58 | return IoU 59 | 60 | 61 | def get_dice_ji(predict, target): 62 | predict = predict.data.cpu().numpy() + 1 63 | target = target.data.cpu().numpy() + 1 64 | tp = np.sum(((predict == 2) * (target == 2)) * (target > 0)) 65 | fp = np.sum(((predict == 2) * (target == 1)) * (target > 0)) 66 | fn = np.sum(((predict == 1) * (target == 2)) * (target > 0)) 67 | ji = float(np.nan_to_num(tp / (tp + fp + fn))) 68 | dice = float(np.nan_to_num(2 * tp / (2 * tp + fp + fn))) 69 | return dice, ji 70 | 71 | 72 | def update_net_list(PATH): 73 | a = os.listdir(PATH) 74 | PATH1_list = [] 75 | for i in a: 76 | if i[0:5]=='model': PATH1_list.append(i) 77 | PATH1_list.sort() 78 | return PATH1_list 79 | 80 | 81 | def dice_metric(X, Y): 82 | return np.sum(X[Y==1])*2.0 / (np.sum(X) + np.sum(Y) + 1e-6) 83 | 84 | 85 | def IoU_metric(y_pred, y_true): 86 | intersection = np.sum(y_true * y_pred, axis=None) 87 | union = np.sum(y_true, axis=None) + np.sum(y_pred, axis=None) - intersection 88 | if float(union)==0: return 0.0 89 | else: return float(intersection) / float(union) 90 | 91 | 92 | def WCov_metric(X, Y): 93 | A1 = float(np.count_nonzero(X)) 94 | A2 = float(np.count_nonzero(Y)) 95 | if A1>=A2: return A2/A1 96 | if A2>A1: return A1/A2 97 | 98 | 99 | def FBound_metric(X, Y): 100 | tmp1 = db_eval_boundary(X,Y,1)[0] 101 | tmp2 = db_eval_boundary(X,Y,2)[0] 102 | tmp3 = db_eval_boundary(X,Y,3)[0] 103 | tmp4 = db_eval_boundary(X,Y,4)[0] 104 | tmp5 = db_eval_boundary(X,Y,5)[0] 105 | return (tmp1+tmp2+tmp3+tmp4+tmp5)/5.0 106 | 107 | def db_eval_boundary(foreground_mask, gt_mask, bound_th): 108 | """ 109 | Compute mean,recall and decay from per-frame evaluation. 110 | Calculates precision/recall for boundaries between foreground_mask and 111 | gt_mask using morphological operators to speed it up. 112 | Arguments: 113 | foreground_mask (ndarray): binary segmentation image. 114 | gt_mask (ndarray): binary annotated image. 115 | Returns: 116 | F (float): boundaries F-measure 117 | P (float): boundaries precision 118 | R (float): boundaries recall 119 | """ 120 | assert np.atleast_3d(foreground_mask).shape[2] == 1 121 | 122 | bound_pix = bound_th if bound_th >= 1 else \ 123 | np.ceil(bound_th*np.linalg.norm(foreground_mask.shape)) 124 | 125 | # Get the pixel boundaries of both masks 126 | fg_boundary = seg2bmap(foreground_mask); 127 | gt_boundary = seg2bmap(gt_mask); 128 | 129 | fg_dil = binary_dilation(fg_boundary,disk(bound_pix)) 130 | gt_dil = binary_dilation(gt_boundary,disk(bound_pix)) 131 | 132 | # Get the intersection 133 | gt_match = gt_boundary * fg_dil 134 | fg_match = fg_boundary * gt_dil 135 | 136 | # Area of the intersection 137 | n_fg = np.sum(fg_boundary) 138 | n_gt = np.sum(gt_boundary) 139 | 140 | #% Compute precision and recall 141 | if n_fg == 0 and n_gt > 0: 142 | precision = 1 143 | recall = 0 144 | elif n_fg > 0 and n_gt == 0: 145 | precision = 0 146 | recall = 1 147 | elif n_fg == 0 and n_gt == 0: 148 | precision = 1 149 | recall = 1 150 | else: 151 | precision = np.sum(fg_match)/float(n_fg) 152 | recall = np.sum(gt_match)/float(n_gt) 153 | 154 | # Compute F measure 155 | if precision + recall == 0: 156 | F = 0 157 | else: 158 | F = 2*precision*recall/(precision+recall); 159 | 160 | return F, precision, recall, np.sum(fg_match), n_fg, np.sum(gt_match), n_gt 161 | 162 | 163 | def seg2bmap(seg,width=None,height=None): 164 | """ 165 | From a segmentation, compute a binary boundary map with 1 pixel wide 166 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 167 | origin from the actual segment boundary. 168 | Arguments: 169 | seg : Segments labeled from 1..k. 170 | width : Width of desired bmap <= seg.shape[1] 171 | height : Height of desired bmap <= seg.shape[0] 172 | Returns: 173 | bmap (ndarray): Binary boundary map. 174 | David Martin 175 | January 2003 176 | """ 177 | seg = seg.astype(np.bool) 178 | seg[seg>0] = 1 179 | 180 | assert np.atleast_3d(seg).shape[2] == 1 181 | 182 | width = seg.shape[1] if width is None else width 183 | height = seg.shape[0] if height is None else height 184 | 185 | h,w = seg.shape[:2] 186 | 187 | ar1 = float(width) / float(height) 188 | ar2 = float(w) / float(h) 189 | 190 | assert not (width>w | height>h | abs(ar1-ar2)>0.01),\ 191 | 'Can''t convert %dx%d seg to %dx%d bmap.'%(w,h,width,height) 192 | 193 | e = np.zeros_like(seg) 194 | s = np.zeros_like(seg) 195 | se = np.zeros_like(seg) 196 | 197 | e[:,:-1] = seg[:,1:] 198 | s[:-1,:] = seg[1:,:] 199 | se[:-1,:-1] = seg[1:,1:] 200 | 201 | b = seg^e | seg^s | seg^se 202 | b[-1,:] = seg[-1,:]^e[-1,:] 203 | b[:,-1] = seg[:,-1]^s[:,-1] 204 | b[-1,-1] = 0 205 | 206 | if w == width and h == height: 207 | bmap = b 208 | else: 209 | bmap = np.zeros((height,width)) 210 | for x in range(w): 211 | for y in range(h): 212 | if b[y,x]: 213 | j = 1+floor((y-1)+height / h) 214 | i = 1+floor((x-1)+width / h) 215 | bmap[j,i] = 1; 216 | 217 | return bmap 218 | -------------------------------------------------------------------------------- /utils/utils_lr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import math 5 | from bisect import bisect_right 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class WarmupLrScheduler(torch.optim.lr_scheduler._LRScheduler): 11 | 12 | def __init__( 13 | self, 14 | optimizer, 15 | warmup_iter=500, 16 | warmup_ratio=5e-4, 17 | warmup='exp', 18 | last_epoch=-1, 19 | ): 20 | self.warmup_iter = warmup_iter 21 | self.warmup_ratio = warmup_ratio 22 | self.warmup = warmup 23 | super(WarmupLrScheduler, self).__init__(optimizer, last_epoch) 24 | 25 | def get_lr(self): 26 | ratio = self.get_lr_ratio() 27 | lrs = [ratio * lr for lr in self.base_lrs] 28 | return lrs 29 | 30 | def get_lr_ratio(self): 31 | if self.last_epoch < self.warmup_iter: 32 | ratio = self.get_warmup_ratio() 33 | else: 34 | ratio = self.get_main_ratio() 35 | return ratio 36 | 37 | def get_main_ratio(self): 38 | raise NotImplementedError 39 | 40 | def get_warmup_ratio(self): 41 | assert self.warmup in ('linear', 'exp') 42 | alpha = self.last_epoch / self.warmup_iter 43 | if self.warmup == 'linear': 44 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha 45 | elif self.warmup == 'exp': 46 | ratio = self.warmup_ratio ** (1. - alpha) 47 | return ratio 48 | 49 | 50 | class WarmupPolyLrScheduler(WarmupLrScheduler): 51 | 52 | def __init__( 53 | self, 54 | optimizer, 55 | power, 56 | max_iter, 57 | warmup_iter=500, 58 | warmup_ratio=5e-4, 59 | warmup='exp', 60 | last_epoch=-1, 61 | ): 62 | self.power = power 63 | self.max_iter = max_iter 64 | super(WarmupPolyLrScheduler, self).__init__( 65 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 66 | 67 | def get_main_ratio(self): 68 | real_iter = self.last_epoch - self.warmup_iter 69 | real_max_iter = self.max_iter - self.warmup_iter 70 | alpha = real_iter / real_max_iter 71 | ratio = (1 - alpha) ** self.power 72 | return ratio 73 | 74 | 75 | class WarmupExpLrScheduler(WarmupLrScheduler): 76 | 77 | def __init__( 78 | self, 79 | optimizer, 80 | gamma, 81 | interval=1, 82 | warmup_iter=500, 83 | warmup_ratio=5e-4, 84 | warmup='exp', 85 | last_epoch=-1, 86 | ): 87 | self.gamma = gamma 88 | self.interval = interval 89 | super(WarmupExpLrScheduler, self).__init__( 90 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 91 | 92 | def get_main_ratio(self): 93 | real_iter = self.last_epoch - self.warmup_iter 94 | ratio = self.gamma ** (real_iter // self.interval) 95 | return ratio 96 | 97 | 98 | class WarmupCosineLrScheduler(WarmupLrScheduler): 99 | 100 | def __init__( 101 | self, 102 | optimizer, 103 | max_iter, 104 | eta_ratio=0, 105 | warmup_iter=500, 106 | warmup_ratio=5e-4, 107 | warmup='exp', 108 | last_epoch=-1, 109 | ): 110 | self.eta_ratio = eta_ratio 111 | self.max_iter = max_iter 112 | super(WarmupCosineLrScheduler, self).__init__( 113 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 114 | 115 | def get_main_ratio(self): 116 | real_iter = self.last_epoch - self.warmup_iter 117 | real_max_iter = self.max_iter - self.warmup_iter 118 | return self.eta_ratio + (1 - self.eta_ratio) * ( 119 | 1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2 120 | 121 | 122 | class WarmupStepLrScheduler(WarmupLrScheduler): 123 | 124 | def __init__( 125 | self, 126 | optimizer, 127 | milestones: list, 128 | gamma=0.1, 129 | warmup_iter=500, 130 | warmup_ratio=5e-4, 131 | warmup='exp', 132 | last_epoch=-1, 133 | ): 134 | self.milestones = milestones 135 | self.gamma = gamma 136 | super(WarmupStepLrScheduler, self).__init__( 137 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 138 | 139 | def get_main_ratio(self): 140 | real_iter = self.last_epoch - self.warmup_iter 141 | ratio = self.gamma ** bisect_right(self.milestones, real_iter) 142 | return ratio 143 | 144 | 145 | if __name__ == "__main__": 146 | model = torch.nn.Conv2d(3, 16, 3, 1, 1) 147 | optim = torch.optim.SGD(model.parameters(), lr=5e-2) 148 | 149 | max_iter = 10000 150 | lr_scheduler = WarmupPolyLrScheduler(optim, 0.9, max_iter, 200, 0.001, 'exp', -1) 151 | lrs = [] 152 | for _ in range(max_iter+1000): 153 | lr = lr_scheduler.get_lr()[0] 154 | print(lr) 155 | lrs.append(lr) 156 | lr_scheduler.step() 157 | 158 | 159 | -------------------------------------------------------------------------------- /utils/utils_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | 5 | 6 | def norm_tensor(_x): 7 | bs = _x.shape[0] 8 | min_x = _x.view(bs, 3, -1).min(dim=2)[0].view(bs, 3, 1, 1) 9 | max_x = _x.view(bs, 3, -1).max(dim=2)[0].view(bs, 3, 1, 1) 10 | return (_x - min_x) / (max_x - min_x + 1e-2) 11 | 12 | 13 | def norm_input(_x, seg_out, a): 14 | _x = norm_tensor(_x) 15 | _x = a * _x + (1 - a) * seg_out 16 | return _x 17 | 18 | 19 | def train(ds, model, segnet, optimizer, snake_loss, PTrain, faces, args): 20 | loss_list = [] 21 | for ix, (_x, _y) in tqdm(enumerate(ds)): 22 | _x = _x.float().cuda() 23 | _p = PTrain.float().cuda().clone() 24 | _y = _y.float().cuda() 25 | optimizer.zero_grad() 26 | seg_out = segnet(_x) 27 | _x = norm_input(_x, seg_out, float(args['a'])).detach() 28 | num_of_it = int(args['DeepIt']) 29 | net_out = model(_x, _p, faces, num_of_it) 30 | loss = snake_loss.snake_loss(num_of_it, net_out, _y) 31 | loss_list.append(loss.item()/num_of_it) 32 | loss.backward() 33 | optimizer.step() 34 | return loss_list 35 | 36 | 37 | def logger_train(epoch, writer, loss_list): 38 | print('************************************************************************') 39 | print('Epoch: ' + str(epoch) + ' Mask mean loss: ' + str(np.mean(loss_list)) + ' Mask max loss: ' + str( 40 | np.max(loss_list)) + ' Mask min loss: ' + str(np.min(loss_list))) 41 | writer.add_scalar('MaskLoss', np.mean(loss_list), global_step=epoch) 42 | print('************************************************************************') 43 | 44 | 45 | -------------------------------------------------------------------------------- /utils/utils_tri.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from scipy.spatial import Delaunay 5 | 6 | 7 | def get_faces(P): 8 | N = P.shape[2]*2 9 | faces = torch.zeros(P.shape[0], N, 3) 10 | for i in range(P.shape[0]): 11 | cP = P[i, :, :, :].squeeze(dim=0).squeeze(dim=0) 12 | tri = Delaunay(cP.detach().cpu().numpy()) 13 | tri = torch.tensor(tri.simplices.copy()) 14 | nP = tri.shape[0] 15 | last = tri[nP-1, :].unsqueeze(dim=0) 16 | for j in range(N-nP): 17 | tri = torch.cat((tri, last), dim=0) 18 | faces[i, :, :] = tri.unsqueeze(dim=0) 19 | return faces.type(torch.int32) 20 | 21 | 22 | def get_poly(dim=128, n=16, R=16, xx=64, yy=64): 23 | half_dim = dim / 2 24 | P = [np.array([xx + math.floor(math.cos(2 * math.pi / n * x) * R), 25 | yy + math.floor(math.sin(2 * math.pi / n * x) * R)]) for x in range(0, n)] 26 | train_data = torch.zeros(1, 1, n, 2) 27 | for i in range(n): 28 | train_data[0, 0, i, 0] = torch.tensor((P[i][0] - half_dim) / half_dim).clone() 29 | train_data[0, 0, i, 1] = torch.tensor((P[i][1] - half_dim) / half_dim).clone() 30 | vertices = torch.ones((n, 3)) 31 | tmp = train_data.squeeze(dim=0).squeeze(dim=0) 32 | vertices[:, 0] = tmp[:, 0] 33 | vertices[:, 1] = tmp[:, 1] * -1 34 | tri = Delaunay(vertices[:, 0:2].numpy()) 35 | faces = torch.tensor(tri.simplices.copy()) 36 | return train_data, faces 37 | -------------------------------------------------------------------------------- /utils/utils_vis.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | import matplotlib 9 | matplotlib.use('agg') 10 | import matplotlib.pyplot as plt 11 | 12 | import importlib 13 | import cv2 14 | import xlrd 15 | # import dar_package.config 16 | 17 | 18 | def fig2img(fig): 19 | """ 20 | Convert a Matplotlib figure to a PIL Image in RGBA format 21 | Copied from http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure 22 | """ 23 | # put the figure pixmap into a np array 24 | buf = fig2data(fig) 25 | w, h, d = buf.shape 26 | return Image.frombytes("RGBA", (w, h), buf.tostring()) 27 | 28 | 29 | def fig2data ( fig ): 30 | """ 31 | Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it 32 | Copied from http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure 33 | """ 34 | # draw the renderer 35 | fig.canvas.draw () 36 | 37 | # Get the RGBA buffer from the figure 38 | w,h = fig.canvas.get_width_height() 39 | buf = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) 40 | buf.shape = (w, h, 4) 41 | 42 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode 43 | buf = np.roll(buf, 3, axis=2) 44 | return buf 45 | 46 | 47 | def send_image_to_TB(img, P_init, Mask, P, Ix, Iy, GT, cIoU): 48 | dim = Mask.shape[0] 49 | dim2 = img.shape[0] 50 | fig, ax = plt.subplots(nrows=3, ncols=3, figsize=[10, 10]) 51 | fig.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) 52 | ax[0, 0].imshow(img) 53 | ax[0, 1].plot(P_init[0, :]*0.5*dim2+0.5*dim2, P_init[1, :]*0.5*dim2+0.5*dim2, 'r--', linewidth=2.0) 54 | ax[0, 1].plot(P[0, :]*0.5*dim2+0.5*dim2, P[1, :]*0.5*dim2+0.5*dim2, color=[0, 1, 0], linewidth=2.0, marker='*') 55 | ax[0, 1].imshow(img) 56 | ax[0, 2].imshow(img) 57 | ax[1, 0].imshow(GT) 58 | ax[1, 1].imshow(Mask) 59 | ax[1, 2].plot(P_init[0, :]*0.5*dim+0.5*dim, P_init[1, :]*0.5*dim+0.5*dim, 'ro') 60 | ax[1, 2].plot(P[0, :]*0.5*dim+0.5*dim, P[1, :]*0.5*dim+0.5*dim, color=[0, 1, 0], linewidth=2.0, marker='*') 61 | ax[1, 2].imshow(Mask) 62 | ax[2, 0].imshow(Ix) 63 | ax[2, 1].imshow(Iy) 64 | ax[2, 2].imshow((Ix**2+Iy**2)**0.5) 65 | ax[0, 0].axis('off') 66 | ax[0, 1].axis('off') 67 | ax[0, 2].axis('off') 68 | ax[1, 0].axis('off') 69 | ax[1, 1].axis('off') 70 | ax[1, 2].axis('off') 71 | ax[2, 0].axis('off') 72 | ax[2, 1].axis('off') 73 | ax[2, 2].axis('off') 74 | fig.suptitle('IoU: ' + str(cIoU)) 75 | fig.tight_layout(pad=0, w_pad=0, h_pad=0) 76 | return np.asarray(fig2img(fig)) 77 | --------------------------------------------------------------------------------