├── .gitattributes
├── Data
└── __init__.py
├── README.md
├── eval.py
├── eval_seg.py
├── hd5_maker
├── __init__.py
├── bing.py
└── viah.py
├── loader
├── __init__.py
├── bing_loader.py
├── transforms.py
└── viah_loader.py
├── models
├── __init__.py
├── hardnet.py
├── model.py
├── model_seg.py
└── vanilla.py
├── out
└── __init__.py
├── pics
├── 16.jpg
└── 8.jpg
├── requirements.txt
├── results
└── __init__.py
├── train.py
├── train_seg.py
└── utils
├── __init__.py
├── loss.py
├── snake_loss.py
├── utils_TB.py
├── utils_args.py
├── utils_eval.py
├── utils_lr.py
├── utils_train.py
├── utils_tri.py
└── utils_vis.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/Data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/Data/__init__.py
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Active Contours Model(ACM) for Buildings Segmentations
2 |
3 |
4 |
5 | Deep snake algorithm for 2D images based on - [ICLR2020 paper(revisiting the results)](https://arxiv.org/abs/1912.00367)
6 | Architecture based on [Hardnet85](https://arxiv.org/abs/1909.00948)
7 | Data and weights = (https://drive.google.com/drive/folders/1fBSjPse3d8geV_iI3-PXV3x2qmLoUnzL?usp=sharing)
8 |
9 | ### Get Started
10 | **To train a segmentation model :**
11 | ```
12 | python train_seg.py -bs 50 -WD 0.00005 -D_rate 3000 -task bing -opt sgd -lr 0.02 -nW 8
13 | ```
14 | **To train a ACM model :**
15 | ```
16 | python train.py -bs 25 -WD 0.00005 -D_rate 30 -it 2
17 | ```
18 | **To eval a segmentation model :**
19 | ```
20 | python eval.py -task viah -nP 100 -it 3 -a 0.4
21 | ```
22 | **To eval a ACM model :**
23 | ```
24 | python eval_seg.py -task bing -nP 24 -it 2 -a 0.4
25 | ```
26 | ### Results
27 |
28 | | Method | Viah
mIoU | Bing
mIoU|
29 | | :---: | :---: | :---: |
30 | | DARNet | 88.24 | 75.29 |
31 | | DSAC | 71.10 | 38.74 |
32 | | **ours** | **90.33** | **75.53** |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from tensorboardX import SummaryWriter
3 |
4 | from models.model import *
5 | from models.model_seg import Segmentation
6 | from loader.viah_loader import *
7 | from loader.bing_loader import *
8 | from utils.utils_args import *
9 | from utils.utils_eval import get_dice_ji, vis_ds
10 | from utils.utils_train import *
11 | from utils.utils_tri import *
12 | from utils.loss import *
13 | from utils.snake_loss import Snakeloss
14 | import random
15 |
16 |
17 | def eval_ds(ds, model, segnet, PTrain, faces, args):
18 | model.eval()
19 | IoU_list = []
20 | Dice_list = []
21 | model.eval()
22 | with torch.no_grad():
23 | for ix, (_x, _y) in enumerate(ds):
24 | _x = _x.float().cuda()
25 | _p = PTrain.float().cuda().clone()
26 | _y = _y.float().cuda()
27 | seg_out = segnet(_x).detach()
28 | _x = norm_input(_x, seg_out, float(args['a']))
29 | iter = int(args['DeepIt'])
30 | net_out = model(_x, _p, faces, iter)
31 | Mask = net_out[0][iter-1]
32 | cDice, cIoU = get_dice_ji(Mask, _y)
33 | IoU_list.append(cIoU)
34 | Dice_list.append(cDice)
35 | IoU = np.mean(IoU_list)
36 | Dice = np.mean(Dice_list)
37 | model.train()
38 | return Dice, IoU
39 |
40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41 | torch.backends.cudnn.benchmark = True
42 | args = get_args()
43 |
44 | segnet = Segmentation(args)
45 | model = DeepACM(args)
46 |
47 |
48 | P_test, faces_test = get_poly(int(args['im_size']), int(args['nP']),
49 | int(args['Radius']), int(args['im_size']) / 2, int(args['im_size']) / 2)
50 | faces_test = faces_test.unsqueeze(dim=0).unsqueeze(dim=0).cuda()
51 | faces = faces_test.repeat(1, 1, 1, 1).cuda()
52 | PTrain = P_test.repeat(1, 1, 1, 1).cuda()
53 | if args['task'] == 'viah':
54 | PATH = r'results/viah/best/'
55 | testset = viah_segmentation(ann='test', args=args)
56 | elif args['task'] == 'bing':
57 | testset = bing_segmentation(ann='test', args=args)
58 | PATH = r'results/bing/best/'
59 | ds_val = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False,
60 | num_workers=1, drop_last=False)
61 |
62 | model1 = torch.load(PATH + 'ACM.pt')
63 | model.load_state_dict(model1.state_dict())
64 | model.eval().to(device)
65 | segnet1 = torch.load(PATH + 'SEG.pt')
66 | segnet.load_state_dict(segnet1.state_dict())
67 | segnet.eval().to(device)
68 | vis_ds(ds_val, model, segnet, PTrain, faces, args, num_of_ex=20)
69 | dice, iou = eval_ds(ds_val, model, segnet, P_test, faces_test, args)
70 | print((dice, iou))
71 |
--------------------------------------------------------------------------------
/eval_seg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from loader.viah_loader import *
4 | from loader.bing_loader import *
5 | from utils.utils_args import *
6 | from utils.utils_eval import *
7 | from utils.utils_train import *
8 | from utils.utils_tri import *
9 | from utils.utils_vis import *
10 | from utils.loss import *
11 | from models.model_seg import *
12 |
13 |
14 | def eval_ds(ds, model):
15 | TestDice_list = []
16 | TestIoU_list = []
17 | for ix, (_x, _y) in enumerate(ds):
18 | _x = _x.float().cpu()
19 | _y = _y.float().cpu()
20 | Mask = model(_x)
21 | Mask[Mask >= 0.5] = 1
22 | Mask[Mask < 0.5] = 0
23 | (cDice, cIoU) = get_dice_ji(Mask, _y)
24 | TestDice_list.append(cDice)
25 | TestIoU_list.append(cIoU)
26 | Dice = np.mean(TestDice_list)
27 | IoU = np.mean(TestIoU_list)
28 | print((Dice, IoU))
29 |
30 |
31 | def main():
32 | torch.backends.cudnn.benchmark = True
33 | args = get_args()
34 | save_args(args)
35 |
36 | if args['task'] == 'viah':
37 | PATH = r'results/viah/best/'
38 | testset = viah_segmentation(ann='test', args=args)
39 | elif args['task'] == 'bing':
40 | testset = bing_segmentation(ann='test', args=args)
41 | PATH = r'results/bing/best/'
42 | segnet = Segmentation(args)
43 | segnet1 = torch.load(PATH + 'SEG.pt')
44 | segnet.load_state_dict(segnet1.state_dict())
45 | segnet.cpu().eval()
46 | ds_val = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False,
47 | num_workers=1, drop_last=False)
48 | eval_ds(ds_val, segnet)
49 |
50 | if __name__ == '__main__':
51 | main()
52 |
53 |
--------------------------------------------------------------------------------
/hd5_maker/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/hd5_maker/__init__.py
--------------------------------------------------------------------------------
/hd5_maker/bing.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import os
3 | import numpy as np
4 | from glob import glob
5 | from PIL import Image
6 | import torchvision.transforms as transforms
7 |
8 |
9 | def get_img(cfile):
10 | image = Image.open(cfile)
11 | image = transforms.functional.resize(image, (256, 256), 3)
12 | return np.asarray(image)
13 |
14 | def get_mask(cfile):
15 | image = Image.open(cfile)
16 | image = np.asarray(image).copy()
17 | image[image > 0] = 255
18 | image = image.astype(np.uint8)
19 | image = Image.fromarray(image).convert('1')
20 | mask_one = transforms.functional.resize(image, (256, 256))
21 | mask = np.asarray(mask_one).astype(np.float)
22 | return mask
23 |
24 | src = '/media/data1/talshah/DAR/single_buildings/'
25 | hf_tri = h5py.File('/media/data1/talshah/DeepACM/Data/full_training_Bing.h5', 'w')
26 | hf_test = h5py.File('/media/data1/talshah/DeepACM/Data/full_test_Bing.h5', 'w')
27 | a = os.listdir(src)
28 | a.sort()
29 |
30 | img_list = glob(src + 'building_*')
31 | mask_list = glob(src + 'building_mask_*')
32 | mask_all_list = glob(src + 'building_mask_all_*')
33 | img_list.sort()
34 | mask_list.sort()
35 | mask_all_list.sort()
36 |
37 | imgs_tri = hf_tri.create_group('imgs')
38 | mask_tri = hf_tri.create_group('mask')
39 | mask_single_tri = hf_tri.create_group('mask_single')
40 |
41 | imgs_test = hf_test.create_group('imgs')
42 | mask_test = hf_test.create_group('mask')
43 | mask_single_test = hf_test.create_group('mask_single')
44 |
45 | for folder in img_list[0:335]:
46 | print('training: ' + folder)
47 | img = get_img(folder)
48 | imgs_tri.create_dataset(folder.split('/')[-1], data=img, dtype=np.uint8)
49 |
50 | for folder in img_list[335:606]:
51 | print('validation: ' + folder)
52 | img = get_img(folder)
53 | imgs_test.create_dataset(folder.split('/')[-1], data=img, dtype=np.uint8)
54 |
55 | for folder in mask_list[0:335]:
56 | print('training: ' + folder)
57 | mask = get_mask(folder)
58 | mask_single_tri.create_dataset(folder.split('/')[-1], data=mask, dtype=np.uint8)
59 |
60 | for folder in mask_list[335:606]:
61 | print('validation: ' + folder)
62 | mask = get_mask(folder)
63 | mask_single_test.create_dataset(folder.split('/')[-1], data=mask, dtype=np.uint8)
64 |
65 | for folder in mask_all_list[0:335]:
66 | print('training: ' + folder)
67 | mask = get_mask(folder)
68 | mask_tri.create_dataset(folder.split('/')[-1], data=mask, dtype=np.uint8)
69 |
70 | for folder in mask_all_list[335:606]:
71 | print('validation: ' + folder)
72 | mask = get_mask(folder)
73 | mask_test.create_dataset(folder.split('/')[-1], data=mask, dtype=np.uint8)
74 |
75 | hf_tri.close()
76 | hf_test.close()
77 |
78 |
--------------------------------------------------------------------------------
/hd5_maker/viah.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import os
3 | import cv2
4 | import numpy as np
5 | from skimage.transform import resize
6 |
7 |
8 | def get_img(cfile, shape=(256, 256)):
9 | img = cv2.cvtColor(cv2.imread(cfile, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
10 | img = resize(img, shape)*255
11 | return img
12 |
13 |
14 | def get_mask(cfile, shape=(256, 256)):
15 | GT = cv2.imread(cfile, 0)
16 | GT = resize(GT, shape)
17 | GT[GT >= 0.5] = 1
18 | GT[GT < 0.5] = 0
19 | return GT
20 |
21 |
22 | img_src = '/media/data1/talshah/GenA/buildings_vaihingen/buildings/img/'
23 | mask_src = '/media/data1/talshah/GenA/buildings_vaihingen/buildings/AllBuildingsMask/'
24 | mask_single_src = '/media/data1/talshah/GenA/buildings_vaihingen/buildings/mask_sizeold/'
25 | hf_tri = h5py.File('/media/data1/talshah/DeepACM/Data/full_training_viah.h5', 'w')
26 | hf_test = h5py.File('/media/data1/talshah/DeepACM/Data/full_test_viah.h5', 'w')
27 | a_img = os.listdir(img_src)
28 | a_img.sort()
29 | a_mask = os.listdir(mask_src)
30 | a_mask.sort()
31 | a_mask_single = os.listdir(mask_single_src)
32 | a_mask_single.sort()
33 |
34 | imgs_tri = hf_tri.create_group('imgs')
35 | mask_tri = hf_tri.create_group('mask')
36 | mask_single_tri = hf_tri.create_group('mask_single')
37 |
38 | imgs_test = hf_test.create_group('imgs')
39 | mask_test = hf_test.create_group('mask')
40 | mask_single_test = hf_test.create_group('mask_single')
41 | shape_mask = (256, 256)
42 | shape_img = (256, 256)
43 |
44 | for folder in a_img[:100]:
45 | print('training: ' + folder)
46 | cfile = img_src + folder
47 | img = get_img(cfile, shape_img)
48 | imgs_tri.create_dataset(folder, data=img, dtype=np.uint8)
49 |
50 | for folder in a_img[100:]:
51 | print('validation: ' + folder)
52 | cfile = img_src + folder
53 | img = get_img(cfile, shape_img)
54 | imgs_test.create_dataset(folder, data=img, dtype=np.uint8)
55 |
56 | for folder in a_mask[:100]:
57 | print('training: ' + folder)
58 | cfile = mask_src + folder
59 | mask = get_mask(cfile, shape_mask)
60 | mask_tri.create_dataset(folder, data=mask, dtype=np.uint8)
61 |
62 | for folder in a_mask[100:]:
63 | print('validation: ' + folder)
64 | cfile = mask_src + folder
65 | mask = get_mask(cfile, shape_mask)
66 | mask_test.create_dataset(folder, data=mask, dtype=np.uint8)
67 |
68 | for folder in a_mask_single[:100]:
69 | print('training: ' + folder)
70 | cfile = mask_single_src + folder
71 | mask = get_mask(cfile, shape_mask)
72 | mask_single_tri.create_dataset(folder, data=mask, dtype=np.uint8)
73 |
74 | for folder in a_mask_single[100:]:
75 | print('validation: ' + folder)
76 | cfile = mask_single_src + folder
77 | mask = get_mask(cfile, shape_mask)
78 | mask_single_test.create_dataset(folder, data=mask, dtype=np.uint8)
79 |
80 | hf_tri.close()
81 | hf_test.close()
82 |
83 |
--------------------------------------------------------------------------------
/loader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/loader/__init__.py
--------------------------------------------------------------------------------
/loader/bing_loader.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings('ignore')
3 | import torch
4 | from torch.utils.data import Dataset
5 | import h5py
6 | import numpy as np
7 | import loader.transforms as transforms
8 | from PIL import Image
9 |
10 |
11 | class bing_segmentation(Dataset):
12 | def __init__(self, ann='training', args=None):
13 | self.ann = ann
14 | self.MEAN = np.array([101.87901, 100.81404, 110.389275])
15 | self.STD = np.array([17.022379, 17.664776, 20.302572])
16 | if ann == 'training':
17 | self.transformations = transforms.Compose([transforms.ToPILImage(),
18 | transforms.ColorJitter(brightness=0.3,
19 | contrast=0.3,
20 | saturation=0.3,
21 | hue=0.01),
22 | transforms.RandomRotation(90),
23 | transforms.RandomVerticalFlip(),
24 | transforms.RandomHorizontalFlip(),
25 | transforms.ToTensor(),
26 | transforms.Normalize(self.MEAN, self.STD)])
27 | else:
28 | self.transformations = transforms.Compose([transforms.ToPILImage(),
29 | transforms.ToTensor(),
30 | transforms.Normalize(self.MEAN, self.STD)])
31 | if ann == 'training':
32 | self.data_length = 335
33 | else:
34 | self.data_length = 270
35 | self.args = args
36 |
37 | def __len__(self):
38 | return self.data_length
39 |
40 | def __getitem__(self, item):
41 | if self.ann == 'training':
42 | self.data = h5py.File('Data/full_training_Bing.h5', 'r')
43 | else:
44 | self.data = h5py.File('Data/full_test_Bing.h5', 'r')
45 | self.mask = self.data['mask_single']
46 | self.imgs = self.data['imgs']
47 | self.img_list = list(self.imgs)
48 | self.mask_list = list(self.mask)
49 | cimage = self.img_list[item]
50 | img = self.imgs.get(cimage).value
51 | cmask = self.mask_list[item]
52 | mask = self.mask.get(cmask).value
53 | img = img.astype(np.uint8)
54 | mask = mask.astype(np.uint8)
55 | img, mask = self.transformations(img, mask)
56 | return img, mask
57 |
--------------------------------------------------------------------------------
/loader/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import math
4 | import sys
5 | import random
6 | from PIL import Image
7 |
8 | try:
9 | import accimage
10 | except ImportError:
11 | accimage = None
12 | import numpy as np
13 | import numbers
14 | import types
15 | import collections
16 | import warnings
17 |
18 | from torchvision.transforms import functional as F
19 |
20 | if sys.version_info < (3, 3):
21 | Sequence = collections.Sequence
22 | Iterable = collections.Iterable
23 | else:
24 | Sequence = collections.abc.Sequence
25 | Iterable = collections.abc.Iterable
26 |
27 | __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad",
28 | "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
29 | "RandomVerticalFlip", "RandomResizedCrop", "FiveCrop", "TenCrop",
30 | "ColorJitter", "RandomRotation", "RandomAffine",
31 | "RandomPerspective"]
32 |
33 | _pil_interpolation_to_str = {
34 | Image.NEAREST: 'PIL.Image.NEAREST',
35 | Image.BILINEAR: 'PIL.Image.BILINEAR',
36 | Image.BICUBIC: 'PIL.Image.BICUBIC',
37 | Image.LANCZOS: 'PIL.Image.LANCZOS',
38 | Image.HAMMING: 'PIL.Image.HAMMING',
39 | Image.BOX: 'PIL.Image.BOX',
40 | }
41 |
42 |
43 | class Compose(object):
44 | def __init__(self, transforms):
45 | self.transforms = transforms
46 |
47 | def __call__(self, img, mask):
48 | for t in self.transforms:
49 | img, mask = t(img, mask)
50 | return img, mask
51 |
52 |
53 | class ToTensor(object):
54 | def __call__(self, img, mask):
55 | # return F.to_tensor(img), F.to_tensor(mask)
56 | img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float()
57 | mask = torch.from_numpy(np.array(mask)).long()
58 | return img, mask
59 |
60 |
61 | class ToPILImage(object):
62 | def __init__(self, mode=None):
63 | self.mode = mode
64 |
65 | def __call__(self, img, mask):
66 | return F.to_pil_image(img, self.mode), F.to_pil_image(mask, self.mode)
67 |
68 |
69 | class Normalize(object):
70 | def __init__(self, mean, std, inplace=False):
71 | self.mean = mean
72 | self.std = std
73 | self.inplace = inplace
74 |
75 | def __call__(self, img, mask):
76 | return F.normalize(img, self.mean, self.std, self.inplace), mask
77 |
78 |
79 | class Resize(object):
80 | def __init__(self, size, interpolation=Image.BILINEAR, do_mask=True):
81 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
82 | self.size = size
83 | self.interpolation = interpolation
84 | self.do_mask = do_mask
85 |
86 | def __call__(self, img, mask):
87 | if self.do_mask:
88 | return F.resize(img, self.size, self.interpolation), F.resize(mask, self.size, Image.NEAREST)
89 | else:
90 | return F.resize(img, self.size, self.interpolation), mask
91 |
92 |
93 | class CenterCrop(object):
94 | def __init__(self, size):
95 | if isinstance(size, numbers.Number):
96 | self.size = (int(size), int(size))
97 | else:
98 | self.size = size
99 |
100 | def __call__(self, img, mask):
101 | return F.center_crop(img, self.size), F.center_crop(mask, self.size)
102 |
103 |
104 | class Pad(object):
105 | def __init__(self, padding, fill=0, padding_mode='constant'):
106 | assert isinstance(padding, (numbers.Number, tuple))
107 | assert isinstance(fill, (numbers.Number, str, tuple))
108 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
109 | if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
110 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
111 | "{} element tuple".format(len(padding)))
112 |
113 | self.padding = padding
114 | self.fill = fill
115 | self.padding_mode = padding_mode
116 |
117 | def __call__(self, img, mask):
118 | return F.pad(img, self.padding, self.fill, self.padding_mode), \
119 | F.pad(mask, self.padding, self.fill, self.padding_mode)
120 |
121 |
122 | class Lambda(object):
123 | def __init__(self, lambd):
124 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
125 | self.lambd = lambd
126 |
127 | def __call__(self, img, mask):
128 | return self.lambd(img), self.lambd(mask)
129 |
130 |
131 | class Lambda_image(object):
132 | def __init__(self, lambd):
133 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
134 | self.lambd = lambd
135 |
136 | def __call__(self, img, mask):
137 | return self.lambd(img), mask
138 |
139 |
140 | class RandomTransforms(object):
141 | def __init__(self, transforms):
142 | assert isinstance(transforms, (list, tuple))
143 | self.transforms = transforms
144 |
145 | def __call__(self, *args, **kwargs):
146 | raise NotImplementedError()
147 |
148 |
149 | class RandomApply(RandomTransforms):
150 | def __init__(self, transforms, p=0.5):
151 | super(RandomApply, self).__init__(transforms)
152 | self.p = p
153 |
154 | def __call__(self, img, mask):
155 | if self.p < random.random():
156 | return img, mask
157 | for t in self.transforms:
158 | img, mask = t(img, mask)
159 | return img, mask
160 |
161 |
162 | class RandomOrder(RandomTransforms):
163 | def __call__(self, img, mask):
164 | order = list(range(len(self.transforms)))
165 | random.shuffle(order)
166 | for i in order:
167 | img, mask = self.transforms[i](img, mask)
168 | return img, mask
169 |
170 |
171 | class RandomChoice(RandomTransforms):
172 | def __call__(self, img, mask):
173 | t = random.choice(self.transforms)
174 | return t(img, mask)
175 |
176 |
177 | class RandomCrop(object):
178 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
179 | if isinstance(size, numbers.Number):
180 | self.size = (int(size), int(size))
181 | else:
182 | self.size = size
183 | self.padding = padding
184 | self.pad_if_needed = pad_if_needed
185 | self.fill = fill
186 | self.padding_mode = padding_mode
187 |
188 | @staticmethod
189 | def get_params(img, output_size):
190 | w, h = img.size
191 | th, tw = output_size
192 | if w == tw and h == th:
193 | return 0, 0, h, w
194 |
195 | i = random.randint(0, h - th)
196 | j = random.randint(0, w - tw)
197 | return i, j, th, tw
198 |
199 | def __call__(self, img, mask):
200 | if self.padding is not None:
201 | img = F.pad(img, self.padding, self.fill, self.padding_mode)
202 |
203 | # pad the width if needed
204 | if self.pad_if_needed and img.size[0] < self.size[1]:
205 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
206 | # pad the height if needed
207 | if self.pad_if_needed and img.size[1] < self.size[0]:
208 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
209 |
210 | i, j, h, w = self.get_params(img, self.size)
211 |
212 | return F.crop(img, i, j, h, w), F.crop(mask, i, j, h, w)
213 |
214 |
215 | class RandomHorizontalFlip(object):
216 | def __init__(self, p=0.5):
217 | self.p = p
218 |
219 | def __call__(self, img, mask):
220 | if random.random() < self.p:
221 | return F.hflip(img), F.hflip(mask)
222 | return img, mask
223 |
224 |
225 | class RandomVerticalFlip(object):
226 | def __init__(self, p=0.5):
227 | self.p = p
228 |
229 | def __call__(self, img, mask):
230 | if random.random() < self.p:
231 | return F.vflip(img), F.vflip(mask)
232 | return img, mask
233 |
234 |
235 | class RandomPerspective(object):
236 | def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC):
237 | self.p = p
238 | self.interpolation = interpolation
239 | self.distortion_scale = distortion_scale
240 |
241 | def __call__(self, img, mask):
242 | if not F._is_pil_image(img):
243 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
244 |
245 | if random.random() < self.p:
246 | width, height = img.size
247 | startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
248 | return F.perspective(img, startpoints, endpoints, self.interpolation), \
249 | F.perspective(mask, startpoints, endpoints, Image.NEAREST)
250 | return img, mask
251 |
252 | @staticmethod
253 | def get_params(width, height, distortion_scale):
254 | half_height = int(height / 2)
255 | half_width = int(width / 2)
256 | topleft = (random.randint(0, int(distortion_scale * half_width)),
257 | random.randint(0, int(distortion_scale * half_height)))
258 | topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
259 | random.randint(0, int(distortion_scale * half_height)))
260 | botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
261 | random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
262 | botleft = (random.randint(0, int(distortion_scale * half_width)),
263 | random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
264 | startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
265 | endpoints = [topleft, topright, botright, botleft]
266 | return startpoints, endpoints
267 |
268 |
269 | class RandomResizedCrop(object):
270 | def __init__(self, size, mask_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
271 | if isinstance(size, tuple):
272 | self.size = size
273 | self.mask_size = mask_size
274 | else:
275 | self.size = (size, size)
276 | self.mask_size = (mask_size, mask_size)
277 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
278 | warnings.warn("range should be of kind (min, max)")
279 |
280 | self.interpolation = interpolation
281 | self.scale = scale
282 | self.ratio = ratio
283 |
284 | @staticmethod
285 | def get_params(img, scale, ratio):
286 | area = img.size[0] * img.size[1]
287 |
288 | for attempt in range(10):
289 | target_area = random.uniform(*scale) * area
290 | log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
291 | aspect_ratio = math.exp(random.uniform(*log_ratio))
292 |
293 | w = int(round(math.sqrt(target_area * aspect_ratio)))
294 | h = int(round(math.sqrt(target_area / aspect_ratio)))
295 |
296 | if w <= img.size[0] and h <= img.size[1]:
297 | i = random.randint(0, img.size[1] - h)
298 | j = random.randint(0, img.size[0] - w)
299 | return i, j, h, w
300 |
301 | # Fallback to central crop
302 | in_ratio = img.size[0] / img.size[1]
303 | if (in_ratio < min(ratio)):
304 | w = img.size[0]
305 | h = w / min(ratio)
306 | elif (in_ratio > max(ratio)):
307 | h = img.size[1]
308 | w = h * max(ratio)
309 | else: # whole image
310 | w = img.size[0]
311 | h = img.size[1]
312 | i = (img.size[1] - h) // 2
313 | j = (img.size[0] - w) // 2
314 | return i, j, h, w
315 |
316 | def __call__(self, img, mask):
317 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
318 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \
319 | F.resized_crop(mask, i, j, h, w, self.mask_size, Image.NEAREST)
320 |
321 |
322 | class FiveCrop(object):
323 | def __init__(self, size):
324 | self.size = size
325 | if isinstance(size, numbers.Number):
326 | self.size = (int(size), int(size))
327 | else:
328 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
329 | self.size = size
330 |
331 | def __call__(self, img, mask):
332 | return F.five_crop(img, self.size), F.five_crop(mask, self.size)
333 |
334 |
335 | class TenCrop(object):
336 | def __init__(self, size, vertical_flip=False):
337 | self.size = size
338 | if isinstance(size, numbers.Number):
339 | self.size = (int(size), int(size))
340 | else:
341 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
342 | self.size = size
343 | self.vertical_flip = vertical_flip
344 |
345 | def __call__(self, img, mask):
346 | return F.ten_crop(img, self.size, self.vertical_flip), F.ten_crop(mask, self.size, self.vertical_flip)
347 |
348 |
349 | class ColorJitter(object):
350 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
351 | self.brightness = self._check_input(brightness, 'brightness')
352 | self.contrast = self._check_input(contrast, 'contrast')
353 | self.saturation = self._check_input(saturation, 'saturation')
354 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
355 | clip_first_on_zero=False)
356 |
357 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
358 | if isinstance(value, numbers.Number):
359 | if value < 0:
360 | raise ValueError("If {} is a single number, it must be non negative.".format(name))
361 | value = [center - value, center + value]
362 | if clip_first_on_zero:
363 | value[0] = max(value[0], 0)
364 | elif isinstance(value, (tuple, list)) and len(value) == 2:
365 | if not bound[0] <= value[0] <= value[1] <= bound[1]:
366 | raise ValueError("{} values should be between {}".format(name, bound))
367 | else:
368 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
369 |
370 | # if value is 0 or (1., 1.) for brightness/contrast/saturation
371 | # or (0., 0.) for hue, do nothing
372 | if value[0] == value[1] == center:
373 | value = None
374 | return value
375 |
376 | @staticmethod
377 | def get_params(brightness, contrast, saturation, hue):
378 | transforms = []
379 |
380 | if brightness is not None:
381 | brightness_factor = random.uniform(brightness[0], brightness[1])
382 | transforms.append(Lambda_image(lambda img: F.adjust_brightness(img, brightness_factor)))
383 |
384 | if contrast is not None:
385 | contrast_factor = random.uniform(contrast[0], contrast[1])
386 | transforms.append(Lambda_image(lambda img: F.adjust_contrast(img, contrast_factor)))
387 |
388 | if saturation is not None:
389 | saturation_factor = random.uniform(saturation[0], saturation[1])
390 | transforms.append(Lambda_image(lambda img: F.adjust_saturation(img, saturation_factor)))
391 |
392 | if hue is not None:
393 | hue_factor = random.uniform(hue[0], hue[1])
394 | transforms.append(Lambda_image(lambda img: F.adjust_hue(img, hue_factor)))
395 |
396 | random.shuffle(transforms)
397 | transform = Compose(transforms)
398 |
399 | return transform
400 |
401 | def __call__(self, img, mask):
402 | transform = self.get_params(self.brightness, self.contrast,
403 | self.saturation, self.hue)
404 | return transform(img, mask)
405 |
406 |
407 | class RandomRotation(object):
408 | def __init__(self, degrees, resample=False, expand=False, center=None):
409 | if isinstance(degrees, numbers.Number):
410 | if degrees < 0:
411 | raise ValueError("If degrees is a single number, it must be positive.")
412 | self.degrees = (-degrees, degrees)
413 | else:
414 | if len(degrees) != 2:
415 | raise ValueError("If degrees is a sequence, it must be of len 2.")
416 | self.degrees = degrees
417 |
418 | self.resample = resample
419 | self.expand = expand
420 | self.center = center
421 |
422 | @staticmethod
423 | def get_params(degrees):
424 | angle = random.uniform(degrees[0], degrees[1])
425 |
426 | return angle
427 |
428 | def __call__(self, img, mask):
429 | angle = self.get_params(self.degrees)
430 |
431 | return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), \
432 | F.rotate(mask, angle, Image.NEAREST, self.expand, self.center)
433 |
434 |
435 | class RandomAffine(object):
436 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
437 | if isinstance(degrees, numbers.Number):
438 | if degrees < 0:
439 | raise ValueError("If degrees is a single number, it must be positive.")
440 | self.degrees = (-degrees, degrees)
441 | else:
442 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
443 | "degrees should be a list or tuple and it must be of length 2."
444 | self.degrees = degrees
445 |
446 | if translate is not None:
447 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
448 | "translate should be a list or tuple and it must be of length 2."
449 | for t in translate:
450 | if not (0.0 <= t <= 1.0):
451 | raise ValueError("translation values should be between 0 and 1")
452 | self.translate = translate
453 |
454 | if scale is not None:
455 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
456 | "scale should be a list or tuple and it must be of length 2."
457 | for s in scale:
458 | if s <= 0:
459 | raise ValueError("scale values should be positive")
460 | self.scale = scale
461 |
462 | if shear is not None:
463 | if isinstance(shear, numbers.Number):
464 | if shear < 0:
465 | raise ValueError("If shear is a single number, it must be positive.")
466 | self.shear = (-shear, shear)
467 | else:
468 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
469 | "shear should be a list or tuple and it must be of length 2."
470 | self.shear = shear
471 | else:
472 | self.shear = shear
473 |
474 | self.resample = resample
475 | self.fillcolor = fillcolor
476 |
477 | @staticmethod
478 | def get_params(degrees, translate, scale_ranges, shears, img_size):
479 | angle = random.uniform(degrees[0], degrees[1])
480 | if translate is not None:
481 | max_dx = translate[0] * img_size[0]
482 | max_dy = translate[1] * img_size[1]
483 | translations = (np.round(random.uniform(-max_dx, max_dx)),
484 | np.round(random.uniform(-max_dy, max_dy)))
485 | else:
486 | translations = (0, 0)
487 |
488 | if scale_ranges is not None:
489 | scale = random.uniform(scale_ranges[0], scale_ranges[1])
490 | else:
491 | scale = 1.0
492 |
493 | if shears is not None:
494 | shear = random.uniform(shears[0], shears[1])
495 | else:
496 | shear = 0.0
497 |
498 | return angle, translations, scale, shear
499 |
500 | def __call__(self, img, mask):
501 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
502 | return F.affine(img, *ret, resample=Image.BILINEAR, fillcolor=self.fillcolor), \
503 | F.affine(mask, *ret, resample=Image.NEAREST, fillcolor=self.fillcolor)
504 |
--------------------------------------------------------------------------------
/loader/viah_loader.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings('ignore')
3 | import torch
4 | from torch.utils.data import Dataset
5 | import h5py
6 | import numpy as np
7 | import loader.transforms as transforms
8 | from PIL import Image
9 |
10 |
11 | class viah_segmentation(Dataset):
12 | def __init__(self, ann='training', args=None):
13 | self.ann = ann
14 | self.MEAN = np.array([0.47341759*255, 0.28791303*255, 0.2850705*255])
15 | self.STD = np.array([0.22645572*255, 0.15276193*255, 0.140702*255])
16 | if ann == 'training':
17 | self.transformations = transforms.Compose([transforms.ToPILImage(),
18 | transforms.RandomResizedCrop(size=(256, 256),
19 | mask_size=(256, 256),
20 | scale=(0.75, 2)),
21 | transforms.ColorJitter(brightness=0.4,
22 | contrast=0.4,
23 | saturation=0.4,
24 | hue=0.1),
25 | transforms.RandomRotation(25),
26 | transforms.RandomHorizontalFlip(),
27 | transforms.ToTensor(),
28 | transforms.Normalize(self.MEAN, self.STD)])
29 | else:
30 | self.transformations = transforms.Compose([transforms.ToPILImage(),
31 | transforms.ToTensor(),
32 | transforms.Normalize(self.MEAN, self.STD)])
33 | if ann == 'training':
34 | self.data_length = 100
35 | else:
36 | self.data_length = 68
37 | self.args = args
38 |
39 | def __len__(self):
40 | return self.data_length
41 |
42 | def __getitem__(self, item):
43 | if self.ann == 'training':
44 | self.data = h5py.File('Data/full_training_viah.h5', 'r')
45 | else:
46 | self.data = h5py.File('Data/full_test_viah.h5', 'r')
47 | self.mask = self.data['mask_single']
48 | self.imgs = self.data['imgs']
49 | self.img_list = list(self.imgs)
50 | self.mask_list = list(self.mask)
51 | cimage = self.img_list[item]
52 | img = self.imgs.get(cimage).value
53 | cmask = self.mask_list[item]
54 | mask = self.mask.get(cmask).value
55 | img = img.astype(np.uint8)
56 | mask = mask.astype(np.uint8)
57 | img, mask = self.transformations(img, mask)
58 | return img, mask
59 |
60 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/models/__init__.py
--------------------------------------------------------------------------------
/models/hardnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class AdaptationMismatch(Exception): pass
8 |
9 |
10 | class Flatten(nn.Module):
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def forward(self, x):
15 | return x.view(x.data.size(0), -1)
16 |
17 |
18 | class CombConvLayer(nn.Sequential):
19 | def __init__(self, in_channels, out_channels, kernel=1, stride=1, dropout=0.1, bias=False):
20 | super().__init__()
21 | self.add_module('layer1', ConvLayer(in_channels, out_channels, kernel))
22 | self.add_module('layer2', DWConvLayer(out_channels, out_channels, stride=stride))
23 |
24 | def forward(self, x):
25 | return super().forward(x)
26 |
27 |
28 | class DWConvLayer(nn.Sequential):
29 | def __init__(self, in_channels, out_channels, stride=1, bias=False):
30 | super().__init__()
31 | out_ch = out_channels
32 |
33 | groups = in_channels
34 | kernel = 3
35 | # print(kernel, 'x', kernel, 'x', out_channels, 'x', out_channels, 'DepthWise')
36 |
37 | self.add_module('dwconv', nn.Conv2d(groups, groups, kernel_size=3,
38 | stride=stride, padding=1, groups=groups, bias=bias))
39 | self.add_module('norm', nn.BatchNorm2d(groups))
40 |
41 | def forward(self, x):
42 | return super().forward(x)
43 |
44 |
45 | class ConvLayer(nn.Sequential):
46 | def __init__(self, in_channels, out_channels, kernel=3, stride=1, dropout=0.1, bias=False):
47 | super().__init__()
48 | out_ch = out_channels
49 | groups = 1
50 | # print(kernel, 'x', kernel, 'x', in_channels, 'x', out_channels)
51 | self.add_module('conv', nn.Conv2d(in_channels, out_ch, kernel_size=kernel,
52 | stride=stride, padding=kernel // 2, groups=groups, bias=bias))
53 | self.add_module('norm', nn.BatchNorm2d(out_ch))
54 | self.add_module('relu', nn.ReLU6(True))
55 |
56 | def forward(self, x):
57 | return super().forward(x)
58 |
59 |
60 | class HarDBlock(nn.Module):
61 | def get_link(self, layer, base_ch, growth_rate, grmul):
62 | if layer == 0:
63 | return base_ch, 0, []
64 | out_channels = growth_rate
65 | link = []
66 | for i in range(10):
67 | dv = 2 ** i
68 | if layer % dv == 0:
69 | k = layer - dv
70 | link.append(k)
71 | if i > 0:
72 | out_channels *= grmul
73 | out_channels = int(int(out_channels + 1) / 2) * 2
74 | in_channels = 0
75 | for i in link:
76 | ch, _, _ = self.get_link(i, base_ch, growth_rate, grmul)
77 | in_channels += ch
78 | return out_channels, in_channels, link
79 |
80 | def get_out_ch(self):
81 | return self.out_channels
82 |
83 | def __init__(self, in_channels, growth_rate, grmul, n_layers, keepBase=False, residual_out=False, dwconv=False):
84 | super().__init__()
85 | self.keepBase = keepBase
86 | self.links = []
87 | layers_ = []
88 | self.out_channels = 0 # if upsample else in_channels
89 | for i in range(n_layers):
90 | outch, inch, link = self.get_link(i + 1, in_channels, growth_rate, grmul)
91 | self.links.append(link)
92 | use_relu = residual_out
93 | if dwconv:
94 | layers_.append(CombConvLayer(inch, outch))
95 | else:
96 | layers_.append(ConvLayer(inch, outch))
97 |
98 | if (i % 2 == 0) or (i == n_layers - 1):
99 | self.out_channels += outch
100 | # print("Blk out =",self.out_channels)
101 | self.layers = nn.ModuleList(layers_)
102 |
103 | def forward(self, x):
104 | layers_ = [x]
105 |
106 | for layer in range(len(self.layers)):
107 | link = self.links[layer]
108 | tin = []
109 | for i in link:
110 | tin.append(layers_[i])
111 | if len(tin) > 1:
112 | x = torch.cat(tin, 1)
113 | else:
114 | x = tin[0]
115 | out = self.layers[layer](x)
116 | layers_.append(out)
117 |
118 | t = len(layers_)
119 | out_ = []
120 | for i in range(t):
121 | if (i == 0 and self.keepBase) or \
122 | (i == t - 1) or (i % 2 == 1):
123 | out_.append(layers_[i])
124 | out = torch.cat(out_, 1)
125 | return out
126 |
127 |
128 | class HarDNet(nn.Module):
129 | def __init__(self, depth_wise=False, arch=85, pretrained=True, weight_path='', out=1, args=None):
130 | super().__init__()
131 | first_ch = [32, 64]
132 | second_kernel = 3
133 | max_pool = True
134 | grmul = 1.7
135 | drop_rate = 0.1
136 | args['order'] = arch
137 | # HarDNet68
138 | ch_list = [128, 256, 320, 640, 1024]
139 | gr = [14, 16, 20, 40, 160]
140 | n_layers = [8, 16, 16, 16, 4]
141 | downSamp = [1, 0, 1, 1, 0]
142 |
143 | if arch == 85:
144 | # HarDNet85
145 | first_ch = [48, 96]
146 | ch_list = [192, 256, 320, 480, 720, 1280]
147 | gr = [24, 24, 28, 36, 48, 256]
148 | n_layers = [8, 16, 16, 16, 16, 4]
149 | downSamp = [1, 0, 1, 0, 1, 0]
150 | drop_rate = 0.2
151 | elif arch == 39:
152 | # HarDNet39
153 | first_ch = [24, 48]
154 | ch_list = [96, 320, 640, 1024]
155 | grmul = 1.6
156 | gr = [16, 20, 64, 160]
157 | n_layers = [4, 16, 8, 4]
158 | downSamp = [1, 1, 1, 0]
159 |
160 | if depth_wise:
161 | second_kernel = 1
162 | max_pool = False
163 | drop_rate = 0.05
164 |
165 | blks = len(n_layers)
166 | self.base = nn.ModuleList([])
167 |
168 | # First Layer: Standard Conv3x3, Stride=2
169 | self.base.append(
170 | ConvLayer(in_channels=3, out_channels=first_ch[0], kernel=3,
171 | stride=2, bias=False))
172 |
173 | # Second Layer
174 | self.base.append(ConvLayer(first_ch[0], first_ch[1], kernel=second_kernel))
175 |
176 | # Maxpooling or DWConv3x3 downsampling
177 | if max_pool:
178 | self.base.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
179 | else:
180 | self.base.append(DWConvLayer(first_ch[1], first_ch[1], stride=2))
181 |
182 | # Build all HarDNet blocks
183 | ch = first_ch[1]
184 | for i in range(blks):
185 | blk = HarDBlock(ch, gr[i], grmul, n_layers[i], dwconv=depth_wise)
186 | ch = blk.get_out_ch()
187 | self.base.append(blk)
188 |
189 | if i == blks - 1 and arch == 85:
190 | self.base.append(nn.Dropout(0.1))
191 |
192 | self.base.append(ConvLayer(ch, ch_list[i], kernel=1))
193 | ch = ch_list[i]
194 | if downSamp[i] == 1:
195 | if max_pool:
196 | self.base.append(nn.MaxPool2d(kernel_size=2, stride=2))
197 | else:
198 | self.base.append(DWConvLayer(ch, ch, stride=2))
199 |
200 | ch = ch_list[blks - 1]
201 | self.base.append(
202 | nn.Sequential(
203 | nn.AdaptiveAvgPool2d((1, 1)),
204 | Flatten(),
205 | nn.Dropout(drop_rate),
206 | nn.Linear(ch, 1000)))
207 |
208 | # print(self.base)
209 |
210 | if pretrained:
211 | if hasattr(torch, 'hub'):
212 |
213 | if arch == 68 and not depth_wise:
214 | checkpoint = 'https://ping-chao.com/hardnet/hardnet68-5d684880.pth'
215 | elif arch == 85 and not depth_wise:
216 | checkpoint = 'https://ping-chao.com/hardnet/hardnet85-a28faa00.pth'
217 | elif arch == 68 and depth_wise:
218 | checkpoint = 'https://ping-chao.com/hardnet/hardnet68ds-632474d2.pth'
219 | else:
220 | checkpoint = 'https://ping-chao.com/hardnet/hardnet39ds-0e6c6fa9.pth'
221 |
222 | self.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
223 | else:
224 | postfix = 'ds' if depth_wise else ''
225 | weight_file = '%shardnet%d%s.pth' % (weight_path, arch, postfix)
226 | if not os.path.isfile(weight_file):
227 | print(weight_file, 'is not found')
228 | exit(0)
229 | weights = torch.load(weight_file)
230 | self.load_state_dict(weights)
231 |
232 | postfix = 'DS' if depth_wise else ''
233 | print('ImageNet pretrained weights for HarDNet%d%s is loaded' % (arch, postfix))
234 |
235 | if int(args['outlayer']) == 1:
236 | if int(args['order']) == 39:
237 | self.features = 96
238 | self.base = self.base[0:5]
239 | if int(args['order']) == 68:
240 | self.features = 128
241 | self.base = self.base[0:5]
242 | if int(args['order']) == 85:
243 | self.features = 192
244 | self.base = self.base[0:5]
245 | elif int(args['outlayer']) == 2:
246 | if int(args['order']) == 39:
247 | self.features = 320
248 | self.base = self.base[0:8]
249 | if int(args['order']) == 68:
250 | self.features = 320
251 | self.base = self.base[0:10]
252 | if int(args['order']) == 85:
253 | self.features = 320
254 | self.base = self.base[0:10]
255 | elif int(args['outlayer']) == 3:
256 | if int(args['order']) == 39:
257 | self.features = 640
258 | self.base = self.base[0:11]
259 | if int(args['order']) == 68:
260 | self.features = 640
261 | self.base = self.base[0:13]
262 | if int(args['order']) == 85:
263 | self.features = 720
264 | self.base = self.base[0:15]
265 | elif int(args['outlayer']) == 4:
266 | if int(args['order']) == 39:
267 | self.features = 1024
268 | self.base = self.base[0:14]
269 | if int(args['order']) == 68:
270 | self.features = 1024
271 | self.base = self.base[0:16]
272 | if int(args['order']) == 85:
273 | self.features = 1280
274 | self.base = self.base[0:19]
275 | if int(args['order']) == 39:
276 | self.full_features = [48, 96, 320, 640, 1024]
277 | self.list = [1, 4, 7, 10, 13]
278 | if int(args['order']) == 68:
279 | self.full_features = [64, 128, 320, 640, 1024]
280 | self.list = [1, 4, 9, 12, 15]
281 | if int(args['order']) == 85:
282 | self.full_features = [96, 192, 320, 720, 1280]
283 | self.list = [1, 4, 9, 14, 18]
284 |
285 | def forward(self, x):
286 | for inx, layer in enumerate(self.base):
287 | x = layer(x)
288 | if inx == self.list[0]:
289 | x2 = x
290 | if inx == len(self.base) - 1:
291 | return x2
292 | elif inx == self.list[1]:
293 | x4 = x
294 | if inx == len(self.base) - 1:
295 | return x2, x4
296 | elif inx == self.list[2]:
297 | x8 = x
298 | if inx == len(self.base) - 1:
299 | return x2, x4, x8
300 | elif inx == self.list[3]:
301 | x16 = x
302 | if inx == len(self.base) - 1:
303 | return x2, x4, x8, x16
304 | elif inx == self.list[4]:
305 | x32 = x
306 | if inx == len(self.base) - 1:
307 | return x2, x4, x8, x16, x32
308 |
309 |
310 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | from models.vanilla import *
2 | import neural_renderer as nr
3 | from models.hardnet import *
4 | from utils.utils_train import norm_input
5 |
6 | class Decoder(nn.Module):
7 | def __init__(self, full_features, args):
8 | super(Decoder, self).__init__()
9 | if int(args['outlayer']) == 2:
10 | self.up1 = UpBlock(full_features[1] + full_features[0], 2,
11 | func='tanh', drop=float(args['drop'])).cuda()
12 | if int(args['outlayer']) == 3:
13 | self.up1 = UpBlock(full_features[2] + full_features[1], full_features[1],
14 | func='relu', drop=float(args['drop'])).cuda()
15 | self.up2 = UpBlock(full_features[1] + full_features[0], 2,
16 | func='tanh', drop=float(args['drop'])).cuda()
17 | if int(args['outlayer']) == 4:
18 | self.up1 = UpBlock(full_features[3] + full_features[2], full_features[2],
19 | func='relu', drop=float(args['drop'])).cuda()
20 | self.up2 = UpBlock(full_features[2] + full_features[1], full_features[1],
21 | func='relu', drop=float(args['drop'])).cuda()
22 | self.up3 = UpBlock(full_features[1] + full_features[0], 2,
23 | func='tanh', drop=float(args['drop'])).cuda()
24 | self.args = args
25 |
26 | def forward(self, x, size):
27 | if int(self.args['outlayer']) == 2:
28 | shift_map = self.up1(x[1], x[0])
29 | if int(self.args['outlayer']) == 3:
30 | z = self.up1(x[2], x[1])
31 | shift_map = self.up2(z, x[0])
32 | if int(self.args['outlayer']) == 4:
33 | z = self.up1(x[3], x[2])
34 | z = self.up2(z, x[1])
35 | shift_map = self.up3(z, x[0])
36 | shift_map = F.interpolate(shift_map, size=size, mode='bilinear', align_corners=True)
37 | return shift_map[:, 0, :, :].unsqueeze(dim=1), shift_map[:, 1, :, :].unsqueeze(dim=1)
38 |
39 |
40 | class DeepACM(nn.Module):
41 | def __init__(self, args):
42 | super(DeepACM, self).__init__()
43 | self.backbone = HarDNet(depth_wise=bool(int(args['depth_wise'])), arch=int(args['order']), args=args)
44 | self.ACMDecoder = Decoder(self.backbone.full_features, args)
45 | self.nP = int(args['nP'])
46 | self.texture_size = 2
47 | self.camera_distance = 1
48 | self.elevation = 0
49 | self.azimuth = 0
50 | self.image_size = int(args['im_size'])
51 | self.renderer = nr.Renderer(camera_mode='look_at', image_size=self.image_size, light_intensity_ambient=1,
52 | light_intensity_directional=1, perspective=False)
53 |
54 | def forward(self, I, P, faces, it):
55 | size = I.size()[2:]
56 | z = self.backbone(I)
57 | Ix, Iy = self.ACMDecoder(z, size)
58 | masks = []
59 | Ps = []
60 | for i in range(it):
61 | Pxx = F.grid_sample(Ix, P).transpose(3, 2)
62 | Pyy = F.grid_sample(Iy, P).transpose(3, 2)
63 | Pedge = torch.cat((Pxx, Pyy), -1)
64 | P = Pedge + P
65 | z = torch.ones((P.shape[0], 1, P.shape[2], 1)).cuda()
66 | PP = torch.cat((P, z), 3)
67 | PP = torch.squeeze(PP, dim=1)
68 | PP[:, :, 1] = PP[:, :, 1]*-1
69 | faces = torch.squeeze(faces, dim=1)
70 | self.renderer.eye = nr.get_points_from_angles(self.camera_distance, self.elevation, self.azimuth)
71 | mask = self.renderer(PP, faces, mode='silhouettes').unsqueeze(dim=1)
72 | PP[:, :, 1] = PP[:, :, 1]*-1
73 | Ps.append(PP[:, :, 0:2].unsqueeze(dim=1))
74 | masks.append(mask)
75 | return masks, Ps, Ix, Iy, I
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/models/model_seg.py:
--------------------------------------------------------------------------------
1 | from models.vanilla import *
2 | import neural_renderer as nr
3 | from models.hardnet import *
4 |
5 |
6 | class final_layer(nn.Module):
7 | def __init__(self):
8 | super(final_layer, self).__init__()
9 |
10 | def forward(self, z, size):
11 | return F.interpolate(z, size=size, mode='bilinear', align_corners=True)
12 |
13 |
14 | class Decoder(nn.Module):
15 | def __init__(self, full_features, args):
16 | super(Decoder, self).__init__()
17 | if int(args['outlayer']) == 2:
18 | self.up1 = UpBlock(full_features[1] + full_features[0], 1,
19 | func='sigmoid', drop=float(args['drop'])).cuda()
20 | if int(args['outlayer']) == 3:
21 | self.up1 = UpBlock(full_features[2] + full_features[1], full_features[1],
22 | func='relu', drop=float(args['drop'])).cuda()
23 | self.up2 = UpBlock(full_features[1] + full_features[0], 1,
24 | func='sigmoid', drop=float(args['drop'])).cuda()
25 | if int(args['outlayer']) == 4:
26 | self.up1 = UpBlock(full_features[3] + full_features[2], full_features[2],
27 | func='relu', drop=float(args['drop'])).cuda()
28 | self.up2 = UpBlock(full_features[2] + full_features[1], full_features[1],
29 | func='relu', drop=float(args['drop'])).cuda()
30 | self.up3 = UpBlock(full_features[1] + full_features[0], 1,
31 | func='sigmoid', drop=float(args['drop'])).cuda()
32 | self.args = args
33 | self.final = final_layer()
34 |
35 | def forward(self, x, size):
36 | if int(self.args['outlayer']) == 2:
37 | z = self.up1(x[1], x[0])
38 | if int(self.args['outlayer']) == 3:
39 | z = self.up1(x[2], x[1])
40 | z = self.up2(z, x[0])
41 | if int(self.args['outlayer']) == 4:
42 | z = self.up1(x[3], x[2])
43 | z = self.up2(z, x[1])
44 | z = self.up3(z, x[0])
45 | return self.final(z, size)
46 |
47 |
48 | class Segmentation(nn.Module):
49 | def __init__(self, args):
50 | super(Segmentation, self).__init__()
51 | self.backbone = HarDNet(depth_wise=bool(int(args['depth_wise'])), arch=int(args['order']), args=args)
52 | self.decoder = Decoder(self.backbone.full_features, args)
53 |
54 | def forward(self, I):
55 | size = I.size()[2:]
56 | z = self.backbone(I)
57 | return self.decoder(z, size)
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/models/vanilla.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_channels, out_channels, kernel_size=3, drop = 0):
8 | super(ResidualBlock, self).__init__()
9 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
10 | self.conv1_drop = nn.Dropout2d(drop)
11 | self.BN1 = nn.BatchNorm2d(out_channels)
12 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=1)
13 | self.conv2_drop = nn.Dropout2d(drop)
14 | self.BN2 = nn.BatchNorm2d(out_channels)
15 |
16 | def forward(self, x_in):
17 | x = self.conv1_drop(self.conv1(x_in))
18 | x = F.relu(self.BN1(x))
19 | x = self.conv2_drop(self.conv2(x))
20 | x = F.relu(self.BN2(x))
21 | return x
22 |
23 |
24 | class DownBlock(nn.Module):
25 | def __init__(self, in_channels, out_channels, kernel_size=3, drop=0):
26 | super(DownBlock, self).__init__()
27 | P = int((kernel_size -1 ) /2)
28 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=P)
29 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=P)
30 | self.pool = nn.MaxPool2d((2, 2))
31 | self.conv1_drop = nn.Dropout2d(drop)
32 | self.conv2_drop = nn.Dropout2d(drop)
33 | self.BN = nn.BatchNorm2d(out_channels)
34 |
35 | def forward(self, x_in):
36 | x1 = self.conv2_drop(self.conv2(self.conv1_drop(self.conv1(x_in))))
37 | x1_pool = F.relu(self.BN(self.pool(x1)))
38 | return x1, x1_pool
39 |
40 |
41 | class UpBlock(nn.Module):
42 | def __init__(self, in_channels, out_channels, kernel_size=3, func=None, drop=0):
43 | super(UpBlock, self).__init__()
44 | P = int((kernel_size -1 ) /2)
45 | self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear')
46 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=P)
47 | self.conv1_drop = nn.Dropout2d(drop)
48 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=P)
49 | self.conv2_drop = nn.Dropout2d(drop)
50 | self.BN = nn.BatchNorm2d(out_channels)
51 | self.func = func
52 |
53 | def forward(self, x_in, x_up):
54 | x = self.Upsample(x_in)
55 | x_cat = torch.cat((x, x_up), 1)
56 | x1 = self.conv2_drop(self.conv2(self.conv1_drop(self.conv1(x_cat))))
57 | if self.func == 'tanh':
58 | return F.tanh(self.BN(x1))
59 | elif self.func == 'relu':
60 | return F.relu(self.BN(x1))
61 | elif self.func == 'sigmoid':
62 | return F.sigmoid(self.BN(x1))
63 |
64 |
65 | class Encoder(nn.Module):
66 | def __init__(self ,AEdim, drop=0):
67 | super(Encoder, self).__init__()
68 | a, b, c = int(AEdim /4), int(AEdim /2), AEdim
69 | self.down1 = DownBlock(3, a, drop=drop)
70 | self.down2 = DownBlock(a, b, drop=drop)
71 | self.down3 = DownBlock(b, c, drop=drop)
72 |
73 | def forward(self, x_in):
74 | x1, x1_pool = self.down1(x_in)
75 | x2, x2_pool = self.down2(x1_pool)
76 | x3, x3_pool = self.down3(x2_pool)
77 | return x1, x2, x3, x3_pool
78 |
--------------------------------------------------------------------------------
/out/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/out/__init__.py
--------------------------------------------------------------------------------
/pics/16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/pics/16.jpg
--------------------------------------------------------------------------------
/pics/8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/pics/8.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | absl-py=0.9.0=pypi_0
6 | attrs=19.3.0=py_0
7 | backcall=0.2.0=pyh9f0ad1d_0
8 | blas=1.0=mkl
9 | bleach=3.1.5=pyh9f0ad1d_0
10 | blinker=1.4=py38_0
11 | brotlipy=0.7.0=py38h1e0a361_1000
12 | c-ares=1.16.1=h7b6447c_0
13 | ca-certificates=2020.7.22=0
14 | cachetools=4.1.0=pypi_0
15 | certifi=2020.6.20=py38_0
16 | cffi=1.14.0=py38he30daa8_1
17 | chardet=3.0.4=py38h32f6830_1006
18 | click=7.1.2=py_0
19 | cryptography=2.9.2=py38h766eaa4_0
20 | cudatoolkit=10.0.130=0
21 | cycler=0.10.0=pypi_0
22 | cython=0.29.20=pypi_0
23 | decorator=4.4.2=py_0
24 | defusedxml=0.6.0=py_0
25 | entrypoints=0.3=py38h32f6830_1001
26 | freetype=2.9.1=h8a8886c_1
27 | google-auth=1.17.0=pypi_0
28 | google-auth-oauthlib=0.4.1=py_2
29 | grpcio=1.29.0=pypi_0
30 | h5py=2.10.0=pypi_0
31 | idna=2.9=pypi_0
32 | imageio=2.8.0=pypi_0
33 | importlib-metadata=1.7.0=py38h32f6830_0
34 | importlib_metadata=1.7.0=0
35 | intel-openmp=2020.1=217
36 | ipdb=0.13.2=pypi_0
37 | ipykernel=5.3.2=py38h23f93f0_0
38 | ipython=7.15.0=pypi_0
39 | ipython-genutils=0.2.0=pypi_0
40 | ipython_genutils=0.2.0=py_1
41 | jedi=0.17.0=pypi_0
42 | jinja2=2.11.2=pyh9f0ad1d_0
43 | joblib=0.15.1=pypi_0
44 | jpeg=9b=h024ee3a_2
45 | json5=0.9.5=pypi_0
46 | jsonpatch=1.25=pypi_0
47 | jsonpointer=2.0=pypi_0
48 | jsonschema=3.2.0=py38h32f6830_1
49 | jupyter_client=6.1.5=py_0
50 | jupyter_core=4.6.3=py38h32f6830_1
51 | jupyterlab=2.1.5=py_0
52 | jupyterlab_server=1.2.0=py_0
53 | kiwisolver=1.2.0=pypi_0
54 | ld_impl_linux-64=2.33.1=h53a641e_7
55 | libedit=3.1.20181209=hc058e9b_0
56 | libffi=3.3=he6710b0_1
57 | libgcc-ng=9.1.0=hdf63c60_0
58 | libgfortran-ng=7.3.0=hdf63c60_0
59 | libpng=1.6.37=hbc83047_0
60 | libprotobuf=3.12.4=hd408876_0
61 | libsodium=1.0.17=h516909a_0
62 | libstdcxx-ng=9.1.0=hdf63c60_0
63 | libtiff=4.1.0=h2733197_1
64 | lz4-c=1.9.2=he6710b0_0
65 | markdown=3.2.2=py38_0
66 | markupsafe=1.1.1=py38h1e0a361_1
67 | matplotlib=3.2.1=pypi_0
68 | mistune=0.8.4=py38h1e0a361_1001
69 | mkl=2020.1=217
70 | mkl-service=2.3.0=py38he904b0f_0
71 | mkl_fft=1.0.15=py38ha843d7b_0
72 | mkl_random=1.1.1=py38h0573a6f_0
73 | nbconvert=5.6.1=py38h32f6830_1
74 | nbformat=5.0.7=py_0
75 | ncurses=6.2=he6710b0_1
76 | networkx=2.4=pypi_0
77 | neural-renderer-pytorch=1.1.3=pypi_0
78 | ninja=1.9.0=py38hfd86e86_0
79 | notebook=6.0.3=py38h32f6830_1
80 | numpy=1.18.1=py38h4f9e942_0
81 | numpy-base=1.18.1=py38hde5b4d6_1
82 | oauthlib=3.1.0=py_0
83 | olefile=0.46=py_0
84 | opencv-python=4.2.0.34=pypi_0
85 | openssl=1.1.1h=h7b6447c_0
86 | packaging=20.4=pyh9f0ad1d_0
87 | pandoc=2.10=0
88 | pandocfilters=1.4.2=pypi_0
89 | parso=0.7.0=pyh9f0ad1d_0
90 | pexpect=4.8.0=py38h32f6830_1
91 | pickleshare=0.7.5=py38h32f6830_1001
92 | pillow=7.1.2=py38hb39fc2d_0
93 | pip=20.0.2=py38_3
94 | progressbar=2.5=pypi_0
95 | prometheus_client=0.8.0=pyh9f0ad1d_0
96 | prompt-toolkit=3.0.5=py_1
97 | protobuf=3.12.2=pypi_0
98 | ptyprocess=0.6.0=py_1001
99 | pyasn1=0.4.8=py_0
100 | pyasn1-modules=0.2.8=pypi_0
101 | pycocotools=2.0.0=pypi_0
102 | pycparser=2.20=pyh9f0ad1d_2
103 | pygments=2.6.1=py_0
104 | pyjwt=1.7.1=py38_0
105 | pyopengl=3.1.5=pypi_0
106 | pyopenssl=19.1.0=py_1
107 | pyparsing=2.4.7=pyh9f0ad1d_0
108 | pyrsistent=0.16.0=py38h1e0a361_0
109 | pysocks=1.7.1=py38h32f6830_1
110 | python=3.8.3=hcff3b4d_0
111 | python-dateutil=2.8.1=py_0
112 | python_abi=3.8=1_cp38
113 | pytorch=1.4.0=py3.8_cuda10.0.130_cudnn7.6.3_0
114 | pywavelets=1.1.1=pypi_0
115 | pyzmq=19.0.1=pypi_0
116 | readline=8.0=h7b6447c_0
117 | requests=2.23.0=pypi_0
118 | requests-oauthlib=1.3.0=py_0
119 | rsa=4.1=pypi_0
120 | scikit-image=0.17.2=pypi_0
121 | scikit-learn=0.23.1=pypi_0
122 | scipy=1.4.1=pypi_0
123 | send2trash=1.5.0=py_0
124 | setuptools=47.1.1=py38_0
125 | six=1.15.0=py_0
126 | sqlite=3.31.1=h62c20be_1
127 | tensorboard=2.2.2=pypi_0
128 | tensorboard-plugin-wit=1.6.0.post3=pypi_0
129 | tensorboardx=2.1=py_0
130 | terminado=0.8.3=pypi_0
131 | testpath=0.4.4=py_0
132 | threadpoolctl=2.1.0=pypi_0
133 | tifffile=2020.6.3=pypi_0
134 | tk=8.6.8=hbc83047_0
135 | torchfile=0.1.0=pypi_0
136 | torchvision=0.5.0=py38_cu100
137 | tornado=6.0.4=py38h1e0a361_1
138 | tqdm=4.46.1=pypi_0
139 | traitlets=4.3.3=py38h32f6830_1
140 | urllib3=1.25.9=py_0
141 | visdom=0.1.8.9=pypi_0
142 | wcwidth=0.2.4=pypi_0
143 | webencodings=0.5.1=pypi_0
144 | websocket-client=0.57.0=pypi_0
145 | werkzeug=1.0.1=py_0
146 | wheel=0.34.2=py38_0
147 | xlrd=1.2.0=py_0
148 | xz=5.2.5=h7b6447c_0
149 | zeromq=4.3.2=he1b5a44_2
150 | zipp=3.1.0=py_0
151 | zlib=1.2.11=h7b6447c_3
152 | zstd=1.4.4=h0b5b093_3
153 |
--------------------------------------------------------------------------------
/results/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/results/__init__.py
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from tensorboardX import SummaryWriter
3 |
4 | from models.model import *
5 | from models.model_seg import Segmentation
6 | from loader.viah_loader import *
7 | from loader.bing_loader import *
8 | from utils.utils_args import *
9 | from utils.utils_eval import *
10 | from utils.utils_train import *
11 | from utils.utils_tri import *
12 | from utils.utils_vis import *
13 | from utils.utils_lr import *
14 | from utils.loss import *
15 | from utils.snake_loss import Snakeloss
16 |
17 |
18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19 | torch.backends.cudnn.benchmark = True
20 | args = get_args()
21 | save_args(args)
22 | writer = SummaryWriter()
23 |
24 | PATH = r'results/' + args['task']
25 | segnet = Segmentation(args)
26 | segnet1 = torch.load(PATH + '/best/SEG.pt')
27 | segnet.load_state_dict(segnet1.state_dict())
28 | segnet.eval().to(device)
29 | model = DeepACM(args)
30 | model.train().to(device)
31 |
32 | P_test, faces_test = get_poly(int(args['im_size']), int(args['nP']),
33 | int(args['Radius']), int(args['im_size']) / 2, int(args['im_size']) / 2)
34 | faces_test = faces_test.unsqueeze(dim=0).unsqueeze(dim=0).cuda()
35 | faces = faces_test.repeat(int(args['Batch_size']), 1, 1, 1).cuda()
36 | PTrain = P_test.repeat(int(args['Batch_size']), 1, 1, 1).cuda()
37 | criterion = SoftDiceLoss()
38 | snake_loss = Snakeloss(criterion)
39 | optimizer = torch.optim.Adam(model.parameters(), lr=float(args['learning_rate']), weight_decay=float(args['WD']))
40 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(args['D_rate']), gamma=0.3)
41 | if args['task'] == 'viah':
42 | trainset = viah_segmentation(ann='training', args=args)
43 | testset = viah_segmentation(ann='test', args=args)
44 | elif args['task'] == 'bing':
45 | trainset = bing_segmentation(ann='training', args=args)
46 | testset = bing_segmentation(ann='test', args=args)
47 |
48 | ds = torch.utils.data.DataLoader(trainset, batch_size=int(args['Batch_size']), shuffle=True,
49 | num_workers=int(args['nW']), drop_last=True)
50 | ds_tri = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False,
51 | num_workers=1, drop_last=False)
52 | ds_val = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False,
53 | num_workers=1, drop_last=False)
54 | best = 0
55 | max_iter = int(args['DeepIt'])
56 | for epoch in range(1, int(args['epochs'])):
57 | args['DeepIt'] = int(epoch/20) + 1
58 | if args['DeepIt'] > max_iter:
59 | args['DeepIt'] = max_iter
60 | loss_list = train(ds, model, segnet, optimizer, snake_loss, PTrain, faces, args)
61 | logger_train(epoch, writer, loss_list)
62 | scheduler.step()
63 | if epoch % 5 == 2:
64 | iou = eval_ds(ds_val, model, segnet, P_test, faces_test, args)
65 | writer.add_scalar('IoU', iou, global_step=epoch)
66 | if iou > best:
67 | best = iou
68 | print('best: ' + str(epoch) + ' with ' + str(best))
69 | torch.save(model, PATH + '/' + 'ACM_best.pt')
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/train_seg.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from tensorboardX import SummaryWriter
3 | from tqdm import tqdm
4 |
5 | from models.model_seg import *
6 | from loader.viah_loader import *
7 | from loader.bing_loader import *
8 | from utils.utils_args import *
9 | from utils.loss import *
10 | from utils.utils_eval import get_dice_ji
11 | from utils.utils_lr import *
12 |
13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14 | torch.backends.cudnn.benchmark = True
15 |
16 |
17 | def train(ds, model, optimizer, criterion, criterion2, scheduler, args):
18 | loss_list = []
19 | for ix, (_x, _y) in tqdm(enumerate(ds)):
20 | _x = _x.float().cuda()
21 | _y = _y.float().cuda().unsqueeze(dim=1)
22 | optimizer.zero_grad()
23 | mask = model(_x)
24 | loss = 0.1*criterion(mask, _y) + 1*criterion2(mask, _y)
25 | loss_list.append(loss.item())
26 | loss.backward()
27 | optimizer.step()
28 | if args['opt'] == 'sgd':
29 | scheduler.step()
30 | return loss_list
31 |
32 |
33 | def eval_ds(ds, model, writer, epoch, PATH1, best, label, args):
34 | model.eval()
35 | TestDice_list = []
36 | TestIoU_list = []
37 | for ix, (_x, _y) in enumerate(ds):
38 | _x = _x.float().cuda()
39 | _y = _y.float().cuda()
40 | Mask = model(_x)
41 | Mask[Mask >= 0.5] = 1
42 | Mask[Mask < 0.5] = 0
43 | (cDice, cIoU) = get_dice_ji(Mask, _y)
44 | TestDice_list.append(cDice)
45 | TestIoU_list.append(cIoU)
46 | Dice = np.mean(TestDice_list)
47 | IoU = np.mean(TestIoU_list)
48 | print((epoch, Dice, IoU))
49 | if IoU > best and label=='test':
50 | torch.save(model, PATH1 + '/SEG_best.pt')
51 | print('best IOU results: ' + str(IoU))
52 | writer.add_scalar('Dice_' + label, Dice, global_step=epoch)
53 | writer.add_scalar('IoU_' + label, IoU, global_step=epoch)
54 | model.train()
55 | return best, IoU
56 |
57 |
58 | def main(args, writer):
59 | PATH = r'results/' + args['task']
60 | model = Segmentation(args)
61 | model.train().to(device)
62 |
63 | criterion = nn.BCELoss()
64 | criterion2 = SoftDiceLoss()
65 | if args['opt'] == 'adam':
66 | optimizer = torch.optim.Adam(model.parameters(), lr=float(args['learning_rate']),
67 | weight_decay=float(args['WD']))
68 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(args['D_rate']), gamma=0.3)
69 | elif args['opt'] == 'sgd':
70 | wd_params, non_wd_params = [], []
71 | for name, param in model.named_parameters():
72 | if param.dim() == 1:
73 | non_wd_params.append(param)
74 | elif param.dim() == 2 or param.dim() == 4:
75 | wd_params.append(param)
76 | params_list = [
77 | {'params': wd_params, },
78 | {'params': non_wd_params, 'weight_decay': 0},
79 | ]
80 | warmup_iters = 12
81 | optimizer = torch.optim.SGD(params_list,
82 | lr=float(args['learning_rate']),
83 | weight_decay=float(args['WD']),
84 | momentum=0.9)
85 | max_iter = int(args['D_rate'])
86 | scheduler = WarmupPolyLrScheduler(optimizer,
87 | power=0.9,
88 | max_iter=max_iter,
89 | warmup_iter=warmup_iters,
90 | warmup_ratio=0.001,
91 | warmup='exp',
92 | last_epoch=-1)
93 |
94 | if args['task'] == 'viah':
95 | trainset = viah_segmentation(ann='training', args=args)
96 | testset = viah_segmentation(ann='test', args=args)
97 | elif args['task'] == 'bing':
98 | trainset = bing_segmentation(ann='training', args=args)
99 | testset = bing_segmentation(ann='test', args=args)
100 |
101 | ds = torch.utils.data.DataLoader(trainset, batch_size=int(args['Batch_size']), shuffle=True,
102 | num_workers=int(args['nW']), drop_last=True)
103 | ds_val = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False,
104 | num_workers=0, drop_last=False)
105 | best = 0
106 | for epoch in range(1, int(args['epochs'])):
107 | loss_list = train(ds, model, optimizer, criterion, criterion2, scheduler, args)
108 | print('************************************************************************')
109 | print('Epoch: ' + str(epoch) + ' Mask mean loss: ' + str(np.mean(loss_list)) + ' Mask max loss: ' + str(
110 | np.max(loss_list)) + ' Mask min loss: ' + str(np.min(loss_list)))
111 | writer.add_scalar('MaskLoss', np.mean(loss_list), global_step=epoch)
112 | print('************************************************************************')
113 | if args['opt'] == 'adam':
114 | scheduler.step()
115 |
116 | if epoch % 3 == 1:
117 | best, _ = eval_ds(ds_val, model, writer, epoch, PATH, best, 'test', args)
118 |
119 | if __name__ == '__main__':
120 | args = get_args()
121 | save_args(args)
122 | writer = SummaryWriter()
123 | main(args, writer)
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/utils/__init__.py
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | get_tp_fp_fn, SoftDiceLoss, and DC_and_CE/TopK_loss are from https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/loss_functions
3 | """
4 |
5 | import torch
6 | # from ND_Crossentropy import CrossentropyND, TopKLoss, WeightedCrossEntropyLoss
7 | from torch import nn
8 | from torch.autograd import Variable
9 | from torch import einsum
10 | import numpy as np
11 |
12 |
13 | def softmax_helper(x):
14 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
15 | rpt = [1 for _ in range(len(x.size()))]
16 | rpt[1] = x.size(1)
17 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
18 | e_x = torch.exp(x - x_max)
19 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
20 |
21 |
22 | def sum_tensor(inp, axes, keepdim=False):
23 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/tensor_utilities.py
24 | axes = np.unique(axes).astype(int)
25 | if keepdim:
26 | for ax in axes:
27 | inp = inp.sum(int(ax), keepdim=True)
28 | else:
29 | for ax in sorted(axes, reverse=True):
30 | inp = inp.sum(int(ax))
31 | return inp
32 |
33 |
34 | def ignore_null(gt, shp_x, nl_cls=250):
35 | shp_x = list(shp_x)
36 | nC = shp_x[1]
37 | gt[gt == nl_cls] = nC + 1
38 | shp_x[1] = nC + 1
39 | return gt, shp_x
40 |
41 |
42 | def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
43 | """
44 | net_output must be (b, c, x, y(, z)))
45 | gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
46 | if mask is provided it must have shape (b, 1, x, y(, z)))
47 | :param net_output:
48 | :param gt:
49 | :param axes:
50 | :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
51 | :param square: if True then fp, tp and fn will be squared before summation
52 | :return:
53 | """
54 | if axes is None:
55 | axes = tuple(range(2, len(net_output.size())))
56 |
57 | shp_x = net_output.shape
58 | shp_y = gt.shape
59 |
60 | with torch.no_grad():
61 | if len(shp_x) != len(shp_y):
62 | gt = gt.view((shp_y[0], 1, *shp_y[1:]))
63 |
64 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
65 | # if this is the case then gt is probably already a one hot encoding
66 | y_onehot = gt
67 | else:
68 | gt = gt.long()
69 | gt, shp_x = ignore_null(gt, shp_x)
70 | y_onehot = torch.zeros(shp_x)
71 | if net_output.device.type == "cuda":
72 | y_onehot = y_onehot.cuda(net_output.device.index)
73 | y_onehot.scatter_(1, gt, 1)
74 | y_onehot = y_onehot[:, :-1, :, :]
75 | # input(y_onehot.shape)
76 | tp = net_output * y_onehot
77 | fp = net_output * (1 - y_onehot)
78 | fn = (1 - net_output) * y_onehot
79 |
80 | if mask is not None:
81 | tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
82 | fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
83 | fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
84 |
85 | if square:
86 | tp = tp ** 2
87 | fp = fp ** 2
88 | fn = fn ** 2
89 |
90 | tp = sum_tensor(tp, axes, keepdim=False)
91 | fp = sum_tensor(fp, axes, keepdim=False)
92 | fn = sum_tensor(fn, axes, keepdim=False)
93 |
94 | return tp, fp, fn
95 |
96 |
97 | class GDiceLoss(nn.Module):
98 | def __init__(self, apply_nonlin=None, smooth=1e-5):
99 | """
100 | Generalized Dice;
101 | Copy from: https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/losses.py#L29
102 | paper: https://arxiv.org/pdf/1707.03237.pdf
103 | tf code: https://github.com/NifTK/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py#L279
104 | """
105 | super(GDiceLoss, self).__init__()
106 |
107 | self.apply_nonlin = apply_nonlin
108 | self.smooth = smooth
109 |
110 | def forward(self, net_output, gt):
111 | shp_x = net_output.shape # (batch size,class_num,x,y,z)
112 | shp_y = gt.shape # (batch size,1,x,y,z)
113 | # one hot code for gt
114 | with torch.no_grad():
115 | if len(shp_x) != len(shp_y):
116 | gt = gt.view((shp_y[0], 1, *shp_y[1:]))
117 |
118 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
119 | # if this is the case then gt is probably already a one hot encoding
120 | y_onehot = gt
121 | else:
122 | gt = gt.long()
123 | y_onehot = torch.zeros(shp_x)
124 | if net_output.device.type == "cuda":
125 | y_onehot = y_onehot.cuda(net_output.device.index)
126 | y_onehot.scatter_(1, gt, 1)
127 |
128 | if self.apply_nonlin is not None:
129 | softmax_output = self.apply_nonlin(net_output)
130 |
131 | # copy from https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/losses.py#L29
132 | w: torch.Tensor = 1 / (einsum("bcxyz->bc", y_onehot).type(torch.float32) + 1e-10) ** 2
133 | intersection: torch.Tensor = w * einsum("bcxyz, bcxyz->bc", softmax_output, y_onehot)
134 | union: torch.Tensor = w * (einsum("bcxyz->bc", softmax_output) + einsum("bcxyz->bc", y_onehot))
135 | divided: torch.Tensor = 1 - 2 * (einsum("bc->b", intersection) + self.smooth) / (
136 | einsum("bc->b", union) + self.smooth)
137 | gdc = divided.mean()
138 |
139 | return gdc
140 |
141 |
142 | def flatten(tensor):
143 | """Flattens a given tensor such that the channel axis is first.
144 | The shapes are transformed as follows:
145 | (N, C, D, H, W) -> (C, N * D * H * W)
146 | """
147 | C = tensor.size(1)
148 | # new axis order
149 | axis_order = (1, 0) + tuple(range(2, tensor.dim()))
150 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
151 | transposed = tensor.permute(axis_order).contiguous()
152 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
153 | return transposed.view(C, -1)
154 |
155 |
156 | class GDiceLossV2(nn.Module):
157 | def __init__(self, apply_nonlin=None, smooth=1e-5):
158 | """
159 | Generalized Dice;
160 | Copy from: https://github.com/wolny/pytorch-3dunet/blob/6e5a24b6438f8c631289c10638a17dea14d42051/unet3d/losses.py#L75
161 | paper: https://arxiv.org/pdf/1707.03237.pdf
162 | tf code: https://github.com/NifTK/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py#L279
163 | """
164 | super(GDiceLossV2, self).__init__()
165 |
166 | self.apply_nonlin = apply_nonlin
167 | self.smooth = smooth
168 |
169 | def forward(self, net_output, gt):
170 | shp_x = net_output.shape # (batch size,class_num,x,y,z)
171 | shp_y = gt.shape # (batch size,1,x,y,z)
172 | # one hot code for gt
173 | with torch.no_grad():
174 | if len(shp_x) != len(shp_y):
175 | gt = gt.view((shp_y[0], 1, *shp_y[1:]))
176 |
177 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
178 | # if this is the case then gt is probably already a one hot encoding
179 | y_onehot = gt
180 | else:
181 | gt = gt.long()
182 | y_onehot = torch.zeros(shp_x)
183 | if net_output.device.type == "cuda":
184 | y_onehot = y_onehot.cuda(net_output.device.index)
185 | y_onehot.scatter_(1, gt, 1)
186 |
187 | if self.apply_nonlin is not None:
188 | softmax_output = self.apply_nonlin(net_output)
189 |
190 | input = flatten(softmax_output)
191 | target = flatten(y_onehot)
192 | target = target.float()
193 | target_sum = target.sum(-1)
194 | class_weights = Variable(1. / (target_sum * target_sum).clamp(min=self.smooth), requires_grad=False)
195 |
196 | intersect = (input * target).sum(-1) * class_weights
197 | intersect = intersect.sum()
198 |
199 | denominator = ((input + target).sum(-1) * class_weights).sum()
200 |
201 | return 1. - 2. * intersect / denominator.clamp(min=self.smooth)
202 |
203 |
204 | class SSLoss(nn.Module):
205 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
206 | square=False):
207 | """
208 | Sensitivity-Specifity loss
209 | paper: http://www.rogertam.ca/Brosch_MICCAI_2015.pdf
210 | tf code: https://github.com/NifTK/NiftyNet/blob/df0f86733357fdc92bbc191c8fec0dcf49aa5499/niftynet/layer/loss_segmentation.py#L392
211 | """
212 | super(SSLoss, self).__init__()
213 |
214 | self.square = square
215 | self.do_bg = do_bg
216 | self.batch_dice = batch_dice
217 | self.apply_nonlin = apply_nonlin
218 | self.smooth = smooth
219 | self.r = 0.1 # weight parameter in SS paper
220 |
221 | def forward(self, net_output, gt, loss_mask=None):
222 | shp_x = net_output.shape
223 | shp_y = gt.shape
224 | # class_num = shp_x[1]
225 |
226 | with torch.no_grad():
227 | if len(shp_x) != len(shp_y):
228 | gt = gt.view((shp_y[0], 1, *shp_y[1:]))
229 |
230 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
231 | # if this is the case then gt is probably already a one hot encoding
232 | y_onehot = gt
233 | else:
234 | gt = gt.long()
235 | y_onehot = torch.zeros(shp_x)
236 | if net_output.device.type == "cuda":
237 | y_onehot = y_onehot.cuda(net_output.device.index)
238 | y_onehot.scatter_(1, gt, 1)
239 |
240 | if self.batch_dice:
241 | axes = [0] + list(range(2, len(shp_x)))
242 | else:
243 | axes = list(range(2, len(shp_x)))
244 |
245 | if self.apply_nonlin is not None:
246 | softmax_output = self.apply_nonlin(net_output)
247 |
248 | # no object value
249 | bg_onehot = 1 - y_onehot
250 | squared_error = (y_onehot - softmax_output) ** 2
251 | specificity_part = sum_tensor(squared_error * y_onehot, axes) / (sum_tensor(y_onehot, axes) + self.smooth)
252 | sensitivity_part = sum_tensor(squared_error * bg_onehot, axes) / (sum_tensor(bg_onehot, axes) + self.smooth)
253 |
254 | ss = self.r * specificity_part + (1 - self.r) * sensitivity_part
255 |
256 | if not self.do_bg:
257 | if self.batch_dice:
258 | ss = ss[1:]
259 | else:
260 | ss = ss[:, 1:]
261 | ss = ss.mean()
262 |
263 | return ss
264 |
265 |
266 | class SoftDiceLoss(nn.Module):
267 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1e-6,
268 | square=False):
269 | """
270 | paper: https://arxiv.org/pdf/1606.04797.pdf
271 | """
272 | super(SoftDiceLoss, self).__init__()
273 |
274 | self.square = square
275 | self.do_bg = do_bg
276 | self.batch_dice = batch_dice
277 | self.apply_nonlin = apply_nonlin
278 | self.smooth = smooth
279 |
280 | def forward(self, x, y, loss_mask=None):
281 | shp_x = x.shape
282 |
283 | if self.batch_dice:
284 | axes = [0] + list(range(2, len(shp_x)))
285 | else:
286 | axes = list(range(2, len(shp_x)))
287 |
288 | if self.apply_nonlin is not None:
289 | x = self.apply_nonlin(x)
290 |
291 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
292 |
293 | dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
294 |
295 | if not self.do_bg:
296 | if self.batch_dice:
297 | dc = dc[1:]
298 | else:
299 | dc = dc[:, 1:]
300 | dc = dc.mean()
301 |
302 | return 1 - dc
303 |
304 |
305 | class IoULoss(nn.Module):
306 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
307 | square=False):
308 | """
309 | paper: https://link.springer.com/chapter/10.1007/978-3-319-50835-1_22
310 |
311 | """
312 | super(IoULoss, self).__init__()
313 |
314 | self.square = square
315 | self.do_bg = do_bg
316 | self.batch_dice = batch_dice
317 | self.apply_nonlin = apply_nonlin
318 | self.smooth = smooth
319 |
320 | def forward(self, x, y, loss_mask=None):
321 | shp_x = x.shape
322 |
323 | if self.batch_dice:
324 | axes = [0] + list(range(2, len(shp_x)))
325 | else:
326 | axes = list(range(2, len(shp_x)))
327 |
328 | if self.apply_nonlin is not None:
329 | x = self.apply_nonlin(x)
330 |
331 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
332 |
333 | iou = (tp + self.smooth) / (tp + fp + fn + self.smooth)
334 |
335 | if not self.do_bg:
336 | if self.batch_dice:
337 | iou = iou[1:]
338 | else:
339 | iou = iou[:, 1:]
340 | iou = iou.mean()
341 |
342 | return -iou
343 |
344 |
345 | class TverskyLoss(nn.Module):
346 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
347 | square=False):
348 | """
349 | paper: https://arxiv.org/pdf/1706.05721.pdf
350 | """
351 | super(TverskyLoss, self).__init__()
352 |
353 | self.square = square
354 | self.do_bg = do_bg
355 | self.batch_dice = batch_dice
356 | self.apply_nonlin = apply_nonlin
357 | self.smooth = smooth
358 | self.alpha = 0.3
359 | self.beta = 0.7
360 |
361 | def forward(self, x, y, loss_mask=None):
362 | shp_x = x.shape
363 |
364 | if self.batch_dice:
365 | axes = [0] + list(range(2, len(shp_x)))
366 | else:
367 | axes = list(range(2, len(shp_x)))
368 |
369 | if self.apply_nonlin is not None:
370 | x = self.apply_nonlin(x)
371 |
372 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
373 |
374 | tversky = (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)
375 |
376 | if not self.do_bg:
377 | if self.batch_dice:
378 | tversky = tversky[1:]
379 | else:
380 | tversky = tversky[:, 1:]
381 | tversky = tversky.mean()
382 |
383 | return -tversky
384 |
385 |
386 | class FocalTversky_loss(nn.Module):
387 | """
388 | paper: https://arxiv.org/pdf/1810.07842.pdf
389 | author code: https://github.com/nabsabraham/focal-tversky-unet/blob/347d39117c24540400dfe80d106d2fb06d2b99e1/losses.py#L65
390 | """
391 |
392 | def __init__(self, tversky_kwargs, gamma=0.75):
393 | super(FocalTversky_loss, self).__init__()
394 | self.gamma = gamma
395 | self.tversky = TverskyLoss(**tversky_kwargs)
396 |
397 | def forward(self, net_output, target):
398 | tversky_loss = 1 + self.tversky(net_output, target) # = 1-tversky(net_output, target)
399 | focal_tversky = torch.pow(tversky_loss, self.gamma)
400 | return focal_tversky
401 |
402 |
403 | class AsymLoss(nn.Module):
404 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
405 | square=False):
406 | """
407 | paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8573779
408 | """
409 | super(AsymLoss, self).__init__()
410 |
411 | self.square = square
412 | self.do_bg = do_bg
413 | self.batch_dice = batch_dice
414 | self.apply_nonlin = apply_nonlin
415 | self.smooth = smooth
416 | self.beta = 1.5
417 |
418 | def forward(self, x, y, loss_mask=None):
419 | shp_x = x.shape
420 |
421 | if self.batch_dice:
422 | axes = [0] + list(range(2, len(shp_x)))
423 | else:
424 | axes = list(range(2, len(shp_x)))
425 |
426 | if self.apply_nonlin is not None:
427 | x = self.apply_nonlin(x)
428 |
429 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square) # shape: (batch size, class num)
430 | weight = (self.beta ** 2) / (1 + self.beta ** 2)
431 | asym = (tp + self.smooth) / (tp + weight * fn + (1 - weight) * fp + self.smooth)
432 |
433 | if not self.do_bg:
434 | if self.batch_dice:
435 | asym = asym[1:]
436 | else:
437 | asym = asym[:, 1:]
438 | asym = asym.mean()
439 |
440 | return -asym
441 |
442 |
443 | class DC_and_CE_loss(nn.Module):
444 | def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
445 | super(DC_and_CE_loss, self).__init__()
446 | self.aggregate = aggregate
447 | self.ce = CrossentropyND(**ce_kwargs)
448 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)
449 |
450 | def forward(self, net_output, target):
451 | dc_loss = self.dc(net_output, target)
452 | ce_loss = self.ce(net_output, target)
453 | if self.aggregate == "sum":
454 | result = ce_loss + dc_loss
455 | else:
456 | raise NotImplementedError("nah son") # reserved for other stuff (later)
457 | return result
458 |
459 |
460 | class PenaltyGDiceLoss(nn.Module):
461 | """
462 | paper: https://openreview.net/forum?id=H1lTh8unKN
463 | """
464 |
465 | def __init__(self, gdice_kwargs):
466 | super(PenaltyGDiceLoss, self).__init__()
467 | self.k = 2.5
468 | self.gdc = GDiceLoss(apply_nonlin=softmax_helper, **gdice_kwargs)
469 |
470 | def forward(self, net_output, target):
471 | gdc_loss = self.gdc(net_output, target)
472 | penalty_gdc = gdc_loss / (1 + self.k * (1 - gdc_loss))
473 |
474 | return penalty_gdc
475 |
476 |
477 | class DC_and_topk_loss(nn.Module):
478 | def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
479 | super(DC_and_topk_loss, self).__init__()
480 | self.aggregate = aggregate
481 | self.ce = TopKLoss(**ce_kwargs)
482 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)
483 |
484 | def forward(self, net_output, target):
485 | dc_loss = self.dc(net_output, target)
486 | ce_loss = self.ce(net_output, target)
487 | if self.aggregate == "sum":
488 | result = ce_loss + dc_loss
489 | else:
490 | raise NotImplementedError("nah son") # reserved for other stuff (later?)
491 | return result
492 |
493 |
494 | class ExpLog_loss(nn.Module):
495 | """
496 | paper: 3D Segmentation with Exponential Logarithmic Loss for Highly Unbalanced Object Sizes
497 | https://arxiv.org/pdf/1809.00076.pdf
498 | """
499 |
500 | def __init__(self, soft_dice_kwargs, wce_kwargs, gamma=0.3):
501 | super(ExpLog_loss, self).__init__()
502 | self.wce = WeightedCrossEntropyLoss(**wce_kwargs)
503 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)
504 | self.gamma = gamma
505 |
506 | def forward(self, net_output, target):
507 | dc_loss = -self.dc(net_output, target) # weight=0.8
508 | wce_loss = self.wce(net_output, target) # weight=0.2
509 | # with torch.no_grad():
510 | # print('dc loss:', dc_loss.cpu().numpy(), 'ce loss:', ce_loss.cpu().numpy())
511 | # a = torch.pow(-torch.log(torch.clamp(dc_loss, 1e-6)), self.gamma)
512 | # b = torch.pow(-torch.log(torch.clamp(ce_loss, 1e-6)), self.gamma)
513 | # print('ExpLog dc loss:', a.cpu().numpy(), 'ExpLogce loss:', b.cpu().numpy())
514 | # print('*'*20)
515 | explog_loss = 0.8 * torch.pow(-torch.log(torch.clamp(dc_loss, 1e-6)), self.gamma) + \
516 | 0.2 * wce_loss
517 |
518 | return explog_loss
--------------------------------------------------------------------------------
/utils/snake_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Snakeloss:
5 | def __init__(self, criterion):
6 | self.criterion = criterion
7 |
8 | def ballon_loss(self, mask):
9 | return 1 - torch.mean(mask, dim=[2, 3]).cuda().mean()
10 |
11 | def avg_dis(self, P):
12 | nP = P.shape[2]
13 | even1 = P[:, :, 0:nP:2, :]
14 | odd1 = P[:, :, 1:nP:2, :]
15 | diff1 = torch.sum((even1-odd1)**2, dim=[2, 3])
16 | even2 = P[:, :, 2:nP:2, :]
17 | odd2 = P[:, :, 1:nP-1:2, :]
18 | diff2 = torch.sum((even2-odd2)**2, dim=[2, 3])
19 | diff3 = torch.sum((P[:, :, 0, :] - P[:, :, nP-1, :])**2, dim=[1, 2])
20 | return (1/3)*diff1**0.5+(1/3)*diff2**0.5+(1/3)*diff3**0.5
21 |
22 | def curvature_loss(self, P):
23 | Pf = P.roll(-1, dims=2)
24 | Pb = P.roll(1, dims=2)
25 | K = Pf + Pb - 2 * P
26 | return K.abs().mean()
27 |
28 | def snake_loss(self, num_of_it, net_out, gt):
29 | for it in range(num_of_it):
30 | if it == 0:
31 | loss = self.criterion(net_out[0][it], gt) + \
32 | 0.1*self.avg_dis(net_out[1][it]).mean()
33 | else:
34 | loss = loss + \
35 | self.criterion(net_out[0][it], gt) + \
36 | 0.1*self.avg_dis(net_out[1][it]).mean()
37 | return loss
38 |
39 |
--------------------------------------------------------------------------------
/utils/utils_TB.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talshaharabany/DeepACM2D/db40f6682c5a86df3d99f8c6acb9c14ef5f58599/utils/utils_TB.py
--------------------------------------------------------------------------------
/utils/utils_args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args():
4 | parser = argparse.ArgumentParser(description='Description of your program')
5 | parser.add_argument('-it', '--DeepIt', default=3, help='number of snake iteration', required=False)
6 | parser.add_argument('-lr', '--learning_rate', default=0.001, help='learning_rate', required=False)
7 | parser.add_argument('-bs', '--Batch_size', default=20, help='batch_size', required=False)
8 | parser.add_argument('-ep', '--epochs', default=25000, help='number of epoches', required=False)
9 | parser.add_argument('-drop', '--drop', default=0, help='dropout value', required=False)
10 | parser.add_argument('-R', '--Radius', default=64, help='initial guess circle radius', required=False)
11 | parser.add_argument('-D_rate', '--D_rate', default=50, help='opt. drop parameter', required=False)
12 | parser.add_argument('-opt', '--opt', default='adam', help='opt. type', required=False)
13 | parser.add_argument('-a', '--a', default=0.5, help='mask coeff with image', required=False)
14 | parser.add_argument('-nW', '--nW', default=0, help='number os workers', required=False)
15 | parser.add_argument('-WD', '--WD', default=0.00005, help='weight decay', required=False)
16 | parser.add_argument('-nP', '--nP', default=32, help='number of contour points', required=False)
17 | parser.add_argument('-order', '--order', default=85, help='backbone dimension', required=False)
18 | parser.add_argument('-depth_wise', '--depth_wise', default=0, help='depth_wise hardnet', required=False)
19 | parser.add_argument('-outlayer', '--outlayer', default=3, help='number of hardnet blocks', required=False)
20 | parser.add_argument('-im_size', '--im_size', default=256, help='image size', required=False)
21 | parser.add_argument('-task', '--task', default='bing', help='which dataset to use?', required=False)
22 | args = vars(parser.parse_args())
23 | return args
24 |
25 |
26 | def save_args(args):
27 | path = r'results/' + args['task'] + '/params.csv'
28 | f = open(path, 'w')
29 | keys = list(args.keys())
30 | vals = list(args.values())
31 | for i in range(len(keys)):
32 | f.write(str(keys[i])+','+str(vals[i])+'\n')
33 | f.flush()
34 |
--------------------------------------------------------------------------------
/utils/utils_eval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.morphology import binary_dilation, disk
3 | import os
4 | import torch
5 | from utils.utils_vis import send_image_to_TB
6 | from utils.utils_train import norm_input
7 | import cv2
8 |
9 | def image_norm(img):
10 | return (img - img.min()) / (img.max() - img.min())
11 |
12 |
13 | def vis_ds(ds, model, segnet, PTrain, faces, args, num_of_ex=5):
14 | model.eval()
15 | for ix, (_x, _y) in enumerate(ds):
16 | if ix > num_of_ex: break
17 | _x = _x.float().cuda()
18 | img = image_norm(_x.squeeze(dim=0).detach().cpu().numpy().transpose(1, 2, 0))
19 | _p = PTrain.float().cuda().clone()
20 | _y = _y.float().cuda()
21 | seg_out = segnet(_x)
22 | _x = norm_input(_x, seg_out, float(args['a']))
23 | iter = int(args['DeepIt'])
24 | net_out = model(_x, _p, faces, iter)
25 | Mask = net_out[0][iter - 1]
26 | (_, cIoU) = get_dice_ji(Mask, _y)
27 | P = net_out[1][iter - 1]
28 | P_init = PTrain.squeeze().detach().cpu().numpy().transpose(1, 0)
29 | Mask = Mask.squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy()
30 | P = P.squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy().transpose(1, 0)
31 | P = np.concatenate((P, P[:, 0:1]), 1)
32 | Ix = net_out[2].squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy()
33 | Iy = net_out[3].squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy()
34 | GT = _y.squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy()
35 | im = cv2.cvtColor(send_image_to_TB(img, P_init, Mask, P, Ix, Iy, GT, cIoU), cv2.COLOR_RGBA2BGR)
36 | cv2.imwrite('out/' + str(ix) + '.jpg', im)
37 | model.train()
38 |
39 |
40 | def eval_ds(ds, model, segnet, PTrain, faces, args):
41 | model.eval()
42 | TestIoU_list = []
43 | model.eval()
44 | with torch.no_grad():
45 | for ix, (_x, _y) in enumerate(ds):
46 | _x = _x.float().cuda()
47 | _p = PTrain.float().cuda().clone()
48 | _y = _y.float().cuda()
49 | seg_out = segnet(_x).detach()
50 | _x = norm_input(_x, seg_out, float(args['a']))
51 | iter = int(args['DeepIt'])
52 | net_out = model(_x, _p, faces, iter)
53 | Mask = net_out[0][iter-1]
54 | _, cIoU = get_dice_ji(Mask, _y)
55 | TestIoU_list.append(cIoU)
56 | IoU = np.mean(TestIoU_list)
57 | model.train()
58 | return IoU
59 |
60 |
61 | def get_dice_ji(predict, target):
62 | predict = predict.data.cpu().numpy() + 1
63 | target = target.data.cpu().numpy() + 1
64 | tp = np.sum(((predict == 2) * (target == 2)) * (target > 0))
65 | fp = np.sum(((predict == 2) * (target == 1)) * (target > 0))
66 | fn = np.sum(((predict == 1) * (target == 2)) * (target > 0))
67 | ji = float(np.nan_to_num(tp / (tp + fp + fn)))
68 | dice = float(np.nan_to_num(2 * tp / (2 * tp + fp + fn)))
69 | return dice, ji
70 |
71 |
72 | def update_net_list(PATH):
73 | a = os.listdir(PATH)
74 | PATH1_list = []
75 | for i in a:
76 | if i[0:5]=='model': PATH1_list.append(i)
77 | PATH1_list.sort()
78 | return PATH1_list
79 |
80 |
81 | def dice_metric(X, Y):
82 | return np.sum(X[Y==1])*2.0 / (np.sum(X) + np.sum(Y) + 1e-6)
83 |
84 |
85 | def IoU_metric(y_pred, y_true):
86 | intersection = np.sum(y_true * y_pred, axis=None)
87 | union = np.sum(y_true, axis=None) + np.sum(y_pred, axis=None) - intersection
88 | if float(union)==0: return 0.0
89 | else: return float(intersection) / float(union)
90 |
91 |
92 | def WCov_metric(X, Y):
93 | A1 = float(np.count_nonzero(X))
94 | A2 = float(np.count_nonzero(Y))
95 | if A1>=A2: return A2/A1
96 | if A2>A1: return A1/A2
97 |
98 |
99 | def FBound_metric(X, Y):
100 | tmp1 = db_eval_boundary(X,Y,1)[0]
101 | tmp2 = db_eval_boundary(X,Y,2)[0]
102 | tmp3 = db_eval_boundary(X,Y,3)[0]
103 | tmp4 = db_eval_boundary(X,Y,4)[0]
104 | tmp5 = db_eval_boundary(X,Y,5)[0]
105 | return (tmp1+tmp2+tmp3+tmp4+tmp5)/5.0
106 |
107 | def db_eval_boundary(foreground_mask, gt_mask, bound_th):
108 | """
109 | Compute mean,recall and decay from per-frame evaluation.
110 | Calculates precision/recall for boundaries between foreground_mask and
111 | gt_mask using morphological operators to speed it up.
112 | Arguments:
113 | foreground_mask (ndarray): binary segmentation image.
114 | gt_mask (ndarray): binary annotated image.
115 | Returns:
116 | F (float): boundaries F-measure
117 | P (float): boundaries precision
118 | R (float): boundaries recall
119 | """
120 | assert np.atleast_3d(foreground_mask).shape[2] == 1
121 |
122 | bound_pix = bound_th if bound_th >= 1 else \
123 | np.ceil(bound_th*np.linalg.norm(foreground_mask.shape))
124 |
125 | # Get the pixel boundaries of both masks
126 | fg_boundary = seg2bmap(foreground_mask);
127 | gt_boundary = seg2bmap(gt_mask);
128 |
129 | fg_dil = binary_dilation(fg_boundary,disk(bound_pix))
130 | gt_dil = binary_dilation(gt_boundary,disk(bound_pix))
131 |
132 | # Get the intersection
133 | gt_match = gt_boundary * fg_dil
134 | fg_match = fg_boundary * gt_dil
135 |
136 | # Area of the intersection
137 | n_fg = np.sum(fg_boundary)
138 | n_gt = np.sum(gt_boundary)
139 |
140 | #% Compute precision and recall
141 | if n_fg == 0 and n_gt > 0:
142 | precision = 1
143 | recall = 0
144 | elif n_fg > 0 and n_gt == 0:
145 | precision = 0
146 | recall = 1
147 | elif n_fg == 0 and n_gt == 0:
148 | precision = 1
149 | recall = 1
150 | else:
151 | precision = np.sum(fg_match)/float(n_fg)
152 | recall = np.sum(gt_match)/float(n_gt)
153 |
154 | # Compute F measure
155 | if precision + recall == 0:
156 | F = 0
157 | else:
158 | F = 2*precision*recall/(precision+recall);
159 |
160 | return F, precision, recall, np.sum(fg_match), n_fg, np.sum(gt_match), n_gt
161 |
162 |
163 | def seg2bmap(seg,width=None,height=None):
164 | """
165 | From a segmentation, compute a binary boundary map with 1 pixel wide
166 | boundaries. The boundary pixels are offset by 1/2 pixel towards the
167 | origin from the actual segment boundary.
168 | Arguments:
169 | seg : Segments labeled from 1..k.
170 | width : Width of desired bmap <= seg.shape[1]
171 | height : Height of desired bmap <= seg.shape[0]
172 | Returns:
173 | bmap (ndarray): Binary boundary map.
174 | David Martin
175 | January 2003
176 | """
177 | seg = seg.astype(np.bool)
178 | seg[seg>0] = 1
179 |
180 | assert np.atleast_3d(seg).shape[2] == 1
181 |
182 | width = seg.shape[1] if width is None else width
183 | height = seg.shape[0] if height is None else height
184 |
185 | h,w = seg.shape[:2]
186 |
187 | ar1 = float(width) / float(height)
188 | ar2 = float(w) / float(h)
189 |
190 | assert not (width>w | height>h | abs(ar1-ar2)>0.01),\
191 | 'Can''t convert %dx%d seg to %dx%d bmap.'%(w,h,width,height)
192 |
193 | e = np.zeros_like(seg)
194 | s = np.zeros_like(seg)
195 | se = np.zeros_like(seg)
196 |
197 | e[:,:-1] = seg[:,1:]
198 | s[:-1,:] = seg[1:,:]
199 | se[:-1,:-1] = seg[1:,1:]
200 |
201 | b = seg^e | seg^s | seg^se
202 | b[-1,:] = seg[-1,:]^e[-1,:]
203 | b[:,-1] = seg[:,-1]^s[:,-1]
204 | b[-1,-1] = 0
205 |
206 | if w == width and h == height:
207 | bmap = b
208 | else:
209 | bmap = np.zeros((height,width))
210 | for x in range(w):
211 | for y in range(h):
212 | if b[y,x]:
213 | j = 1+floor((y-1)+height / h)
214 | i = 1+floor((x-1)+width / h)
215 | bmap[j,i] = 1;
216 |
217 | return bmap
218 |
--------------------------------------------------------------------------------
/utils/utils_lr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import math
5 | from bisect import bisect_right
6 | import torch
7 | import numpy as np
8 |
9 |
10 | class WarmupLrScheduler(torch.optim.lr_scheduler._LRScheduler):
11 |
12 | def __init__(
13 | self,
14 | optimizer,
15 | warmup_iter=500,
16 | warmup_ratio=5e-4,
17 | warmup='exp',
18 | last_epoch=-1,
19 | ):
20 | self.warmup_iter = warmup_iter
21 | self.warmup_ratio = warmup_ratio
22 | self.warmup = warmup
23 | super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
24 |
25 | def get_lr(self):
26 | ratio = self.get_lr_ratio()
27 | lrs = [ratio * lr for lr in self.base_lrs]
28 | return lrs
29 |
30 | def get_lr_ratio(self):
31 | if self.last_epoch < self.warmup_iter:
32 | ratio = self.get_warmup_ratio()
33 | else:
34 | ratio = self.get_main_ratio()
35 | return ratio
36 |
37 | def get_main_ratio(self):
38 | raise NotImplementedError
39 |
40 | def get_warmup_ratio(self):
41 | assert self.warmup in ('linear', 'exp')
42 | alpha = self.last_epoch / self.warmup_iter
43 | if self.warmup == 'linear':
44 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
45 | elif self.warmup == 'exp':
46 | ratio = self.warmup_ratio ** (1. - alpha)
47 | return ratio
48 |
49 |
50 | class WarmupPolyLrScheduler(WarmupLrScheduler):
51 |
52 | def __init__(
53 | self,
54 | optimizer,
55 | power,
56 | max_iter,
57 | warmup_iter=500,
58 | warmup_ratio=5e-4,
59 | warmup='exp',
60 | last_epoch=-1,
61 | ):
62 | self.power = power
63 | self.max_iter = max_iter
64 | super(WarmupPolyLrScheduler, self).__init__(
65 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
66 |
67 | def get_main_ratio(self):
68 | real_iter = self.last_epoch - self.warmup_iter
69 | real_max_iter = self.max_iter - self.warmup_iter
70 | alpha = real_iter / real_max_iter
71 | ratio = (1 - alpha) ** self.power
72 | return ratio
73 |
74 |
75 | class WarmupExpLrScheduler(WarmupLrScheduler):
76 |
77 | def __init__(
78 | self,
79 | optimizer,
80 | gamma,
81 | interval=1,
82 | warmup_iter=500,
83 | warmup_ratio=5e-4,
84 | warmup='exp',
85 | last_epoch=-1,
86 | ):
87 | self.gamma = gamma
88 | self.interval = interval
89 | super(WarmupExpLrScheduler, self).__init__(
90 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
91 |
92 | def get_main_ratio(self):
93 | real_iter = self.last_epoch - self.warmup_iter
94 | ratio = self.gamma ** (real_iter // self.interval)
95 | return ratio
96 |
97 |
98 | class WarmupCosineLrScheduler(WarmupLrScheduler):
99 |
100 | def __init__(
101 | self,
102 | optimizer,
103 | max_iter,
104 | eta_ratio=0,
105 | warmup_iter=500,
106 | warmup_ratio=5e-4,
107 | warmup='exp',
108 | last_epoch=-1,
109 | ):
110 | self.eta_ratio = eta_ratio
111 | self.max_iter = max_iter
112 | super(WarmupCosineLrScheduler, self).__init__(
113 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
114 |
115 | def get_main_ratio(self):
116 | real_iter = self.last_epoch - self.warmup_iter
117 | real_max_iter = self.max_iter - self.warmup_iter
118 | return self.eta_ratio + (1 - self.eta_ratio) * (
119 | 1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2
120 |
121 |
122 | class WarmupStepLrScheduler(WarmupLrScheduler):
123 |
124 | def __init__(
125 | self,
126 | optimizer,
127 | milestones: list,
128 | gamma=0.1,
129 | warmup_iter=500,
130 | warmup_ratio=5e-4,
131 | warmup='exp',
132 | last_epoch=-1,
133 | ):
134 | self.milestones = milestones
135 | self.gamma = gamma
136 | super(WarmupStepLrScheduler, self).__init__(
137 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
138 |
139 | def get_main_ratio(self):
140 | real_iter = self.last_epoch - self.warmup_iter
141 | ratio = self.gamma ** bisect_right(self.milestones, real_iter)
142 | return ratio
143 |
144 |
145 | if __name__ == "__main__":
146 | model = torch.nn.Conv2d(3, 16, 3, 1, 1)
147 | optim = torch.optim.SGD(model.parameters(), lr=5e-2)
148 |
149 | max_iter = 10000
150 | lr_scheduler = WarmupPolyLrScheduler(optim, 0.9, max_iter, 200, 0.001, 'exp', -1)
151 | lrs = []
152 | for _ in range(max_iter+1000):
153 | lr = lr_scheduler.get_lr()[0]
154 | print(lr)
155 | lrs.append(lr)
156 | lr_scheduler.step()
157 |
158 |
159 |
--------------------------------------------------------------------------------
/utils/utils_train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 |
5 |
6 | def norm_tensor(_x):
7 | bs = _x.shape[0]
8 | min_x = _x.view(bs, 3, -1).min(dim=2)[0].view(bs, 3, 1, 1)
9 | max_x = _x.view(bs, 3, -1).max(dim=2)[0].view(bs, 3, 1, 1)
10 | return (_x - min_x) / (max_x - min_x + 1e-2)
11 |
12 |
13 | def norm_input(_x, seg_out, a):
14 | _x = norm_tensor(_x)
15 | _x = a * _x + (1 - a) * seg_out
16 | return _x
17 |
18 |
19 | def train(ds, model, segnet, optimizer, snake_loss, PTrain, faces, args):
20 | loss_list = []
21 | for ix, (_x, _y) in tqdm(enumerate(ds)):
22 | _x = _x.float().cuda()
23 | _p = PTrain.float().cuda().clone()
24 | _y = _y.float().cuda()
25 | optimizer.zero_grad()
26 | seg_out = segnet(_x)
27 | _x = norm_input(_x, seg_out, float(args['a'])).detach()
28 | num_of_it = int(args['DeepIt'])
29 | net_out = model(_x, _p, faces, num_of_it)
30 | loss = snake_loss.snake_loss(num_of_it, net_out, _y)
31 | loss_list.append(loss.item()/num_of_it)
32 | loss.backward()
33 | optimizer.step()
34 | return loss_list
35 |
36 |
37 | def logger_train(epoch, writer, loss_list):
38 | print('************************************************************************')
39 | print('Epoch: ' + str(epoch) + ' Mask mean loss: ' + str(np.mean(loss_list)) + ' Mask max loss: ' + str(
40 | np.max(loss_list)) + ' Mask min loss: ' + str(np.min(loss_list)))
41 | writer.add_scalar('MaskLoss', np.mean(loss_list), global_step=epoch)
42 | print('************************************************************************')
43 |
44 |
45 |
--------------------------------------------------------------------------------
/utils/utils_tri.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 | from scipy.spatial import Delaunay
5 |
6 |
7 | def get_faces(P):
8 | N = P.shape[2]*2
9 | faces = torch.zeros(P.shape[0], N, 3)
10 | for i in range(P.shape[0]):
11 | cP = P[i, :, :, :].squeeze(dim=0).squeeze(dim=0)
12 | tri = Delaunay(cP.detach().cpu().numpy())
13 | tri = torch.tensor(tri.simplices.copy())
14 | nP = tri.shape[0]
15 | last = tri[nP-1, :].unsqueeze(dim=0)
16 | for j in range(N-nP):
17 | tri = torch.cat((tri, last), dim=0)
18 | faces[i, :, :] = tri.unsqueeze(dim=0)
19 | return faces.type(torch.int32)
20 |
21 |
22 | def get_poly(dim=128, n=16, R=16, xx=64, yy=64):
23 | half_dim = dim / 2
24 | P = [np.array([xx + math.floor(math.cos(2 * math.pi / n * x) * R),
25 | yy + math.floor(math.sin(2 * math.pi / n * x) * R)]) for x in range(0, n)]
26 | train_data = torch.zeros(1, 1, n, 2)
27 | for i in range(n):
28 | train_data[0, 0, i, 0] = torch.tensor((P[i][0] - half_dim) / half_dim).clone()
29 | train_data[0, 0, i, 1] = torch.tensor((P[i][1] - half_dim) / half_dim).clone()
30 | vertices = torch.ones((n, 3))
31 | tmp = train_data.squeeze(dim=0).squeeze(dim=0)
32 | vertices[:, 0] = tmp[:, 0]
33 | vertices[:, 1] = tmp[:, 1] * -1
34 | tri = Delaunay(vertices[:, 0:2].numpy())
35 | faces = torch.tensor(tri.simplices.copy())
36 | return train_data, faces
37 |
--------------------------------------------------------------------------------
/utils/utils_vis.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function
2 |
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import os
7 |
8 | import matplotlib
9 | matplotlib.use('agg')
10 | import matplotlib.pyplot as plt
11 |
12 | import importlib
13 | import cv2
14 | import xlrd
15 | # import dar_package.config
16 |
17 |
18 | def fig2img(fig):
19 | """
20 | Convert a Matplotlib figure to a PIL Image in RGBA format
21 | Copied from http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure
22 | """
23 | # put the figure pixmap into a np array
24 | buf = fig2data(fig)
25 | w, h, d = buf.shape
26 | return Image.frombytes("RGBA", (w, h), buf.tostring())
27 |
28 |
29 | def fig2data ( fig ):
30 | """
31 | Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
32 | Copied from http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure
33 | """
34 | # draw the renderer
35 | fig.canvas.draw ()
36 |
37 | # Get the RGBA buffer from the figure
38 | w,h = fig.canvas.get_width_height()
39 | buf = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
40 | buf.shape = (w, h, 4)
41 |
42 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
43 | buf = np.roll(buf, 3, axis=2)
44 | return buf
45 |
46 |
47 | def send_image_to_TB(img, P_init, Mask, P, Ix, Iy, GT, cIoU):
48 | dim = Mask.shape[0]
49 | dim2 = img.shape[0]
50 | fig, ax = plt.subplots(nrows=3, ncols=3, figsize=[10, 10])
51 | fig.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
52 | ax[0, 0].imshow(img)
53 | ax[0, 1].plot(P_init[0, :]*0.5*dim2+0.5*dim2, P_init[1, :]*0.5*dim2+0.5*dim2, 'r--', linewidth=2.0)
54 | ax[0, 1].plot(P[0, :]*0.5*dim2+0.5*dim2, P[1, :]*0.5*dim2+0.5*dim2, color=[0, 1, 0], linewidth=2.0, marker='*')
55 | ax[0, 1].imshow(img)
56 | ax[0, 2].imshow(img)
57 | ax[1, 0].imshow(GT)
58 | ax[1, 1].imshow(Mask)
59 | ax[1, 2].plot(P_init[0, :]*0.5*dim+0.5*dim, P_init[1, :]*0.5*dim+0.5*dim, 'ro')
60 | ax[1, 2].plot(P[0, :]*0.5*dim+0.5*dim, P[1, :]*0.5*dim+0.5*dim, color=[0, 1, 0], linewidth=2.0, marker='*')
61 | ax[1, 2].imshow(Mask)
62 | ax[2, 0].imshow(Ix)
63 | ax[2, 1].imshow(Iy)
64 | ax[2, 2].imshow((Ix**2+Iy**2)**0.5)
65 | ax[0, 0].axis('off')
66 | ax[0, 1].axis('off')
67 | ax[0, 2].axis('off')
68 | ax[1, 0].axis('off')
69 | ax[1, 1].axis('off')
70 | ax[1, 2].axis('off')
71 | ax[2, 0].axis('off')
72 | ax[2, 1].axis('off')
73 | ax[2, 2].axis('off')
74 | fig.suptitle('IoU: ' + str(cIoU))
75 | fig.tight_layout(pad=0, w_pad=0, h_pad=0)
76 | return np.asarray(fig2img(fig))
77 |
--------------------------------------------------------------------------------