├── layers
├── __init__.py
└── loss.py
├── dataloaders
├── __init__.py
├── pascal_map.npy
├── combine_dbs.py
├── sbd.py
├── custom_transforms.py
├── pascal.py
└── helpers.py
├── evaluation
├── __init__.py
├── evaluation.py
└── eval.py
├── networks
├── __init__.py
└── deeplab_resnet.py
├── ims
├── bear.jpg
└── dog-cat.jpg
├── doc
├── dextr.png
└── github_teaser.gif
├── media
├── out1.png
├── out1_txt.png
└── out1.txt
├── requirements.txt
├── models
├── download_pretrained_psp_model.sh
└── download_dextr_model.sh
├── mypath.py
├── eval_all.py
├── README.md
├── demo.py
└── train_pascal.py
/layers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ims/bear.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/ims/bear.jpg
--------------------------------------------------------------------------------
/doc/dextr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/doc/dextr.png
--------------------------------------------------------------------------------
/ims/dog-cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/ims/dog-cat.jpg
--------------------------------------------------------------------------------
/media/out1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/media/out1.png
--------------------------------------------------------------------------------
/media/out1_txt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/media/out1_txt.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | opencv-python
3 | pillow
4 | scikit-learn
5 | scikit-image
6 |
--------------------------------------------------------------------------------
/doc/github_teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/doc/github_teaser.gif
--------------------------------------------------------------------------------
/dataloaders/pascal_map.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/dataloaders/pascal_map.npy
--------------------------------------------------------------------------------
/models/download_pretrained_psp_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget https://data.vision.ee.ethz.ch/csergi/share/DEXTR/MS_DeepLab_resnet_trained_VOC.pth
3 |
--------------------------------------------------------------------------------
/models/download_dextr_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Model trained on PASCAL + SBD
4 | wget https://data.vision.ee.ethz.ch/csergi/share/DEXTR/dextr_pascal-sbd.pth
5 |
--------------------------------------------------------------------------------
/mypath.py:
--------------------------------------------------------------------------------
1 |
2 | class Path(object):
3 | @staticmethod
4 | def db_root_dir(database):
5 | if database == 'pascal':
6 | return '/path/to/PASCAL/VOC2012' # folder that contains VOCdevkit/.
7 |
8 | elif database == 'sbd':
9 | return '/path/to/SBD/' # folder with img/, inst/, cls/, etc.
10 | else:
11 | print('Database {} not available.'.format(database))
12 | raise NotImplementedError
13 |
14 | @staticmethod
15 | def models_dir():
16 | return 'models/'
17 |
--------------------------------------------------------------------------------
/evaluation/evaluation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def jaccard(annotation, segmentation, void_pixels=None):
5 |
6 | assert(annotation.shape == segmentation.shape)
7 |
8 | if void_pixels is None:
9 | void_pixels = np.zeros_like(annotation)
10 | assert(void_pixels.shape == annotation.shape)
11 |
12 | annotation = annotation.astype(np.bool)
13 | segmentation = segmentation.astype(np.bool)
14 | void_pixels = void_pixels.astype(np.bool)
15 | if np.isclose(np.sum(annotation & np.logical_not(void_pixels)), 0) and np.isclose(np.sum(segmentation & np.logical_not(void_pixels)), 0):
16 | return 1
17 | else:
18 | return np.sum(((annotation & segmentation) & np.logical_not(void_pixels))) / \
19 | np.sum(((annotation | segmentation) & np.logical_not(void_pixels)), dtype=np.float32)
20 |
21 |
--------------------------------------------------------------------------------
/media/out1.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | : ;*~
7 | *XO*: ;:~QzC
8 | YuUbbpqbwXXb
9 | ~Zwbbpdbbppa
10 | :adQzZYUOvQbo,,:
11 | hqZJUr))tYCZwp~~*`
12 | :mzf}]]}{}1{\tJkh~`
13 | p(?_+_?{)|}]}vdboh.
14 | ^bqL/\\+_-?|-]umpkha~
15 | ;akbUvx]-?]]}{{f0wOutxq
16 | ,hkz{{1)/Cpdddk*,^mr)|U
17 | *Y|}]][|k U1)u
18 | ;aL{--1. o/{n
19 |
--------------------------------------------------------------------------------
/eval_all.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | from torch.utils.data import DataLoader
4 | from evaluation.eval import eval_one_result
5 | import dataloaders.pascal as pascal
6 |
7 | exp_root_dir = './'
8 |
9 | method_names = []
10 | method_names.append('run_0')
11 |
12 | if __name__ == '__main__':
13 |
14 | # Dataloader
15 | dataset = pascal.VOCSegmentation(transform=None, retname=True)
16 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
17 |
18 | # Iterate through all the different methods
19 | for method in method_names:
20 | results_folder = os.path.join(exp_root_dir, method, 'Results')
21 |
22 | filename = os.path.join(exp_root_dir, 'eval_results', method.replace('/', '-') + '.txt')
23 | if not os.path.exists(os.path.join(exp_root_dir, 'eval_results')):
24 | os.makedirs(os.path.join(exp_root_dir, 'eval_results'))
25 |
26 | if os.path.isfile(filename):
27 | with open(filename, 'r') as f:
28 | val = float(f.read())
29 | else:
30 | print("Evaluating method: {}".format(method))
31 | jaccards = eval_one_result(dataloader, results_folder, mask_thres=0.8)
32 | val = jaccards["all_jaccards"].mean()
33 |
34 | # Show mean and store result
35 | print("Result for {:<80}: {}".format(method, str.format("{0:.1f}", 100*val)))
36 | with open(filename, 'w') as f:
37 | f.write(str(val))
38 |
--------------------------------------------------------------------------------
/layers/loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import numpy as np
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | def class_balanced_cross_entropy_loss(output, label, size_average=True, batch_average=True, void_pixels=None):
8 | """Define the class balanced cross entropy loss to train the network
9 | Args:
10 | output: Output of the network
11 | label: Ground truth label
12 | size_average: return per-element (pixel) average loss
13 | batch_average: return per-batch average loss
14 | void_pixels: pixels to ignore from the loss
15 | Returns:
16 | Tensor that evaluates the loss
17 | """
18 | assert(output.size() == label.size())
19 |
20 | labels = torch.ge(label, 0.5).float()
21 |
22 | num_labels_pos = torch.sum(labels)
23 | num_labels_neg = torch.sum(1.0 - labels)
24 | num_total = num_labels_pos + num_labels_neg
25 |
26 | output_gt_zero = torch.ge(output, 0).float()
27 | loss_val = torch.mul(output, (labels - output_gt_zero)) - torch.log(
28 | 1 + torch.exp(output - 2 * torch.mul(output, output_gt_zero)))
29 |
30 | loss_pos_pix = -torch.mul(labels, loss_val)
31 | loss_neg_pix = -torch.mul(1.0 - labels, loss_val)
32 |
33 | if void_pixels is not None:
34 | w_void = torch.le(void_pixels, 0.5).float()
35 | loss_pos_pix = torch.mul(w_void, loss_pos_pix)
36 | loss_neg_pix = torch.mul(w_void, loss_neg_pix)
37 | num_total = num_total - torch.ge(void_pixels, 0.5).float().sum()
38 |
39 | loss_pos = torch.sum(loss_pos_pix)
40 | loss_neg = torch.sum(loss_neg_pix)
41 |
42 | final_loss = num_labels_neg / num_total * loss_pos + num_labels_pos / num_total * loss_neg
43 |
44 | if size_average:
45 | final_loss /= np.prod(label.size())
46 | elif batch_average:
47 | final_loss /= label.size()[0]
48 |
49 | return final_loss
50 |
--------------------------------------------------------------------------------
/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import cv2
4 | import numpy as np
5 | from PIL import Image
6 |
7 | import dataloaders.helpers as helpers
8 | import evaluation.evaluation as evaluation
9 |
10 |
11 | def eval_one_result(loader, folder, one_mask_per_image=False, mask_thres=0.5, use_void_pixels=True, custom_box=False):
12 | def mAPr(per_cat, thresholds):
13 | n_cat = len(per_cat)
14 | all_apr = np.zeros(len(thresholds))
15 | for ii, th in enumerate(thresholds):
16 | per_cat_recall = np.zeros(n_cat)
17 | for jj, categ in enumerate(per_cat.keys()):
18 | per_cat_recall[jj] = np.sum(np.array(per_cat[categ]) > th)/len(per_cat[categ])
19 |
20 | all_apr[ii] = per_cat_recall.mean()
21 |
22 | return all_apr.mean()
23 |
24 | # Allocate
25 | eval_result = dict()
26 | eval_result["all_jaccards"] = np.zeros(len(loader))
27 | eval_result["all_percent"] = np.zeros(len(loader))
28 | eval_result["meta"] = []
29 | eval_result["per_categ_jaccard"] = dict()
30 |
31 | # Iterate
32 | for i, sample in enumerate(loader):
33 |
34 | if i % 500 == 0:
35 | print('Evaluating: {} of {} objects'.format(i, len(loader)))
36 |
37 | # Load result
38 | if not one_mask_per_image:
39 | filename = os.path.join(folder,
40 | sample["meta"]["image"][0] + '-' + sample["meta"]["object"][0] + '.png')
41 | else:
42 | filename = os.path.join(folder,
43 | sample["meta"]["image"][0] + '.png')
44 | mask = np.array(Image.open(filename)).astype(np.float32) / 255.
45 | gt = np.squeeze(helpers.tens2image(sample["gt"]))
46 | if use_void_pixels:
47 | void_pixels = np.squeeze(helpers.tens2image(sample["void_pixels"]))
48 | if mask.shape != gt.shape:
49 | mask = cv2.resize(mask, gt.shape[::-1], interpolation=cv2.INTER_CUBIC)
50 |
51 | # Threshold
52 | mask = (mask > mask_thres)
53 | if use_void_pixels:
54 | void_pixels = (void_pixels > 0.5)
55 |
56 | # Evaluate
57 | if use_void_pixels:
58 | eval_result["all_jaccards"][i] = evaluation.jaccard(gt, mask, void_pixels)
59 | else:
60 | eval_result["all_jaccards"][i] = evaluation.jaccard(gt, mask)
61 |
62 | if custom_box:
63 | box = np.squeeze(helpers.tens2image(sample["box"]))
64 | bb = helpers.get_bbox(box)
65 | else:
66 | bb = helpers.get_bbox(gt)
67 |
68 | mask_crop = helpers.crop_from_bbox(mask, bb)
69 | if use_void_pixels:
70 | non_void_pixels_crop = helpers.crop_from_bbox(np.logical_not(void_pixels), bb)
71 | gt_crop = helpers.crop_from_bbox(gt, bb)
72 | if use_void_pixels:
73 | eval_result["all_percent"][i] = np.sum((gt_crop != mask_crop) & non_void_pixels_crop)/np.sum(non_void_pixels_crop)
74 | else:
75 | eval_result["all_percent"][i] = np.sum((gt_crop != mask_crop))/mask_crop.size
76 | # Store in per category
77 | if "category" in sample["meta"]:
78 | cat = sample["meta"]["category"][0]
79 | else:
80 | cat = 1
81 | if cat not in eval_result["per_categ_jaccard"]:
82 | eval_result["per_categ_jaccard"][cat] = []
83 | eval_result["per_categ_jaccard"][cat].append(eval_result["all_jaccards"][i])
84 |
85 | # Store meta
86 | eval_result["meta"].append(sample["meta"])
87 |
88 | # Compute some stats
89 | eval_result["mAPr0.5"] = mAPr(eval_result["per_categ_jaccard"], [0.5])
90 | eval_result["mAPr0.7"] = mAPr(eval_result["per_categ_jaccard"], [0.7])
91 | eval_result["mAPr-vol"] = mAPr(eval_result["per_categ_jaccard"], np.linspace(0.1, 0.9, 9))
92 |
93 | return eval_result
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/dataloaders/combine_dbs.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 |
4 | class CombineDBs(data.Dataset):
5 | def __init__(self, dataloaders, excluded=None):
6 | self.dataloaders = dataloaders
7 | self.excluded = excluded
8 | self.im_ids = []
9 |
10 | # Combine object lists
11 | for dl in dataloaders:
12 | for elem in dl.im_ids:
13 | if elem not in self.im_ids:
14 | self.im_ids.append(elem)
15 |
16 | # Exclude
17 | if excluded:
18 | for dl in excluded:
19 | for elem in dl.im_ids:
20 | if elem in self.im_ids:
21 | self.im_ids.remove(elem)
22 |
23 | # Get object pointers
24 | self.obj_list = []
25 | self.im_list = []
26 | new_im_ids = []
27 | obj_counter = 0
28 | num_images = 0
29 | for ii, dl in enumerate(dataloaders):
30 | for jj, curr_im_id in enumerate(dl.im_ids):
31 | if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids):
32 | flag = False
33 | new_im_ids.append(curr_im_id)
34 | for kk in range(len(dl.obj_dict[curr_im_id])):
35 | if dl.obj_dict[curr_im_id][kk] != -1:
36 | self.obj_list.append({'db_ii': ii, 'obj_ii': dl.obj_list.index([jj, kk])})
37 | flag = True
38 | obj_counter += 1
39 | self.im_list.append({'db_ii': ii, 'im_ii': jj})
40 | if flag:
41 | num_images += 1
42 |
43 | self.im_ids = new_im_ids
44 | print('Combined number of images: {:d}\nCombined number of objects: {:d}'.format(num_images, len(self.obj_list)))
45 |
46 | def __getitem__(self, index):
47 |
48 | _db_ii = self.obj_list[index]["db_ii"]
49 | _obj_ii = self.obj_list[index]['obj_ii']
50 | sample = self.dataloaders[_db_ii].__getitem__(_obj_ii)
51 |
52 | if 'meta' in sample.keys():
53 | sample['meta']['db'] = str(self.dataloaders[_db_ii])
54 |
55 | return sample
56 |
57 | def __len__(self):
58 | return len(self.obj_list)
59 |
60 | def __str__(self):
61 | include_db = [str(db) for db in self.dataloaders]
62 | exclude_db = [str(db) for db in self.excluded]
63 | return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db)
64 |
65 |
66 | if __name__ == "__main__":
67 | import matplotlib.pyplot as plt
68 | import dataloaders.helpers as helpers
69 | import dataloaders.pascal as pascal
70 | import dataloaders.sbd as sbd
71 | import torch
72 | import numpy as np
73 | import dataloaders.custom_transforms as tr
74 | from torchvision import transforms
75 |
76 | transform = transforms.Compose([tr.FixedResize({'image': (512, 512), 'gt': (512, 512)}),
77 | tr.ToTensor()])
78 |
79 | pascal_voc_val = pascal.VOCSegmentation(split='val', transform=transform, retname=True)
80 | sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=transform, retname=True)
81 | pascal_voc_train = pascal.VOCSegmentation(split='train', transform=transform, retname=True)
82 |
83 | dataset = CombineDBs([pascal_voc_train, sbd], excluded=[pascal_voc_val])
84 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)
85 |
86 | for ii, sample in enumerate(dataloader):
87 | for jj in range(sample["image"].size()[0]):
88 | plt.figure()
89 | max_img = np.max(sample["image"][jj].numpy())
90 | overlay = helpers.overlay_mask(helpers.tens2image(sample["image"][jj])/max_img, helpers.tens2image(sample["gt"][jj]))
91 | plt.imshow(overlay)
92 | plt.title(sample["meta"])
93 | if ii == 5:
94 | break
95 |
96 | plt.show(block=True)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ### etos-deepcut
3 |
4 | A tool for object segmentation from extreme points.
5 |
6 | ### About Deep Extreme Cut (DEXTR)
7 |
8 | Visit [project page](http://www.vision.ee.ethz.ch/~cvlsegmentation/dextr) for accessing the paper, and the pre-computed results.
9 |
10 | 
11 |
12 | This is based the implementation of `Deep Extreme Cut (DEXTR)`, for object segmentation from extreme points.
13 |
14 | ### Installation
15 |
16 | 0. Clone the repo:
17 | ```Shell
18 | git clone https://github.com/etosworld/etos-deepcut
19 | cd etos-deepcut
20 | ```
21 |
22 | 1. Install dependencies:
23 | ```Shell
24 | pip install -r requirements.txt
25 | ```
26 |
27 | 2. Download the model by running the script inside ```models/```:
28 | ```Shell
29 | cd models/
30 | chmod +x download_dextr_model.sh
31 | ./download_dextr_model.sh
32 | cd ..
33 | ```
34 | The default model is trained on PASCAL VOC Segmentation train + SBD (10582 images). To download models trained on PASCAL VOC Segmentation train or COCO, please visit [project page](http://www.vision.ee.ethz.ch/~cvlsegmentation/dextr/#downloads), or keep scrolling till the end of this README.
35 |
36 | 3. To try the demo version of etos-deepcut, please run:
37 | ```Shell
38 | python demo.py
39 | ```
40 | If installed correctly, the result should look like this:
41 |

42 |
43 | ### image2txt
44 |
45 | we have implemented the [image2txt](https://github.com/etosworld/etos-img2txt) function , for each segmented object, an image and text file would be saved.enjoy!
46 |
47 |  
48 |
49 |
50 | ### Training
51 |
52 | To train and evaluate etos-deepcut on PASCAL (or PASCAL + SBD), please follow these additional steps:
53 |
54 | 4. Install tensorboard (integrated with PyTorch).
55 | ```Shell
56 | pip install tensorboard tensorboardx
57 | ```
58 |
59 | 5. Download the pre-trained PSPNet model for semantic segmentation, taken from [this](https://github.com/isht7/pytorch-deeplab-resnet) repository.
60 | ```Shell
61 | cd models/
62 | chmod +x download_pretrained_psp_model.sh
63 | ./download_pretrained_psp_model.sh
64 | cd ..
65 | ```
66 | 6. Set the paths in ```mypath.py```, so that they point to the location of PASCAL/SBD dataset.
67 |
68 | 7. Run ```python train_pascal.py```, after changing the default parameters, if necessary (eg. gpu_id).
69 |
70 | Enjoy!!
71 |
72 | ### Pre-trained models
73 |
74 | you can download the following DEXTR models, pre-trained on:
75 | * [PASCAL + SBD](https://data.vision.ee.ethz.ch/kmaninis/share/DEXTR/Downloads/models/dextr_pascal-sbd.pth), trained on PASCAL VOC Segmentation train + SBD (10582 images). Achieves mIoU of 91.5% on PASCAL VOC Segmentation val.
76 | * [PASCAL](https://data.vision.ee.ethz.ch/kmaninis/share/DEXTR/Downloads/models/dextr_pascal.pth), trained on PASCAL VOC Segmentation train (1464 images). Achieves mIoU of 90.5% on PASCAL VOC Segmentation val.
77 | * [COCO](https://data.vision.ee.ethz.ch/kmaninis/share/DEXTR/Downloads/models/dextr_coco.pth), trained on COCO train 2014 (82783 images). Achieves mIoU of 87.8% on PASCAL VOC Segmentation val.
78 |
79 | ### TODO
80 |
81 | - to support deep extreme video cut
82 |
83 | ### Citation
84 |
85 | @Inproceedings{Man+18,
86 | Title = {Deep Extreme Cut: From Extreme Points to Object Segmentation},
87 | Author = {K.K. Maninis and S. Caelles and J. Pont-Tuset and L. {Van Gool}},
88 | Booktitle = {Computer Vision and Pattern Recognition (CVPR)},
89 | Year = {2018}
90 | }
91 |
92 | @InProceedings{Pap+17,
93 | Title = {Extreme clicking for efficient object annotation},
94 | Author = {D.P. Papadopoulos and J. Uijlings and F. Keller and V. Ferrari},
95 | Booktitle = {ICCV},
96 | Year = {2017}
97 | }
98 |
99 |
100 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from PIL import Image
5 | import numpy as np
6 | from matplotlib import pyplot as plt
7 |
8 | from torch.nn.functional import upsample
9 |
10 | import networks.deeplab_resnet as resnet
11 | from mypath import Path
12 | from dataloaders import helpers as helpers
13 |
14 | modelName = 'dextr_pascal-sbd'
15 | pad = 50
16 | thres = 0.8
17 | gpu_id = 0
18 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
19 |
20 | #img2txt
21 | CHAR_LIST = " ;:,^`'.*~oahkbdpqwmZO0QLCJUYXzcvunxrjft/\|()1{}[]?-_+\<>i!lI$@B%8&WM#"
22 | num_chars = len(CHAR_LIST)
23 | num_cols = 50
24 |
25 | # Create the network and load the weights
26 | net = resnet.resnet101(1, nInputChannels=4, classifier='psp')
27 | print("Initializing weights from: {}".format(os.path.join(Path.models_dir(), modelName + '.pth')))
28 | state_dict_checkpoint = torch.load(os.path.join(Path.models_dir(), modelName + '.pth'),
29 | map_location=lambda storage, loc: storage)
30 | # Remove the prefix .module from the model when it is trained using DataParallel
31 | if 'module.' in list(state_dict_checkpoint.keys())[0]:
32 | new_state_dict = OrderedDict()
33 | for k, v in state_dict_checkpoint.items():
34 | name = k[7:] # remove `module.` from multi-gpu training
35 | new_state_dict[name] = v
36 | else:
37 | new_state_dict = state_dict_checkpoint
38 | net.load_state_dict(new_state_dict)
39 | net.eval()
40 | net.to(device)
41 |
42 | # Read image and click the points
43 | image = np.array(Image.open('ims/dog-cat.jpg'))
44 | plt.ion()
45 | plt.axis('off')
46 | plt.imshow(image)
47 | plt.title('Click the four extreme points of the objects\nHit enter when done (do not close the window)')
48 |
49 | results = []
50 |
51 | idx = 0
52 |
53 | with torch.no_grad():
54 | while 1:
55 |
56 | extreme_points_ori = np.array(plt.ginput(4, timeout=0)).astype(np.int)
57 |
58 | # Crop image to the bounding box from the extreme points and resize
59 | bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=pad, zero_pad=True)
60 | crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
61 | resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32)
62 |
63 | # Generate extreme point heat map normalized to image values
64 | extreme_points = extreme_points_ori - [np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1])] + [pad,
65 | pad]
66 | extreme_points = (512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
67 | extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10)
68 | extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)
69 |
70 | # Concatenate inputs and convert to tensor
71 | input_dextr = np.concatenate((resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
72 | inputs = torch.from_numpy(input_dextr.transpose((2, 0, 1))[np.newaxis, ...])
73 |
74 | # Run a forward pass
75 | inputs = inputs.to(device)
76 | outputs = net.forward(inputs)
77 | outputs = upsample(outputs, size=(512, 512), mode='bilinear', align_corners=True)
78 | outputs = outputs.to(torch.device('cpu'))
79 |
80 | pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
81 | pred = 1 / (1 + np.exp(-pred))
82 | pred = np.squeeze(pred)
83 | result = helpers.crop2fullmask(pred, bbox, im_size=image.shape[:2], zero_pad=True, relax=pad) > thres
84 |
85 | #save cut image and image2txt file
86 | #cut_mask = result.astype(int)
87 | height, width,_ = image.shape
88 | out_img = np.zeros((height, width,3),np.uint8)
89 | out_img[result]=image[result]
90 | #print(out_img)
91 | out_gray=np.dot(out_img[...,:3], [0.299, 0.587, 0.114]).astype(np.uint8)
92 | cell_width = width / num_cols
93 | cell_height = 2 * cell_width
94 | num_rows = int(height / cell_height)
95 | if num_cols > width or num_rows > height:
96 | print("Too many columns or rows. Use default setting")
97 | cell_width = 6
98 | cell_height = 12
99 | num_cols = int(width / cell_width)
100 | num_rows = int(height / cell_height)
101 |
102 | idx = idx + 1
103 | output_file = open('out%d.txt'%idx, 'w')
104 | for i in range(num_rows):
105 | for j in range(num_cols):
106 | output_file.write(
107 | CHAR_LIST[min(int(np.mean(out_gray[int(i * cell_height):min(int((i + 1) * cell_height), height),
108 | int(j * cell_width):min(int((j + 1) * cell_width),
109 | width)]) * num_chars / 255), num_chars - 1)])
110 | output_file.write("\n")
111 | output_file.close()
112 |
113 |
114 | im = Image.fromarray(out_img)
115 | #im = Image.fromarray(out_gray)
116 |
117 | im.save("out%d.png"%idx)
118 |
119 | results.append(result)
120 |
121 | # Plot the results
122 | plt.imshow(helpers.overlay_masks(image / 255, results))
123 | plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1], 'gx')
124 |
--------------------------------------------------------------------------------
/dataloaders/sbd.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import torch, cv2
4 | import errno
5 | import hashlib
6 | import json
7 | import os
8 | import sys
9 | import tarfile
10 |
11 | import numpy as np
12 | import scipy.io
13 | import torch.utils.data as data
14 | from PIL import Image
15 | from six.moves import urllib
16 | from mypath import Path
17 |
18 |
19 | class SBDSegmentation(data.Dataset):
20 |
21 | URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
22 | FILE = "benchmark.tgz"
23 | MD5 = '82b4d87ceb2ed10f6038a1cba92111cb'
24 |
25 | def __init__(self,
26 | root=Path.db_root_dir('sbd'),
27 | split='val',
28 | transform=None,
29 | download=False,
30 | preprocess=False,
31 | area_thres=0,
32 | retname=True):
33 |
34 | # Store parameters
35 | self.root = root
36 | self.transform = transform
37 | if isinstance(split, str):
38 | self.split = [split]
39 | else:
40 | split.sort()
41 | self.split = split
42 | self.area_thres = area_thres
43 | self.retname = retname
44 |
45 | # Where to find things according to the author's structure
46 | self.dataset_dir = os.path.join(self.root, 'benchmark_RELEASE', 'dataset')
47 | _mask_dir = os.path.join(self.dataset_dir, 'inst')
48 | _image_dir = os.path.join(self.dataset_dir, 'img')
49 |
50 | if self.area_thres != 0:
51 | self.obj_list_file = os.path.join(self.dataset_dir, '_'.join(self.split) + '_instances_area_thres-' +
52 | str(area_thres) + '.txt')
53 | else:
54 | self.obj_list_file = os.path.join(self.dataset_dir, '_'.join(self.split) + '_instances' + '.txt')
55 |
56 | # Download dataset?
57 | if download:
58 | self._download()
59 | if not self._check_integrity():
60 | raise RuntimeError('Dataset file downloaded is corrupted.')
61 |
62 | # Get list of all images from the split and check that the files exist
63 | self.im_ids = []
64 | self.images = []
65 | self.masks = []
66 | for splt in self.split:
67 | with open(os.path.join(self.dataset_dir, splt+'.txt'), "r") as f:
68 | lines = f.read().splitlines()
69 |
70 | for line in lines:
71 | _image = os.path.join(_image_dir, line + ".jpg")
72 | _mask = os.path.join(_mask_dir, line + ".mat")
73 | assert os.path.isfile(_image)
74 | assert os.path.isfile(_mask)
75 | self.im_ids.append(line)
76 | self.images.append(_image)
77 | self.masks.append(_mask)
78 |
79 | assert (len(self.images) == len(self.masks))
80 |
81 | # Precompute the list of objects and their categories for each image
82 | if (not self._check_preprocess()) or preprocess:
83 | print('Preprocessing SBD dataset, this will take long, but it will be done only once.')
84 | self._preprocess()
85 |
86 | # Build the list of objects
87 | self.obj_list = []
88 | num_images = 0
89 | for ii in range(len(self.im_ids)):
90 | if self.im_ids[ii] in self.obj_dict.keys():
91 | flag = False
92 | for jj in range(len(self.obj_dict[self.im_ids[ii]])):
93 | if self.obj_dict[self.im_ids[ii]][jj] != -1:
94 | self.obj_list.append([ii, jj])
95 | flag = True
96 | if flag:
97 | num_images += 1
98 |
99 | # Display stats
100 | print('Number of images: {:d}\nNumber of objects: {:d}'.format(num_images, len(self.obj_list)))
101 |
102 | def __getitem__(self, index):
103 |
104 | _img, _target = self._make_img_gt_point_pair(index)
105 | _void_pixels = (_target == 255).astype(np.float32)
106 | sample = {'image': _img, 'gt': _target, 'void_pixels': _void_pixels}
107 |
108 | if self.retname:
109 | _im_ii = self.obj_list[index][0]
110 | _obj_ii = self.obj_list[index][1]
111 | sample['meta'] = {'image': str(self.im_ids[_im_ii]),
112 | 'object': str(_obj_ii),
113 | 'im_size': (_img.shape[0], _img.shape[1]),
114 | 'category': self.obj_dict[self.im_ids[_im_ii]][_obj_ii]}
115 |
116 | if self.transform is not None:
117 | sample = self.transform(sample)
118 |
119 | return sample
120 |
121 | def __len__(self):
122 | return len(self.obj_list)
123 |
124 | def _check_integrity(self):
125 | _fpath = os.path.join(self.root, self.FILE)
126 | if not os.path.isfile(_fpath):
127 | print("{} does not exist".format(_fpath))
128 | return False
129 | _md5c = hashlib.md5(open(_fpath, 'rb').read()).hexdigest()
130 | if _md5c != self.MD5:
131 | print(" MD5({}) did not match MD5({}) expected for {}".format(
132 | _md5c, self.MD5, _fpath))
133 | return False
134 | return True
135 |
136 | def _check_preprocess(self):
137 | # Check that the file with categories is there and with correct size
138 | _obj_list_file = self.obj_list_file
139 | if not os.path.isfile(_obj_list_file):
140 | return False
141 | else:
142 | self.obj_dict = json.load(open(_obj_list_file, 'r'))
143 | return list(np.sort([str(x) for x in self.obj_dict.keys()])) == list(np.sort(self.im_ids))
144 |
145 | def _preprocess(self):
146 | # Get all object instances and their category
147 | self.obj_dict = {}
148 | obj_counter = 0
149 | for ii in range(len(self.im_ids)):
150 | # Read object masks and get number of objects
151 | tmp = scipy.io.loadmat(self.masks[ii])
152 | _mask = tmp["GTinst"][0]["Segmentation"][0]
153 | _cat_ids = tmp["GTinst"][0]["Categories"][0].astype(int)
154 |
155 | _mask_ids = np.unique(_mask)
156 | n_obj = _mask_ids[-1]
157 | assert(n_obj == len(_cat_ids))
158 |
159 | for jj in range(n_obj):
160 | temp = np.where(_mask == jj + 1)
161 | obj_area = len(temp[0])
162 | if obj_area < self.area_thres:
163 | _cat_ids[jj] = -1
164 | obj_counter += 1
165 |
166 | self.obj_dict[self.im_ids[ii]] = np.squeeze(_cat_ids, 1).tolist()
167 |
168 | # Save it to file for future reference
169 | with open(self.obj_list_file, 'w') as outfile:
170 | outfile.write('{{\n\t"{:s}": {:s}'.format(self.im_ids[0], json.dumps(self.obj_dict[self.im_ids[0]])))
171 | for ii in range(1, len(self.im_ids)):
172 | outfile.write(',\n\t"{:s}": {:s}'.format(self.im_ids[ii], json.dumps(self.obj_dict[self.im_ids[ii]])))
173 | outfile.write('\n}\n')
174 |
175 | print('Pre-processing finished')
176 |
177 | def _download(self):
178 | _fpath = os.path.join(self.root, self.FILE)
179 |
180 | try:
181 | os.makedirs(self.root)
182 | except OSError as e:
183 | if e.errno == errno.EEXIST:
184 | pass
185 | else:
186 | raise
187 |
188 | if self._check_integrity():
189 | print('Files already downloaded and verified')
190 | return
191 | else:
192 | print('Downloading ' + self.URL + ' to ' + _fpath)
193 |
194 | def _progress(count, block_size, total_size):
195 | sys.stdout.write('\r>> %s %.1f%%' %
196 | (_fpath, float(count * block_size) /
197 | float(total_size) * 100.0))
198 | sys.stdout.flush()
199 |
200 | urllib.request.urlretrieve(self.URL, _fpath, _progress)
201 |
202 | # extract file
203 | cwd = os.getcwd()
204 | print('Extracting tar file')
205 | tar = tarfile.open(_fpath)
206 | os.chdir(self.root)
207 | tar.extractall()
208 | tar.close()
209 | os.chdir(cwd)
210 | print('Done!')
211 |
212 | def _make_img_gt_point_pair(self, index):
213 | _im_ii = self.obj_list[index][0]
214 | _obj_ii = self.obj_list[index][1]
215 |
216 | # Read Image
217 | _img = np.array(Image.open(self.images[_im_ii]).convert('RGB')).astype(np.float32)
218 |
219 | # Read Taret object
220 | _tmp = scipy.io.loadmat(self.masks[_im_ii])["GTinst"][0]["Segmentation"][0]
221 | _target = (_tmp == (_obj_ii + 1)).astype(np.float32)
222 |
223 | return _img, _target
224 |
225 | def __str__(self):
226 | return 'SBDSegmentation(split='+str(self.split)+', area_thres='+str(self.area_thres)+')'
227 |
228 |
229 | if __name__ == '__main__':
230 | import matplotlib.pyplot as plt
231 | import dataloaders.helpers as helpers
232 | import torch
233 | import torchvision.transforms as transforms
234 | import dataloaders.custom_transforms as tr
235 |
236 | transform = transforms.Compose([tr.ToTensor()])
237 | dataset = SBDSegmentation(transform=transform)
238 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
239 |
240 | for i, data in enumerate(dataloader):
241 | plt.figure()
242 | plt.imshow(helpers.tens2image(data['image'])/255)
243 | plt.figure()
244 | plt.imshow(helpers.tens2image(data['gt'])[:, :, 0])
245 | if i == 10:
246 | break
247 |
248 | plt.show(block=True)
249 |
--------------------------------------------------------------------------------
/dataloaders/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import torch, cv2
2 |
3 | import numpy.random as random
4 | import numpy as np
5 | import dataloaders.helpers as helpers
6 |
7 |
8 | class ScaleNRotate(object):
9 | """Scale (zoom-in, zoom-out) and Rotate the image and the ground truth.
10 | Args:
11 | two possibilities:
12 | 1. rots (tuple): (minimum, maximum) rotation angle
13 | scales (tuple): (minimum, maximum) scale
14 | 2. rots [list]: list of fixed possible rotation angles
15 | scales [list]: list of fixed possible scales
16 | """
17 | def __init__(self, rots=(-30, 30), scales=(.75, 1.25), semseg=False):
18 | assert (isinstance(rots, type(scales)))
19 | self.rots = rots
20 | self.scales = scales
21 | self.semseg = semseg
22 |
23 | def __call__(self, sample):
24 |
25 | if type(self.rots) == tuple:
26 | # Continuous range of scales and rotations
27 | rot = (self.rots[1] - self.rots[0]) * random.random() - \
28 | (self.rots[1] - self.rots[0])/2
29 |
30 | sc = (self.scales[1] - self.scales[0]) * random.random() - \
31 | (self.scales[1] - self.scales[0]) / 2 + 1
32 | elif type(self.rots) == list:
33 | # Fixed range of scales and rotations
34 | rot = self.rots[random.randint(0, len(self.rots))]
35 | sc = self.scales[random.randint(0, len(self.scales))]
36 |
37 | for elem in sample.keys():
38 | if 'meta' in elem:
39 | continue
40 |
41 | tmp = sample[elem]
42 |
43 | h, w = tmp.shape[:2]
44 | center = (w / 2, h / 2)
45 | assert(center != 0) # Strange behaviour warpAffine
46 | M = cv2.getRotationMatrix2D(center, rot, sc)
47 |
48 | if ((tmp == 0) | (tmp == 1)).all():
49 | flagval = cv2.INTER_NEAREST
50 | elif 'gt' in elem and self.semseg:
51 | flagval = cv2.INTER_NEAREST
52 | else:
53 | flagval = cv2.INTER_CUBIC
54 | tmp = cv2.warpAffine(tmp, M, (w, h), flags=flagval)
55 |
56 | sample[elem] = tmp
57 |
58 | return sample
59 |
60 | def __str__(self):
61 | return 'ScaleNRotate:(rot='+str(self.rots)+',scale='+str(self.scales)+')'
62 |
63 |
64 | class FixedResize(object):
65 | """Resize the image and the ground truth to specified resolution.
66 | Args:
67 | resolutions (dict): the list of resolutions
68 | """
69 | def __init__(self, resolutions=None, flagvals=None):
70 | self.resolutions = resolutions
71 | self.flagvals = flagvals
72 | if self.flagvals is not None:
73 | assert(len(self.resolutions) == len(self.flagvals))
74 |
75 | def __call__(self, sample):
76 |
77 | # Fixed range of scales
78 | if self.resolutions is None:
79 | return sample
80 |
81 | elems = list(sample.keys())
82 |
83 | for elem in elems:
84 |
85 | if 'meta' in elem or 'bbox' in elem or ('extreme_points_coord' in elem and elem not in self.resolutions):
86 | continue
87 | if 'extreme_points_coord' in elem and elem in self.resolutions:
88 | bbox = sample['bbox']
89 | crop_size = np.array([bbox[3]-bbox[1]+1, bbox[4]-bbox[2]+1])
90 | res = np.array(self.resolutions[elem]).astype(np.float32)
91 | sample[elem] = np.round(sample[elem]*res/crop_size).astype(np.int)
92 | continue
93 | if elem in self.resolutions:
94 | if self.resolutions[elem] is None:
95 | continue
96 | if isinstance(sample[elem], list):
97 | if sample[elem][0].ndim == 3:
98 | output_size = np.append(self.resolutions[elem], [3, len(sample[elem])])
99 | else:
100 | output_size = np.append(self.resolutions[elem], len(sample[elem]))
101 | tmp = sample[elem]
102 | sample[elem] = np.zeros(output_size, dtype=np.float32)
103 | for ii, crop in enumerate(tmp):
104 | if self.flagvals is None:
105 | sample[elem][..., ii] = helpers.fixed_resize(crop, self.resolutions[elem])
106 | else:
107 | sample[elem][..., ii] = helpers.fixed_resize(crop, self.resolutions[elem], flagval=self.flagvals[elem])
108 | else:
109 | if self.flagvals is None:
110 | sample[elem] = helpers.fixed_resize(sample[elem], self.resolutions[elem])
111 | else:
112 | sample[elem] = helpers.fixed_resize(sample[elem], self.resolutions[elem], flagval=self.flagvals[elem])
113 | else:
114 | del sample[elem]
115 |
116 | return sample
117 |
118 | def __str__(self):
119 | return 'FixedResize:'+str(self.resolutions)
120 |
121 |
122 | class RandomHorizontalFlip(object):
123 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5."""
124 |
125 | def __call__(self, sample):
126 |
127 | if random.random() < 0.5:
128 | for elem in sample.keys():
129 | if 'meta' in elem:
130 | continue
131 | tmp = sample[elem]
132 | tmp = cv2.flip(tmp, flipCode=1)
133 | sample[elem] = tmp
134 |
135 | return sample
136 |
137 | def __str__(self):
138 | return 'RandomHorizontalFlip'
139 |
140 |
141 | class ExtremePoints(object):
142 | """
143 | Returns the four extreme points (left, right, top, bottom) (with some random perturbation) in a given binary mask
144 | sigma: sigma of Gaussian to create a heatmap from a point
145 | pert: number of pixels fo the maximum perturbation
146 | elem: which element of the sample to choose as the binary mask
147 | """
148 | def __init__(self, sigma=10, pert=0, elem='gt'):
149 | self.sigma = sigma
150 | self.pert = pert
151 | self.elem = elem
152 |
153 | def __call__(self, sample):
154 | if sample[self.elem].ndim == 3:
155 | raise ValueError('ExtremePoints not implemented for multiple object per image.')
156 | _target = sample[self.elem]
157 | if np.max(_target) == 0:
158 | sample['extreme_points'] = np.zeros(_target.shape, dtype=_target.dtype) # TODO: handle one_mask_per_point case
159 | else:
160 | _points = helpers.extreme_points(_target, self.pert)
161 | sample['extreme_points'] = helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False)
162 |
163 | return sample
164 |
165 | def __str__(self):
166 | return 'ExtremePoints:(sigma='+str(self.sigma)+', pert='+str(self.pert)+', elem='+str(self.elem)+')'
167 |
168 |
169 | class ConcatInputs(object):
170 |
171 | def __init__(self, elems=('image', 'point')):
172 | self.elems = elems
173 |
174 | def __call__(self, sample):
175 |
176 | res = sample[self.elems[0]]
177 |
178 | for elem in self.elems[1:]:
179 | assert(sample[self.elems[0]].shape[:2] == sample[elem].shape[:2])
180 |
181 | # Check if third dimension is missing
182 | tmp = sample[elem]
183 | if tmp.ndim == 2:
184 | tmp = tmp[:, :, np.newaxis]
185 |
186 | res = np.concatenate((res, tmp), axis=2)
187 |
188 | sample['concat'] = res
189 |
190 | return sample
191 |
192 | def __str__(self):
193 | return 'ExtremePoints:'+str(self.elems)
194 |
195 |
196 | class CropFromMask(object):
197 | """
198 | Returns image cropped in bounding box from a given mask
199 | """
200 | def __init__(self, crop_elems=('image', 'gt'),
201 | mask_elem='gt',
202 | relax=0,
203 | zero_pad=False):
204 |
205 | self.crop_elems = crop_elems
206 | self.mask_elem = mask_elem
207 | self.relax = relax
208 | self.zero_pad = zero_pad
209 |
210 | def __call__(self, sample):
211 | _target = sample[self.mask_elem]
212 | if _target.ndim == 2:
213 | _target = np.expand_dims(_target, axis=-1)
214 | for elem in self.crop_elems:
215 | _img = sample[elem]
216 | _crop = []
217 | if self.mask_elem == elem:
218 | if _img.ndim == 2:
219 | _img = np.expand_dims(_img, axis=-1)
220 | for k in range(0, _target.shape[-1]):
221 | _tmp_img = _img[..., k]
222 | _tmp_target = _target[..., k]
223 | if np.max(_target[..., k]) == 0:
224 | _crop.append(np.zeros(_tmp_img.shape, dtype=_img.dtype))
225 | else:
226 | _crop.append(helpers.crop_from_mask(_tmp_img, _tmp_target, relax=self.relax, zero_pad=self.zero_pad))
227 | else:
228 | for k in range(0, _target.shape[-1]):
229 | if np.max(_target[..., k]) == 0:
230 | _crop.append(np.zeros(_img.shape, dtype=_img.dtype))
231 | else:
232 | _tmp_target = _target[..., k]
233 | _crop.append(helpers.crop_from_mask(_img, _tmp_target, relax=self.relax, zero_pad=self.zero_pad))
234 | if len(_crop) == 1:
235 | sample['crop_' + elem] = _crop[0]
236 | else:
237 | sample['crop_' + elem] = _crop
238 | return sample
239 |
240 | def __str__(self):
241 | return 'CropFromMask:(crop_elems='+str(self.crop_elems)+', mask_elem='+str(self.mask_elem)+\
242 | ', relax='+str(self.relax)+',zero_pad='+str(self.zero_pad)+')'
243 |
244 |
245 | class ToImage(object):
246 | """
247 | Return the given elements between 0 and 255
248 | """
249 | def __init__(self, norm_elem='image', custom_max=255.):
250 | self.norm_elem = norm_elem
251 | self.custom_max = custom_max
252 |
253 | def __call__(self, sample):
254 | if isinstance(self.norm_elem, tuple):
255 | for elem in self.norm_elem:
256 | tmp = sample[elem]
257 | sample[elem] = self.custom_max * (tmp - tmp.min()) / (tmp.max() - tmp.min() + 1e-10)
258 | else:
259 | tmp = sample[self.norm_elem]
260 | sample[self.norm_elem] = self.custom_max * (tmp - tmp.min()) / (tmp.max() - tmp.min() + 1e-10)
261 | return sample
262 |
263 | def __str__(self):
264 | return 'NormalizeImage'
265 |
266 |
267 | class ToTensor(object):
268 | """Convert ndarrays in sample to Tensors."""
269 |
270 | def __call__(self, sample):
271 |
272 | for elem in sample.keys():
273 | if 'meta' in elem:
274 | continue
275 | elif 'bbox' in elem:
276 | tmp = sample[elem]
277 | sample[elem] = torch.from_numpy(tmp)
278 | continue
279 |
280 | tmp = sample[elem]
281 |
282 | if tmp.ndim == 2:
283 | tmp = tmp[:, :, np.newaxis]
284 |
285 | # swap color axis because
286 | # numpy image: H x W x C
287 | # torch image: C X H X W
288 | tmp = tmp.transpose((2, 0, 1))
289 | sample[elem] = torch.from_numpy(tmp)
290 |
291 | return sample
292 |
293 | def __str__(self):
294 | return 'ToTensor'
295 |
--------------------------------------------------------------------------------
/train_pascal.py:
--------------------------------------------------------------------------------
1 | import socket
2 | import timeit
3 | from datetime import datetime
4 | import scipy.misc as sm
5 | from collections import OrderedDict
6 | import glob
7 |
8 | # PyTorch includes
9 | import torch.optim as optim
10 | from torchvision import transforms
11 | from torch.utils.data import DataLoader
12 | from torch.nn.functional import upsample
13 |
14 | # Tensorboard include
15 | from tensorboardX import SummaryWriter
16 |
17 | # Custom includes
18 | from dataloaders.combine_dbs import CombineDBs as combine_dbs
19 | import dataloaders.pascal as pascal
20 | import dataloaders.sbd as sbd
21 | from dataloaders import custom_transforms as tr
22 | import networks.deeplab_resnet as resnet
23 | from layers.loss import class_balanced_cross_entropy_loss
24 | from dataloaders.helpers import *
25 |
26 | # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu
27 | gpu_id = 0
28 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
29 | if torch.cuda.is_available():
30 | print('Using GPU: {} '.format(gpu_id))
31 |
32 | # Setting parameters
33 | use_sbd = False
34 | nEpochs = 100 # Number of epochs for training
35 | resume_epoch = 0 # Default is 0, change if want to resume
36 |
37 | p = OrderedDict() # Parameters to include in report
38 | classifier = 'psp' # Head classifier to use
39 | p['trainBatch'] = 5 # Training batch size
40 | testBatch = 5 # Testing batch size
41 | useTest = 1 # See evolution of the test set when training?
42 | nTestInterval = 10 # Run on test set every nTestInterval epochs
43 | snapshot = 20 # Store a model every snapshot epochs
44 | relax_crop = 50 # Enlarge the bounding box by relax_crop pixels
45 | nInputChannels = 4 # Number of input channels (RGB + heatmap of extreme points)
46 | zero_pad_crop = True # Insert zero padding when cropping the image
47 | p['nAveGrad'] = 1 # Average the gradient of several iterations
48 | p['lr'] = 1e-8 # Learning rate
49 | p['wd'] = 0.0005 # Weight decay
50 | p['momentum'] = 0.9 # Momentum
51 |
52 | # Results and model directories (a new directory is generated for every run)
53 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
54 | exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
55 | if resume_epoch == 0:
56 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*')))
57 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
58 | else:
59 | run_id = 0
60 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id))
61 | if not os.path.exists(os.path.join(save_dir, 'models')):
62 | os.makedirs(os.path.join(save_dir, 'models'))
63 |
64 | # Network definition
65 | modelName = 'dextr_pascal'
66 | net = resnet.resnet101(1, pretrained=True, nInputChannels=nInputChannels, classifier=classifier)
67 | if resume_epoch == 0:
68 | print("Initializing from pretrained Deeplab-v2 model")
69 | else:
70 | print("Initializing weights from: {}".format(
71 | os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')))
72 | net.load_state_dict(
73 | torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'),
74 | map_location=lambda storage, loc: storage))
75 | train_params = [{'params': resnet.get_1x_lr_params(net), 'lr': p['lr']},
76 | {'params': resnet.get_10x_lr_params(net), 'lr': p['lr'] * 10}]
77 |
78 | net.to(device)
79 |
80 | # Training the network
81 | if resume_epoch != nEpochs:
82 | # Logging into Tensorboard
83 | log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
84 | writer = SummaryWriter(log_dir=log_dir)
85 |
86 | # Use the following optimizer
87 | optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
88 | p['optimizer'] = str(optimizer)
89 |
90 | # Preparation of the data loaders
91 | composed_transforms_tr = transforms.Compose([
92 | tr.RandomHorizontalFlip(),
93 | tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
94 | tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
95 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}),
96 | tr.ExtremePoints(sigma=10, pert=5, elem='crop_gt'),
97 | tr.ToImage(norm_elem='extreme_points'),
98 | tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
99 | tr.ToTensor()])
100 | composed_transforms_ts = transforms.Compose([
101 | tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
102 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}),
103 | tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'),
104 | tr.ToImage(norm_elem='extreme_points'),
105 | tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
106 | tr.ToTensor()])
107 |
108 | voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr)
109 | voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
110 |
111 | if use_sbd:
112 | sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=composed_transforms_tr, retname=True)
113 | db_train = combine_dbs([voc_train, sbd], excluded=[voc_val])
114 | else:
115 | db_train = voc_train
116 |
117 | p['dataset_train'] = str(db_train)
118 | p['transformations_train'] = [str(tran) for tran in composed_transforms_tr.transforms]
119 | p['dataset_test'] = str(db_train)
120 | p['transformations_test'] = [str(tran) for tran in composed_transforms_ts.transforms]
121 |
122 | trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=2)
123 | testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=2)
124 |
125 | generate_param_report(os.path.join(save_dir, exp_name + '.txt'), p)
126 |
127 | # Train variables
128 | num_img_tr = len(trainloader)
129 | num_img_ts = len(testloader)
130 | running_loss_tr = 0.0
131 | running_loss_ts = 0.0
132 | aveGrad = 0
133 | print("Training Network")
134 | # Main Training and Testing Loop
135 | for epoch in range(resume_epoch, nEpochs):
136 | start_time = timeit.default_timer()
137 |
138 | net.train()
139 | for ii, sample_batched in enumerate(trainloader):
140 |
141 | inputs, gts = sample_batched['concat'], sample_batched['crop_gt']
142 |
143 | # Forward-Backward of the mini-batch
144 | inputs.requires_grad_()
145 | inputs, gts = inputs.to(device), gts.to(device)
146 |
147 | output = net.forward(inputs)
148 | output = upsample(output, size=(512, 512), mode='bilinear', align_corners=True)
149 |
150 | # Compute the losses, side outputs and fuse
151 | loss = class_balanced_cross_entropy_loss(output, gts, size_average=False, batch_average=True)
152 | running_loss_tr += loss.item()
153 |
154 | # Print stuff
155 | if ii % num_img_tr == num_img_tr - 1:
156 | running_loss_tr = running_loss_tr / num_img_tr
157 | writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
158 | print('[Epoch: %d, numImages: %5d]' % (epoch, ii*p['trainBatch']+inputs.data.shape[0]))
159 | print('Loss: %f' % running_loss_tr)
160 | running_loss_tr = 0
161 | stop_time = timeit.default_timer()
162 | print("Execution time: " + str(stop_time - start_time)+"\n")
163 |
164 | # Backward the averaged gradient
165 | loss /= p['nAveGrad']
166 | loss.backward()
167 | aveGrad += 1
168 |
169 | # Update the weights once in p['nAveGrad'] forward passes
170 | if aveGrad % p['nAveGrad'] == 0:
171 | writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch)
172 | optimizer.step()
173 | optimizer.zero_grad()
174 | aveGrad = 0
175 |
176 | # Save the model
177 | if (epoch % snapshot) == snapshot - 1 and epoch != 0:
178 | torch.save(net.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))
179 |
180 | # One testing epoch
181 | if useTest and epoch % nTestInterval == (nTestInterval - 1):
182 | net.eval()
183 | with torch.no_grad():
184 | for ii, sample_batched in enumerate(testloader):
185 | inputs, gts = sample_batched['concat'], sample_batched['crop_gt']
186 |
187 | # Forward pass of the mini-batch
188 | inputs, gts = inputs.to(device), gts.to(device)
189 |
190 | output = net.forward(inputs)
191 | output = upsample(output, size=(512, 512), mode='bilinear', align_corners=True)
192 |
193 | # Compute the losses, side outputs and fuse
194 | loss = class_balanced_cross_entropy_loss(output, gts, size_average=False)
195 | running_loss_ts += loss.item()
196 |
197 | # Print stuff
198 | if ii % num_img_ts == num_img_ts - 1:
199 | running_loss_ts = running_loss_ts / num_img_ts
200 | print('[Epoch: %d, numImages: %5d]' % (epoch, ii*testBatch+inputs.data.shape[0]))
201 | writer.add_scalar('data/test_loss_epoch', running_loss_ts, epoch)
202 | print('Loss: %f' % running_loss_ts)
203 | running_loss_ts = 0
204 |
205 | writer.close()
206 |
207 | # Generate result of the validation images
208 | net.eval()
209 | composed_transforms_ts = transforms.Compose([
210 | tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
211 | tr.FixedResize(resolutions={'gt': None, 'crop_image': (512, 512), 'crop_gt': (512, 512)}),
212 | tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'),
213 | tr.ToImage(norm_elem='extreme_points'),
214 | tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
215 | tr.ToTensor()])
216 | db_test = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts, retname=True)
217 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
218 |
219 | save_dir_res = os.path.join(save_dir, 'Results')
220 | if not os.path.exists(save_dir_res):
221 | os.makedirs(save_dir_res)
222 |
223 | print('Testing Network')
224 | with torch.no_grad():
225 | # Main Testing Loop
226 | for ii, sample_batched in enumerate(testloader):
227 |
228 | inputs, gts, metas = sample_batched['concat'], sample_batched['gt'], sample_batched['meta']
229 |
230 | # Forward of the mini-batch
231 | inputs = inputs.to(device)
232 |
233 | outputs = net.forward(inputs)
234 | outputs = upsample(outputs, size=(512, 512), mode='bilinear', align_corners=True)
235 | outputs = outputs.to(torch.device('cpu'))
236 |
237 | for jj in range(int(inputs.size()[0])):
238 | pred = np.transpose(outputs.data.numpy()[jj, :, :, :], (1, 2, 0))
239 | pred = 1 / (1 + np.exp(-pred))
240 | pred = np.squeeze(pred)
241 | gt = tens2image(gts[jj, :, :, :])
242 | bbox = get_bbox(gt, pad=relax_crop, zero_pad=zero_pad_crop)
243 | result = crop2fullmask(pred, bbox, gt, zero_pad=zero_pad_crop, relax=relax_crop)
244 |
245 | # Save the result, attention to the index jj
246 | sm.imsave(os.path.join(save_dir_res, metas['image'][jj] + '-' + metas['object'][jj] + '.png'), result)
247 |
--------------------------------------------------------------------------------
/dataloaders/pascal.py:
--------------------------------------------------------------------------------
1 | import torch, cv2
2 | import errno
3 | import hashlib
4 | import os
5 | import sys
6 | import tarfile
7 | import numpy as np
8 |
9 | import torch.utils.data as data
10 | from PIL import Image
11 | from six.moves import urllib
12 | import json
13 | from mypath import Path
14 |
15 |
16 | class VOCSegmentation(data.Dataset):
17 |
18 | URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
19 | FILE = "VOCtrainval_11-May-2012.tar"
20 | MD5 = '6cd6e144f989b92b3379bac3b3de84fd'
21 | BASE_DIR = 'VOCdevkit/VOC2012'
22 |
23 | category_names = ['background',
24 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
25 | 'bus', 'car', 'cat', 'chair', 'cow',
26 | 'diningtable', 'dog', 'horse', 'motorbike', 'person',
27 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
28 |
29 | def __init__(self,
30 | root=Path.db_root_dir('pascal'),
31 | split='val',
32 | transform=None,
33 | download=False,
34 | preprocess=False,
35 | area_thres=0,
36 | retname=True,
37 | suppress_void_pixels=True,
38 | default=False):
39 |
40 | self.root = root
41 | _voc_root = os.path.join(self.root, self.BASE_DIR)
42 | _mask_dir = os.path.join(_voc_root, 'SegmentationObject')
43 | _cat_dir = os.path.join(_voc_root, 'SegmentationClass')
44 | _image_dir = os.path.join(_voc_root, 'JPEGImages')
45 | self.transform = transform
46 | if isinstance(split, str):
47 | self.split = [split]
48 | else:
49 | split.sort()
50 | self.split = split
51 | self.area_thres = area_thres
52 | self.retname = retname
53 | self.suppress_void_pixels = suppress_void_pixels
54 | self.default = default
55 |
56 | # Build the ids file
57 | area_th_str = ""
58 | if self.area_thres != 0:
59 | area_th_str = '_area_thres-' + str(area_thres)
60 |
61 | self.obj_list_file = os.path.join(self.root, self.BASE_DIR, 'ImageSets', 'Segmentation',
62 | '_'.join(self.split) + '_instances' + area_th_str + '.txt')
63 |
64 | if download:
65 | self._download()
66 |
67 | if not self._check_integrity():
68 | raise RuntimeError('Dataset not found or corrupted.' +
69 | ' You can use download=True to download it')
70 |
71 | # train/val/test splits are pre-cut
72 | _splits_dir = os.path.join(_voc_root, 'ImageSets', 'Segmentation')
73 |
74 | self.im_ids = []
75 | self.images = []
76 | self.categories = []
77 | self.masks = []
78 |
79 | for splt in self.split:
80 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
81 | lines = f.read().splitlines()
82 |
83 | for ii, line in enumerate(lines):
84 | _image = os.path.join(_image_dir, line + ".jpg")
85 | _cat = os.path.join(_cat_dir, line + ".png")
86 | _mask = os.path.join(_mask_dir, line + ".png")
87 | assert os.path.isfile(_image)
88 | assert os.path.isfile(_cat)
89 | assert os.path.isfile(_mask)
90 | self.im_ids.append(line.rstrip('\n'))
91 | self.images.append(_image)
92 | self.categories.append(_cat)
93 | self.masks.append(_mask)
94 |
95 | assert (len(self.images) == len(self.masks))
96 | assert (len(self.images) == len(self.categories))
97 |
98 | # Precompute the list of objects and their categories for each image
99 | if (not self._check_preprocess()) or preprocess:
100 | print('Preprocessing of PASCAL VOC dataset, this will take long, but it will be done only once.')
101 | self._preprocess()
102 |
103 | # Build the list of objects
104 | self.obj_list = []
105 | num_images = 0
106 | for ii in range(len(self.im_ids)):
107 | flag = False
108 | for jj in range(len(self.obj_dict[self.im_ids[ii]])):
109 | if self.obj_dict[self.im_ids[ii]][jj] != -1:
110 | self.obj_list.append([ii, jj])
111 | flag = True
112 | if flag:
113 | num_images += 1
114 |
115 | # Display stats
116 | print('Number of images: {:d}\nNumber of objects: {:d}'.format(num_images, len(self.obj_list)))
117 |
118 | def __getitem__(self, index):
119 | _img, _target, _void_pixels, _, _, _ = self._make_img_gt_point_pair(index)
120 | sample = {'image': _img, 'gt': _target, 'void_pixels': _void_pixels}
121 |
122 | if self.retname:
123 | _im_ii = self.obj_list[index][0]
124 | _obj_ii = self.obj_list[index][1]
125 | sample['meta'] = {'image': str(self.im_ids[_im_ii]),
126 | 'object': str(_obj_ii),
127 | 'category': self.obj_dict[self.im_ids[_im_ii]][_obj_ii],
128 | 'im_size': (_img.shape[0], _img.shape[1])}
129 |
130 | if self.transform is not None:
131 | sample = self.transform(sample)
132 |
133 | return sample
134 |
135 | def __len__(self):
136 | return len(self.obj_list)
137 |
138 | def _check_integrity(self):
139 | _fpath = os.path.join(self.root, self.FILE)
140 | if not os.path.isfile(_fpath):
141 | print("{} does not exist".format(_fpath))
142 | return False
143 | _md5c = hashlib.md5(open(_fpath, 'rb').read()).hexdigest()
144 | if _md5c != self.MD5:
145 | print(" MD5({}) did not match MD5({}) expected for {}".format(
146 | _md5c, self.MD5, _fpath))
147 | return False
148 | return True
149 |
150 | def _check_preprocess(self):
151 | _obj_list_file = self.obj_list_file
152 | if not os.path.isfile(_obj_list_file):
153 | return False
154 | else:
155 | self.obj_dict = json.load(open(_obj_list_file, 'r'))
156 |
157 | return list(np.sort([str(x) for x in self.obj_dict.keys()])) == list(np.sort(self.im_ids))
158 |
159 | def _preprocess(self):
160 | self.obj_dict = {}
161 | obj_counter = 0
162 | for ii in range(len(self.im_ids)):
163 | # Read object masks and get number of objects
164 | _mask = np.array(Image.open(self.masks[ii]))
165 | _mask_ids = np.unique(_mask)
166 | if _mask_ids[-1] == 255:
167 | n_obj = _mask_ids[-2]
168 | else:
169 | n_obj = _mask_ids[-1]
170 |
171 | # Get the categories from these objects
172 | _cats = np.array(Image.open(self.categories[ii]))
173 | _cat_ids = []
174 | for jj in range(n_obj):
175 | tmp = np.where(_mask == jj + 1)
176 | obj_area = len(tmp[0])
177 | if obj_area > self.area_thres:
178 | _cat_ids.append(int(_cats[tmp[0][0], tmp[1][0]]))
179 | else:
180 | _cat_ids.append(-1)
181 | obj_counter += 1
182 |
183 | self.obj_dict[self.im_ids[ii]] = _cat_ids
184 |
185 | with open(self.obj_list_file, 'w') as outfile:
186 | outfile.write('{{\n\t"{:s}": {:s}'.format(self.im_ids[0], json.dumps(self.obj_dict[self.im_ids[0]])))
187 | for ii in range(1, len(self.im_ids)):
188 | outfile.write(',\n\t"{:s}": {:s}'.format(self.im_ids[ii], json.dumps(self.obj_dict[self.im_ids[ii]])))
189 | outfile.write('\n}\n')
190 |
191 | print('Preprocessing finished')
192 |
193 | def _download(self):
194 | _fpath = os.path.join(self.root, self.FILE)
195 |
196 | try:
197 | os.makedirs(self.root)
198 | except OSError as e:
199 | if e.errno == errno.EEXIST:
200 | pass
201 | else:
202 | raise
203 |
204 | if self._check_integrity():
205 | print('Files already downloaded and verified')
206 | return
207 | else:
208 | print('Downloading ' + self.URL + ' to ' + _fpath)
209 |
210 | def _progress(count, block_size, total_size):
211 | sys.stdout.write('\r>> %s %.1f%%' %
212 | (_fpath, float(count * block_size) /
213 | float(total_size) * 100.0))
214 | sys.stdout.flush()
215 |
216 | urllib.request.urlretrieve(self.URL, _fpath, _progress)
217 |
218 | # extract file
219 | cwd = os.getcwd()
220 | print('Extracting tar file')
221 | tar = tarfile.open(_fpath)
222 | os.chdir(self.root)
223 | tar.extractall()
224 | tar.close()
225 | os.chdir(cwd)
226 | print('Done!')
227 |
228 | def _make_img_gt_point_pair(self, index):
229 | _im_ii = self.obj_list[index][0]
230 | _obj_ii = self.obj_list[index][1]
231 |
232 | # Read Image
233 | _img = np.array(Image.open(self.images[_im_ii]).convert('RGB')).astype(np.float32)
234 |
235 | # Read Target object
236 | _tmp = (np.array(Image.open(self.masks[_im_ii]))).astype(np.float32)
237 | _void_pixels = (_tmp == 255)
238 | _tmp[_void_pixels] = 0
239 |
240 | _other_same_class = np.zeros(_tmp.shape)
241 | _other_classes = np.zeros(_tmp.shape)
242 |
243 | if self.default:
244 | _target = _tmp
245 | _background = np.logical_and(_tmp == 0, ~_void_pixels)
246 | else:
247 | _target = (_tmp == (_obj_ii + 1)).astype(np.float32)
248 | _background = np.logical_and(_tmp == 0, ~_void_pixels)
249 | obj_cat = self.obj_dict[self.im_ids[_im_ii]][_obj_ii]
250 | for ii in range(1, np.max(_tmp).astype(np.int)+1):
251 | ii_cat = self.obj_dict[self.im_ids[_im_ii]][ii-1]
252 | if obj_cat == ii_cat and ii != _obj_ii+1:
253 | _other_same_class = np.logical_or(_other_same_class, _tmp == ii)
254 | elif ii != _obj_ii+1:
255 | _other_classes = np.logical_or(_other_classes, _tmp == ii)
256 |
257 | return _img, _target, _void_pixels.astype(np.float32), \
258 | _other_classes.astype(np.float32), _other_same_class.astype(np.float32), \
259 | _background.astype(np.float32)
260 |
261 | def __str__(self):
262 | return 'VOC2012(split=' + str(self.split) + ',area_thres=' + str(self.area_thres) + ')'
263 |
264 |
265 | if __name__ == '__main__':
266 | import matplotlib.pyplot as plt
267 | import dataloaders.helpers as helpers
268 | import torch
269 | import dataloaders.custom_transforms as tr
270 | from torchvision import transforms
271 |
272 | transform = transforms.Compose([tr.ToTensor()])
273 |
274 | dataset = VOCSegmentation(split=['train', 'val'], transform=transform, retname=True)
275 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
276 |
277 | for i, sample in enumerate(dataloader):
278 | plt.figure()
279 | overlay = helpers.overlay_mask(helpers.tens2image(sample["image"]) / 255.,
280 | np.squeeze(helpers.tens2image(sample["gt"])))
281 | plt.imshow(overlay)
282 | plt.title(dataset.category_names[sample["meta"]["category"][0]])
283 | if i == 3:
284 | break
285 |
286 | plt.show(block=True)
287 |
--------------------------------------------------------------------------------
/dataloaders/helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch, cv2
4 | import random
5 | import numpy as np
6 |
7 |
8 | def tens2image(im):
9 | if im.size()[0] == 1:
10 | tmp = np.squeeze(im.numpy(), axis=0)
11 | else:
12 | tmp = im.numpy()
13 | if tmp.ndim == 2:
14 | return tmp
15 | else:
16 | return tmp.transpose((1, 2, 0))
17 |
18 |
19 | def crop2fullmask(crop_mask, bbox, im=None, im_size=None, zero_pad=False, relax=0, mask_relax=True,
20 | interpolation=cv2.INTER_CUBIC, scikit=False):
21 | if scikit:
22 | from skimage.transform import resize as sk_resize
23 | assert(not(im is None and im_size is None)), 'You have to provide an image or the image size'
24 | if im is None:
25 | im_si = im_size
26 | else:
27 | im_si = im.shape
28 | # Borers of image
29 | bounds = (0, 0, im_si[1] - 1, im_si[0] - 1)
30 |
31 | # Valid bounding box locations as (x_min, y_min, x_max, y_max)
32 | bbox_valid = (max(bbox[0], bounds[0]),
33 | max(bbox[1], bounds[1]),
34 | min(bbox[2], bounds[2]),
35 | min(bbox[3], bounds[3]))
36 |
37 | # Bounding box of initial mask
38 | bbox_init = (bbox[0] + relax,
39 | bbox[1] + relax,
40 | bbox[2] - relax,
41 | bbox[3] - relax)
42 |
43 | if zero_pad:
44 | # Offsets for x and y
45 | offsets = (-bbox[0], -bbox[1])
46 | else:
47 | assert((bbox == bbox_valid).all())
48 | offsets = (-bbox_valid[0], -bbox_valid[1])
49 |
50 | # Simple per element addition in the tuple
51 | inds = tuple(map(sum, zip(bbox_valid, offsets + offsets)))
52 |
53 | if scikit:
54 | crop_mask = sk_resize(crop_mask, (bbox[3] - bbox[1] + 1, bbox[2] - bbox[0] + 1), order=0, mode='constant').astype(crop_mask.dtype)
55 | else:
56 | crop_mask = cv2.resize(crop_mask, (bbox[2] - bbox[0] + 1, bbox[3] - bbox[1] + 1), interpolation=interpolation)
57 | result_ = np.zeros(im_si)
58 | result_[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1] = \
59 | crop_mask[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1]
60 |
61 | result = np.zeros(im_si)
62 | if mask_relax:
63 | result[bbox_init[1]:bbox_init[3]+1, bbox_init[0]:bbox_init[2]+1] = \
64 | result_[bbox_init[1]:bbox_init[3]+1, bbox_init[0]:bbox_init[2]+1]
65 | else:
66 | result = result_
67 |
68 | return result
69 |
70 |
71 | def overlay_mask(im, ma, colors=None, alpha=0.5):
72 | assert np.max(im) <= 1.0
73 | if colors is None:
74 | colors = np.load(os.path.join(os.path.dirname(__file__), 'pascal_map.npy'))/255.
75 | else:
76 | colors = np.append([[0.,0.,0.]], colors, axis=0);
77 |
78 | if ma.ndim == 3:
79 | assert len(colors) >= ma.shape[0], 'Not enough colors'
80 | ma = ma.astype(np.bool)
81 | im = im.astype(np.float32)
82 |
83 | if ma.ndim == 2:
84 | fg = im * alpha+np.ones(im.shape) * (1 - alpha) * colors[1, :3] # np.array([0,0,255])/255.0
85 | else:
86 | fg = []
87 | for n in range(ma.ndim):
88 | fg.append(im * alpha + np.ones(im.shape) * (1 - alpha) * colors[1+n, :3])
89 | # Whiten background
90 | bg = im.copy()
91 | if ma.ndim == 2:
92 | bg[ma == 0] = im[ma == 0]
93 | bg[ma == 1] = fg[ma == 1]
94 | total_ma = ma
95 | else:
96 | total_ma = np.zeros([ma.shape[1], ma.shape[2]])
97 | for n in range(ma.shape[0]):
98 | tmp_ma = ma[n, :, :]
99 | total_ma = np.logical_or(tmp_ma, total_ma)
100 | tmp_fg = fg[n]
101 | bg[tmp_ma == 1] = tmp_fg[tmp_ma == 1]
102 | bg[total_ma == 0] = im[total_ma == 0]
103 |
104 | # [-2:] is s trick to be compatible both with opencv 2 and 3
105 | contours = cv2.findContours(total_ma.copy().astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
106 | cv2.drawContours(bg, contours[0], -1, (0.0, 0.0, 0.0), 1)
107 |
108 | return bg
109 |
110 | def overlay_masks(im, masks, alpha=0.5):
111 | colors = np.load(os.path.join(os.path.dirname(__file__), 'pascal_map.npy'))/255.
112 |
113 | if isinstance(masks, np.ndarray):
114 | masks = [masks]
115 |
116 | assert len(colors) >= len(masks), 'Not enough colors'
117 |
118 | ov = im.copy()
119 | im = im.astype(np.float32)
120 | total_ma = np.zeros([im.shape[0], im.shape[1]])
121 | i = 1
122 | for ma in masks:
123 | ma = ma.astype(np.bool)
124 | fg = im * alpha+np.ones(im.shape) * (1 - alpha) * colors[i, :3] # np.array([0,0,255])/255.0
125 | i = i + 1
126 | ov[ma == 1] = fg[ma == 1]
127 | total_ma += ma
128 |
129 | # [-2:] is s trick to be compatible both with opencv 2 and 3
130 | contours = cv2.findContours(ma.copy().astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
131 | cv2.drawContours(ov, contours[0], -1, (0.0, 0.0, 0.0), 1)
132 | ov[total_ma == 0] = im[total_ma == 0]
133 |
134 | return ov
135 |
136 |
137 | def extreme_points(mask, pert):
138 | def find_point(id_x, id_y, ids):
139 | sel_id = ids[0][random.randint(0, len(ids[0]) - 1)]
140 | return [id_x[sel_id], id_y[sel_id]]
141 |
142 | # List of coordinates of the mask
143 | inds_y, inds_x = np.where(mask > 0.5)
144 |
145 | # Find extreme points
146 | return np.array([find_point(inds_x, inds_y, np.where(inds_x <= np.min(inds_x)+pert)), # left
147 | find_point(inds_x, inds_y, np.where(inds_x >= np.max(inds_x)-pert)), # right
148 | find_point(inds_x, inds_y, np.where(inds_y <= np.min(inds_y)+pert)), # top
149 | find_point(inds_x, inds_y, np.where(inds_y >= np.max(inds_y)-pert)) # bottom
150 | ])
151 |
152 |
153 | def get_bbox(mask, points=None, pad=0, zero_pad=False):
154 | if points is not None:
155 | inds = np.flip(points.transpose(), axis=0)
156 | else:
157 | inds = np.where(mask > 0)
158 |
159 | if inds[0].shape[0] == 0:
160 | return None
161 |
162 | if zero_pad:
163 | x_min_bound = -np.inf
164 | y_min_bound = -np.inf
165 | x_max_bound = np.inf
166 | y_max_bound = np.inf
167 | else:
168 | x_min_bound = 0
169 | y_min_bound = 0
170 | x_max_bound = mask.shape[1] - 1
171 | y_max_bound = mask.shape[0] - 1
172 |
173 | x_min = max(inds[1].min() - pad, x_min_bound)
174 | y_min = max(inds[0].min() - pad, y_min_bound)
175 | x_max = min(inds[1].max() + pad, x_max_bound)
176 | y_max = min(inds[0].max() + pad, y_max_bound)
177 |
178 | return x_min, y_min, x_max, y_max
179 |
180 |
181 | def crop_from_bbox(img, bbox, zero_pad=False):
182 | # Borders of image
183 | bounds = (0, 0, img.shape[1] - 1, img.shape[0] - 1)
184 |
185 | # Valid bounding box locations as (x_min, y_min, x_max, y_max)
186 | bbox_valid = (max(bbox[0], bounds[0]),
187 | max(bbox[1], bounds[1]),
188 | min(bbox[2], bounds[2]),
189 | min(bbox[3], bounds[3]))
190 |
191 | if zero_pad:
192 | # Initialize crop size (first 2 dimensions)
193 | crop = np.zeros((bbox[3] - bbox[1] + 1, bbox[2] - bbox[0] + 1), dtype=img.dtype)
194 |
195 | # Offsets for x and y
196 | offsets = (-bbox[0], -bbox[1])
197 |
198 | else:
199 | assert(bbox == bbox_valid)
200 | crop = np.zeros((bbox_valid[3] - bbox_valid[1] + 1, bbox_valid[2] - bbox_valid[0] + 1), dtype=img.dtype)
201 | offsets = (-bbox_valid[0], -bbox_valid[1])
202 |
203 | # Simple per element addition in the tuple
204 | inds = tuple(map(sum, zip(bbox_valid, offsets + offsets)))
205 |
206 | img = np.squeeze(img)
207 | if img.ndim == 2:
208 | crop[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1] = \
209 | img[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1]
210 | else:
211 | crop = np.tile(crop[:, :, np.newaxis], [1, 1, 3]) # Add 3 RGB Channels
212 | crop[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1, :] = \
213 | img[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1, :]
214 |
215 | return crop
216 |
217 |
218 | def fixed_resize(sample, resolution, flagval=None):
219 |
220 | if flagval is None:
221 | if ((sample == 0) | (sample == 1)).all():
222 | flagval = cv2.INTER_NEAREST
223 | else:
224 | flagval = cv2.INTER_CUBIC
225 |
226 | if isinstance(resolution, int):
227 | tmp = [resolution, resolution]
228 | tmp[np.argmax(sample.shape[:2])] = int(round(float(resolution)/np.min(sample.shape[:2])*np.max(sample.shape[:2])))
229 | resolution = tuple(tmp)
230 |
231 | if sample.ndim == 2 or (sample.ndim == 3 and sample.shape[2] == 3):
232 | sample = cv2.resize(sample, resolution[::-1], interpolation=flagval)
233 | else:
234 | tmp = sample
235 | sample = np.zeros(np.append(resolution, tmp.shape[2]), dtype=np.float32)
236 | for ii in range(sample.shape[2]):
237 | sample[:, :, ii] = cv2.resize(tmp[:, :, ii], resolution[::-1], interpolation=flagval)
238 | return sample
239 |
240 |
241 | def crop_from_mask(img, mask, relax=0, zero_pad=False):
242 | if mask.shape[:2] != img.shape[:2]:
243 | mask = cv2.resize(mask, dsize=tuple(reversed(img.shape[:2])), interpolation=cv2.INTER_NEAREST)
244 |
245 | assert(mask.shape[:2] == img.shape[:2])
246 |
247 | bbox = get_bbox(mask, pad=relax, zero_pad=zero_pad)
248 |
249 | if bbox is None:
250 | return None
251 |
252 | crop = crop_from_bbox(img, bbox, zero_pad)
253 |
254 | return crop
255 |
256 |
257 | def make_gaussian(size, sigma=10, center=None, d_type=np.float64):
258 | """ Make a square gaussian kernel.
259 | size: is the dimensions of the output gaussian
260 | sigma: is full-width-half-maximum, which
261 | can be thought of as an effective radius.
262 | """
263 |
264 | x = np.arange(0, size[1], 1, float)
265 | y = np.arange(0, size[0], 1, float)
266 | y = y[:, np.newaxis]
267 |
268 | if center is None:
269 | x0 = y0 = size[0] // 2
270 | else:
271 | x0 = center[0]
272 | y0 = center[1]
273 |
274 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / sigma ** 2).astype(d_type)
275 |
276 |
277 | def make_gt(img, labels, sigma=10, one_mask_per_point=False):
278 | """ Make the ground-truth for landmark.
279 | img: the original color image
280 | labels: label with the Gaussian center(s) [[x0, y0],[x1, y1],...]
281 | sigma: sigma of the Gaussian.
282 | one_mask_per_point: masks for each point in different channels?
283 | """
284 | h, w = img.shape[:2]
285 | if labels is None:
286 | gt = make_gaussian((h, w), center=(h//2, w//2), sigma=sigma)
287 | else:
288 | labels = np.array(labels)
289 | if labels.ndim == 1:
290 | labels = labels[np.newaxis]
291 | if one_mask_per_point:
292 | gt = np.zeros(shape=(h, w, labels.shape[0]))
293 | for ii in range(labels.shape[0]):
294 | gt[:, :, ii] = make_gaussian((h, w), center=labels[ii, :], sigma=sigma)
295 | else:
296 | gt = np.zeros(shape=(h, w), dtype=np.float64)
297 | for ii in range(labels.shape[0]):
298 | gt = np.maximum(gt, make_gaussian((h, w), center=labels[ii, :], sigma=sigma))
299 |
300 | gt = gt.astype(dtype=img.dtype)
301 |
302 | return gt
303 |
304 |
305 | def cstm_normalize(im, max_value):
306 | """
307 | Normalize image to range 0 - max_value
308 | """
309 | imn = max_value*(im - im.min()) / max((im.max() - im.min()), 1e-8)
310 | return imn
311 |
312 |
313 | def generate_param_report(logfile, param):
314 | log_file = open(logfile, 'w')
315 | for key, val in param.items():
316 | log_file.write(key+':'+str(val)+'\n')
317 | log_file.close()
318 |
--------------------------------------------------------------------------------
/networks/deeplab_resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torchvision.models.resnet as resnet
3 | import torch
4 | import numpy as np
5 | from copy import deepcopy
6 | import os
7 | from torch.nn import functional as F
8 | from mypath import Path
9 |
10 | affine_par = True
11 |
12 |
13 | def outS(i):
14 | i = int(i)
15 | i = (i+1)/2
16 | i = int(np.ceil((i+1)/2.0))
17 | i = (i+1)/2
18 | return i
19 |
20 |
21 | class Bottleneck(nn.Module):
22 | expansion = 4
23 |
24 | def __init__(self, inplanes, planes, stride=1, dilation_=1, downsample=None):
25 | super(Bottleneck, self).__init__()
26 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
27 | self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
28 | for i in self.bn1.parameters():
29 | i.requires_grad = False
30 | padding = 1
31 | if dilation_ == 2:
32 | padding = 2
33 | elif dilation_ == 4:
34 | padding = 4
35 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
36 | padding=padding, bias=False, dilation=dilation_)
37 | self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
38 | for i in self.bn2.parameters():
39 | i.requires_grad = False
40 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
41 | self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
42 | for i in self.bn3.parameters():
43 | i.requires_grad = False
44 | self.relu = nn.ReLU(inplace=True)
45 | self.downsample = downsample
46 | self.stride = stride
47 |
48 | def forward(self, x):
49 | residual = x
50 |
51 | out = self.conv1(x)
52 | out = self.bn1(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv2(out)
56 | out = self.bn2(out)
57 | out = self.relu(out)
58 |
59 | out = self.conv3(out)
60 | out = self.bn3(out)
61 |
62 | if self.downsample is not None:
63 | residual = self.downsample(x)
64 |
65 | out += residual
66 | out = self.relu(out)
67 |
68 | return out
69 |
70 |
71 | class ClassifierModule(nn.Module):
72 |
73 | def __init__(self, dilation_series, padding_series, n_classes):
74 | super(ClassifierModule, self).__init__()
75 | self.conv2d_list = nn.ModuleList()
76 | for dilation, padding in zip(dilation_series, padding_series):
77 | self.conv2d_list.append(nn.Conv2d(2048, n_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True))
78 |
79 | for m in self.conv2d_list:
80 | m.weight.data.normal_(0, 0.01)
81 |
82 | def forward(self, x):
83 | out = self.conv2d_list[0](x)
84 | for i in range(len(self.conv2d_list)-1):
85 | out += self.conv2d_list[i+1](x)
86 | return out
87 |
88 |
89 | class PSPModule(nn.Module):
90 | """
91 | Pyramid Scene Parsing module
92 | """
93 | def __init__(self, in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=1):
94 | super(PSPModule, self).__init__()
95 | self.stages = []
96 | self.stages = nn.ModuleList([self._make_stage_1(in_features, size) for size in sizes])
97 | self.bottleneck = self._make_stage_2(in_features * (len(sizes)//4 + 1), out_features)
98 | self.relu = nn.ReLU()
99 | self.final = nn.Conv2d(out_features, n_classes, kernel_size=1)
100 |
101 | def _make_stage_1(self, in_features, size):
102 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
103 | conv = nn.Conv2d(in_features, in_features//4, kernel_size=1, bias=False)
104 | bn = nn.BatchNorm2d(in_features//4, affine=affine_par)
105 | relu = nn.ReLU(inplace=True)
106 |
107 | return nn.Sequential(prior, conv, bn, relu)
108 |
109 | def _make_stage_2(self, in_features, out_features):
110 | conv = nn.Conv2d(in_features, out_features, kernel_size=1, bias=False)
111 | bn = nn.BatchNorm2d(out_features, affine=affine_par)
112 | relu = nn.ReLU(inplace=True)
113 |
114 | return nn.Sequential(conv, bn, relu)
115 |
116 | def forward(self, feats):
117 | h, w = feats.size(2), feats.size(3)
118 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages]
119 | priors.append(feats)
120 | bottle = self.relu(self.bottleneck(torch.cat(priors, 1)))
121 | out = self.final(bottle)
122 |
123 | return out
124 |
125 |
126 | class ResNet(nn.Module):
127 | def __init__(self, block, layers, n_classes, nInputChannels=3, classifier="atrous",
128 | dilations=(2, 4), strides=(2, 2, 2, 1, 1), _print=False):
129 | if _print:
130 | print("Constructing ResNet model...")
131 | print("Dilations: {}".format(dilations))
132 | print("Number of classes: {}".format(n_classes))
133 | print("Number of Input Channels: {}".format(nInputChannels))
134 | self.inplanes = 64
135 | self.classifier = classifier
136 | super(ResNet, self).__init__()
137 | self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=strides[0], padding=3,
138 | bias=False)
139 | self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
140 | for i in self.bn1.parameters():
141 | i.requires_grad = False
142 | self.relu = nn.ReLU(inplace=True)
143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=strides[1], padding=1, ceil_mode=False)
144 | self.layer1 = self._make_layer(block, 64, layers[0])
145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[2])
146 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[3], dilation__=dilations[0])
147 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[4], dilation__=dilations[1])
148 |
149 | if classifier == "atrous":
150 | if _print:
151 | print('Initializing classifier: A-trous pyramid')
152 | self.layer5 = self._make_pred_layer(ClassifierModule, [6, 12, 18, 24], [6, 12, 18, 24], n_classes=n_classes)
153 | elif classifier == "psp":
154 | if _print:
155 | print('Initializing classifier: PSP')
156 | self.layer5 = PSPModule(in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=n_classes)
157 | else:
158 | self.layer5 = None
159 | for m in self.modules():
160 | if isinstance(m, nn.Conv2d):
161 | m.weight.data.normal_(0, 0.01)
162 | elif isinstance(m, nn.BatchNorm2d):
163 | m.weight.data.fill_(1)
164 | m.bias.data.zero_()
165 |
166 | def _make_layer(self, block, planes, blocks, stride=1, dilation__=1):
167 | downsample = None
168 | if stride != 1 or self.inplanes != planes * block.expansion or dilation__ == 2 or dilation__ == 4:
169 | downsample = nn.Sequential(
170 | nn.Conv2d(self.inplanes, planes * block.expansion,
171 | kernel_size=1, stride=stride, bias=False),
172 | nn.BatchNorm2d(planes * block.expansion, affine=affine_par),
173 | )
174 | for i in downsample._modules['1'].parameters():
175 | i.requires_grad = False
176 | layers = [block(self.inplanes, planes, stride, dilation_=dilation__, downsample=downsample)]
177 | self.inplanes = planes * block.expansion
178 | for i in range(1, blocks):
179 | layers.append(block(self.inplanes, planes, dilation_=dilation__))
180 |
181 | return nn.Sequential(*layers)
182 |
183 | def _make_pred_layer(self, block, dilation_series, padding_series, n_classes):
184 | return block(dilation_series, padding_series, n_classes)
185 |
186 | def forward(self, x, bbox=None):
187 | x = self.conv1(x)
188 | x = self.bn1(x)
189 | x = self.relu(x)
190 | x = self.maxpool(x)
191 | x = self.layer1(x)
192 | x = self.layer2(x)
193 | x = self.layer3(x)
194 | x = self.layer4(x)
195 | if self.layer5 is not None:
196 | x = self.layer5(x)
197 | return x
198 |
199 | def load_pretrained_ms(self, base_network, nInputChannels=3):
200 | flag = 0
201 | for module, module_ori in zip(self.modules(), base_network.Scale.modules()):
202 | if isinstance(module, nn.Conv2d) and isinstance(module_ori, nn.Conv2d):
203 | if not flag and nInputChannels != 3:
204 | module.weight[:, :3, :, :].data = deepcopy(module_ori.weight.data)
205 | module.bias = deepcopy(module_ori.bias)
206 | for i in range(3, int(module.weight.data.shape[1])):
207 | module.weight[:, i, :, :].data = deepcopy(module_ori.weight[:, -1, :, :][:, np.newaxis, :, :].data)
208 | flag = 1
209 | elif module.weight.data.shape == module_ori.weight.data.shape:
210 | module.weight.data = deepcopy(module_ori.weight.data)
211 | module.bias = deepcopy(module_ori.bias)
212 | else:
213 | print('Skipping Conv layer with size: {} and target size: {}'
214 | .format(module.weight.data.shape, module_ori.weight.data.shape))
215 | elif isinstance(module, nn.BatchNorm2d) and isinstance(module_ori, nn.BatchNorm2d) \
216 | and module.weight.data.shape == module_ori.weight.data.shape:
217 | module.weight.data = deepcopy(module_ori.weight.data)
218 | module.bias.data = deepcopy(module_ori.bias.data)
219 |
220 |
221 | def resnet101(n_classes, pretrained=False, nInputChannels=3, classifier="atrous",
222 | dilations=(2, 4), strides=(2, 2, 2, 1, 1)):
223 | """Constructs a ResNet-101 model.
224 | """
225 | model = ResNet(Bottleneck, [3, 4, 23, 3], n_classes, nInputChannels=nInputChannels,
226 | classifier=classifier, dilations=dilations, strides=strides, _print=True)
227 | if pretrained:
228 | model_full = Res_Deeplab(n_classes, pretrained=pretrained)
229 | model.load_pretrained_ms(model_full, nInputChannels=nInputChannels)
230 | return model
231 |
232 |
233 | class MS_Deeplab(nn.Module):
234 | def __init__(self, block, NoLabels, nInputChannels=3):
235 | super(MS_Deeplab, self).__init__()
236 | self.Scale = ResNet(block, [3, 4, 23, 3], NoLabels, nInputChannels=nInputChannels)
237 |
238 | def forward(self, x):
239 | input_size = x.size()[2]
240 | self.interp1 = nn.Upsample(size=(int(input_size*0.75)+1, int(input_size*0.75)+1), mode='bilinear', align_corners=True)
241 | self.interp2 = nn.Upsample(size=(int(input_size*0.5)+1, int(input_size*0.5)+1), mode='bilinear', align_corners=True)
242 | self.interp3 = nn.Upsample(size=(outS(input_size), outS(input_size)), mode='bilinear', align_corners=True)
243 | out = []
244 | x2 = self.interp1(x)
245 | x3 = self.interp2(x)
246 | out.append(self.Scale(x)) # for original scale
247 | out.append(self.interp3(self.Scale(x2))) # for 0.75x scale
248 | out.append(self.Scale(x3)) # for 0.5x scale
249 |
250 | x2Out_interp = out[1]
251 | x3Out_interp = self.interp3(out[2])
252 | temp1 = torch.max(out[0], x2Out_interp)
253 | out.append(torch.max(temp1, x3Out_interp))
254 | return out[-1]
255 |
256 |
257 | def Res_Deeplab(n_classes=21, pretrained=False):
258 | model = MS_Deeplab(Bottleneck, n_classes)
259 | if pretrained:
260 | pth_model = 'MS_DeepLab_resnet_trained_VOC.pth'
261 | saved_state_dict = torch.load(os.path.join(Path.models_dir(), pth_model),
262 | map_location=lambda storage, loc: storage)
263 | if n_classes != 21:
264 | for i in saved_state_dict:
265 | i_parts = i.split('.')
266 | if i_parts[1] == 'layer5':
267 | saved_state_dict[i] = model.state_dict()[i]
268 | model.load_state_dict(saved_state_dict)
269 | return model
270 |
271 |
272 | def get_lr_params(model):
273 | """
274 | This generator returns all the parameters of the net except for
275 | the last classification layer. Note that for each batchnorm layer,
276 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
277 | any batchnorm parameter
278 | """
279 | b = [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, model.layer4, model.layer5]
280 | for i in range(len(b)):
281 | for k in b[i].parameters():
282 | if k.requires_grad:
283 | yield k
284 |
285 |
286 | def get_1x_lr_params(model):
287 | """
288 | This generator returns all the parameters of the net except for
289 | the last classification layer. Note that for each batchnorm layer,
290 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
291 | any batchnorm parameter
292 | """
293 | b = [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, model.layer4]
294 | for i in range(len(b)):
295 | for k in b[i].parameters():
296 | if k.requires_grad:
297 | yield k
298 |
299 |
300 | def get_10x_lr_params(model):
301 | """
302 | This generator returns all the parameters for the last layer of the net,
303 | which does the classification of pixel into classes
304 | """
305 | b = [model.layer5]
306 | for j in range(len(b)):
307 | for k in b[j].parameters():
308 | if k.requires_grad:
309 | yield k
310 |
311 |
312 | def lr_poly(base_lr, iter_, max_iter=100, power=0.9):
313 | return base_lr*((1-float(iter_)/max_iter)**power)
314 |
--------------------------------------------------------------------------------