├── layers ├── __init__.py └── loss.py ├── dataloaders ├── __init__.py ├── pascal_map.npy ├── combine_dbs.py ├── sbd.py ├── custom_transforms.py ├── pascal.py └── helpers.py ├── evaluation ├── __init__.py ├── evaluation.py └── eval.py ├── networks ├── __init__.py └── deeplab_resnet.py ├── ims ├── bear.jpg └── dog-cat.jpg ├── doc ├── dextr.png └── github_teaser.gif ├── media ├── out1.png ├── out1_txt.png └── out1.txt ├── requirements.txt ├── models ├── download_pretrained_psp_model.sh └── download_dextr_model.sh ├── mypath.py ├── eval_all.py ├── README.md ├── demo.py └── train_pascal.py /layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ims/bear.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/ims/bear.jpg -------------------------------------------------------------------------------- /doc/dextr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/doc/dextr.png -------------------------------------------------------------------------------- /ims/dog-cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/ims/dog-cat.jpg -------------------------------------------------------------------------------- /media/out1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/media/out1.png -------------------------------------------------------------------------------- /media/out1_txt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/media/out1_txt.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | opencv-python 3 | pillow 4 | scikit-learn 5 | scikit-image 6 | -------------------------------------------------------------------------------- /doc/github_teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/doc/github_teaser.gif -------------------------------------------------------------------------------- /dataloaders/pascal_map.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/etosworld/etos-deepcut/HEAD/dataloaders/pascal_map.npy -------------------------------------------------------------------------------- /models/download_pretrained_psp_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://data.vision.ee.ethz.ch/csergi/share/DEXTR/MS_DeepLab_resnet_trained_VOC.pth 3 | -------------------------------------------------------------------------------- /models/download_dextr_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Model trained on PASCAL + SBD 4 | wget https://data.vision.ee.ethz.ch/csergi/share/DEXTR/dextr_pascal-sbd.pth 5 | -------------------------------------------------------------------------------- /mypath.py: -------------------------------------------------------------------------------- 1 | 2 | class Path(object): 3 | @staticmethod 4 | def db_root_dir(database): 5 | if database == 'pascal': 6 | return '/path/to/PASCAL/VOC2012' # folder that contains VOCdevkit/. 7 | 8 | elif database == 'sbd': 9 | return '/path/to/SBD/' # folder with img/, inst/, cls/, etc. 10 | else: 11 | print('Database {} not available.'.format(database)) 12 | raise NotImplementedError 13 | 14 | @staticmethod 15 | def models_dir(): 16 | return 'models/' 17 | -------------------------------------------------------------------------------- /evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def jaccard(annotation, segmentation, void_pixels=None): 5 | 6 | assert(annotation.shape == segmentation.shape) 7 | 8 | if void_pixels is None: 9 | void_pixels = np.zeros_like(annotation) 10 | assert(void_pixels.shape == annotation.shape) 11 | 12 | annotation = annotation.astype(np.bool) 13 | segmentation = segmentation.astype(np.bool) 14 | void_pixels = void_pixels.astype(np.bool) 15 | if np.isclose(np.sum(annotation & np.logical_not(void_pixels)), 0) and np.isclose(np.sum(segmentation & np.logical_not(void_pixels)), 0): 16 | return 1 17 | else: 18 | return np.sum(((annotation & segmentation) & np.logical_not(void_pixels))) / \ 19 | np.sum(((annotation | segmentation) & np.logical_not(void_pixels)), dtype=np.float32) 20 | 21 | -------------------------------------------------------------------------------- /media/out1.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | : ;*~ 7 | *XO*: ;:~QzC 8 | YuUbbpqbwXXb 9 | ~Zwbbpdbbppa 10 | :adQzZYUOvQbo,,: 11 | hqZJUr))tYCZwp~~*` 12 | :mzf}]]}{}1{\tJkh~` 13 | p(?_+_?{)|}]}vdboh. 14 | ^bqL/\\+_-?|-]umpkha~ 15 | ;akbUvx]-?]]}{{f0wOutxq 16 | ,hkz{{1)/Cpdddk*,^mr)|U 17 | *Y|}]][|k U1)u 18 | ;aL{--1. o/{n 19 | -------------------------------------------------------------------------------- /eval_all.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from torch.utils.data import DataLoader 4 | from evaluation.eval import eval_one_result 5 | import dataloaders.pascal as pascal 6 | 7 | exp_root_dir = './' 8 | 9 | method_names = [] 10 | method_names.append('run_0') 11 | 12 | if __name__ == '__main__': 13 | 14 | # Dataloader 15 | dataset = pascal.VOCSegmentation(transform=None, retname=True) 16 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) 17 | 18 | # Iterate through all the different methods 19 | for method in method_names: 20 | results_folder = os.path.join(exp_root_dir, method, 'Results') 21 | 22 | filename = os.path.join(exp_root_dir, 'eval_results', method.replace('/', '-') + '.txt') 23 | if not os.path.exists(os.path.join(exp_root_dir, 'eval_results')): 24 | os.makedirs(os.path.join(exp_root_dir, 'eval_results')) 25 | 26 | if os.path.isfile(filename): 27 | with open(filename, 'r') as f: 28 | val = float(f.read()) 29 | else: 30 | print("Evaluating method: {}".format(method)) 31 | jaccards = eval_one_result(dataloader, results_folder, mask_thres=0.8) 32 | val = jaccards["all_jaccards"].mean() 33 | 34 | # Show mean and store result 35 | print("Result for {:<80}: {}".format(method, str.format("{0:.1f}", 100*val))) 36 | with open(filename, 'w') as f: 37 | f.write(str(val)) 38 | -------------------------------------------------------------------------------- /layers/loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def class_balanced_cross_entropy_loss(output, label, size_average=True, batch_average=True, void_pixels=None): 8 | """Define the class balanced cross entropy loss to train the network 9 | Args: 10 | output: Output of the network 11 | label: Ground truth label 12 | size_average: return per-element (pixel) average loss 13 | batch_average: return per-batch average loss 14 | void_pixels: pixels to ignore from the loss 15 | Returns: 16 | Tensor that evaluates the loss 17 | """ 18 | assert(output.size() == label.size()) 19 | 20 | labels = torch.ge(label, 0.5).float() 21 | 22 | num_labels_pos = torch.sum(labels) 23 | num_labels_neg = torch.sum(1.0 - labels) 24 | num_total = num_labels_pos + num_labels_neg 25 | 26 | output_gt_zero = torch.ge(output, 0).float() 27 | loss_val = torch.mul(output, (labels - output_gt_zero)) - torch.log( 28 | 1 + torch.exp(output - 2 * torch.mul(output, output_gt_zero))) 29 | 30 | loss_pos_pix = -torch.mul(labels, loss_val) 31 | loss_neg_pix = -torch.mul(1.0 - labels, loss_val) 32 | 33 | if void_pixels is not None: 34 | w_void = torch.le(void_pixels, 0.5).float() 35 | loss_pos_pix = torch.mul(w_void, loss_pos_pix) 36 | loss_neg_pix = torch.mul(w_void, loss_neg_pix) 37 | num_total = num_total - torch.ge(void_pixels, 0.5).float().sum() 38 | 39 | loss_pos = torch.sum(loss_pos_pix) 40 | loss_neg = torch.sum(loss_neg_pix) 41 | 42 | final_loss = num_labels_neg / num_total * loss_pos + num_labels_pos / num_total * loss_neg 43 | 44 | if size_average: 45 | final_loss /= np.prod(label.size()) 46 | elif batch_average: 47 | final_loss /= label.size()[0] 48 | 49 | return final_loss 50 | -------------------------------------------------------------------------------- /evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import dataloaders.helpers as helpers 8 | import evaluation.evaluation as evaluation 9 | 10 | 11 | def eval_one_result(loader, folder, one_mask_per_image=False, mask_thres=0.5, use_void_pixels=True, custom_box=False): 12 | def mAPr(per_cat, thresholds): 13 | n_cat = len(per_cat) 14 | all_apr = np.zeros(len(thresholds)) 15 | for ii, th in enumerate(thresholds): 16 | per_cat_recall = np.zeros(n_cat) 17 | for jj, categ in enumerate(per_cat.keys()): 18 | per_cat_recall[jj] = np.sum(np.array(per_cat[categ]) > th)/len(per_cat[categ]) 19 | 20 | all_apr[ii] = per_cat_recall.mean() 21 | 22 | return all_apr.mean() 23 | 24 | # Allocate 25 | eval_result = dict() 26 | eval_result["all_jaccards"] = np.zeros(len(loader)) 27 | eval_result["all_percent"] = np.zeros(len(loader)) 28 | eval_result["meta"] = [] 29 | eval_result["per_categ_jaccard"] = dict() 30 | 31 | # Iterate 32 | for i, sample in enumerate(loader): 33 | 34 | if i % 500 == 0: 35 | print('Evaluating: {} of {} objects'.format(i, len(loader))) 36 | 37 | # Load result 38 | if not one_mask_per_image: 39 | filename = os.path.join(folder, 40 | sample["meta"]["image"][0] + '-' + sample["meta"]["object"][0] + '.png') 41 | else: 42 | filename = os.path.join(folder, 43 | sample["meta"]["image"][0] + '.png') 44 | mask = np.array(Image.open(filename)).astype(np.float32) / 255. 45 | gt = np.squeeze(helpers.tens2image(sample["gt"])) 46 | if use_void_pixels: 47 | void_pixels = np.squeeze(helpers.tens2image(sample["void_pixels"])) 48 | if mask.shape != gt.shape: 49 | mask = cv2.resize(mask, gt.shape[::-1], interpolation=cv2.INTER_CUBIC) 50 | 51 | # Threshold 52 | mask = (mask > mask_thres) 53 | if use_void_pixels: 54 | void_pixels = (void_pixels > 0.5) 55 | 56 | # Evaluate 57 | if use_void_pixels: 58 | eval_result["all_jaccards"][i] = evaluation.jaccard(gt, mask, void_pixels) 59 | else: 60 | eval_result["all_jaccards"][i] = evaluation.jaccard(gt, mask) 61 | 62 | if custom_box: 63 | box = np.squeeze(helpers.tens2image(sample["box"])) 64 | bb = helpers.get_bbox(box) 65 | else: 66 | bb = helpers.get_bbox(gt) 67 | 68 | mask_crop = helpers.crop_from_bbox(mask, bb) 69 | if use_void_pixels: 70 | non_void_pixels_crop = helpers.crop_from_bbox(np.logical_not(void_pixels), bb) 71 | gt_crop = helpers.crop_from_bbox(gt, bb) 72 | if use_void_pixels: 73 | eval_result["all_percent"][i] = np.sum((gt_crop != mask_crop) & non_void_pixels_crop)/np.sum(non_void_pixels_crop) 74 | else: 75 | eval_result["all_percent"][i] = np.sum((gt_crop != mask_crop))/mask_crop.size 76 | # Store in per category 77 | if "category" in sample["meta"]: 78 | cat = sample["meta"]["category"][0] 79 | else: 80 | cat = 1 81 | if cat not in eval_result["per_categ_jaccard"]: 82 | eval_result["per_categ_jaccard"][cat] = [] 83 | eval_result["per_categ_jaccard"][cat].append(eval_result["all_jaccards"][i]) 84 | 85 | # Store meta 86 | eval_result["meta"].append(sample["meta"]) 87 | 88 | # Compute some stats 89 | eval_result["mAPr0.5"] = mAPr(eval_result["per_categ_jaccard"], [0.5]) 90 | eval_result["mAPr0.7"] = mAPr(eval_result["per_categ_jaccard"], [0.7]) 91 | eval_result["mAPr-vol"] = mAPr(eval_result["per_categ_jaccard"], np.linspace(0.1, 0.9, 9)) 92 | 93 | return eval_result 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /dataloaders/combine_dbs.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | 4 | class CombineDBs(data.Dataset): 5 | def __init__(self, dataloaders, excluded=None): 6 | self.dataloaders = dataloaders 7 | self.excluded = excluded 8 | self.im_ids = [] 9 | 10 | # Combine object lists 11 | for dl in dataloaders: 12 | for elem in dl.im_ids: 13 | if elem not in self.im_ids: 14 | self.im_ids.append(elem) 15 | 16 | # Exclude 17 | if excluded: 18 | for dl in excluded: 19 | for elem in dl.im_ids: 20 | if elem in self.im_ids: 21 | self.im_ids.remove(elem) 22 | 23 | # Get object pointers 24 | self.obj_list = [] 25 | self.im_list = [] 26 | new_im_ids = [] 27 | obj_counter = 0 28 | num_images = 0 29 | for ii, dl in enumerate(dataloaders): 30 | for jj, curr_im_id in enumerate(dl.im_ids): 31 | if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids): 32 | flag = False 33 | new_im_ids.append(curr_im_id) 34 | for kk in range(len(dl.obj_dict[curr_im_id])): 35 | if dl.obj_dict[curr_im_id][kk] != -1: 36 | self.obj_list.append({'db_ii': ii, 'obj_ii': dl.obj_list.index([jj, kk])}) 37 | flag = True 38 | obj_counter += 1 39 | self.im_list.append({'db_ii': ii, 'im_ii': jj}) 40 | if flag: 41 | num_images += 1 42 | 43 | self.im_ids = new_im_ids 44 | print('Combined number of images: {:d}\nCombined number of objects: {:d}'.format(num_images, len(self.obj_list))) 45 | 46 | def __getitem__(self, index): 47 | 48 | _db_ii = self.obj_list[index]["db_ii"] 49 | _obj_ii = self.obj_list[index]['obj_ii'] 50 | sample = self.dataloaders[_db_ii].__getitem__(_obj_ii) 51 | 52 | if 'meta' in sample.keys(): 53 | sample['meta']['db'] = str(self.dataloaders[_db_ii]) 54 | 55 | return sample 56 | 57 | def __len__(self): 58 | return len(self.obj_list) 59 | 60 | def __str__(self): 61 | include_db = [str(db) for db in self.dataloaders] 62 | exclude_db = [str(db) for db in self.excluded] 63 | return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db) 64 | 65 | 66 | if __name__ == "__main__": 67 | import matplotlib.pyplot as plt 68 | import dataloaders.helpers as helpers 69 | import dataloaders.pascal as pascal 70 | import dataloaders.sbd as sbd 71 | import torch 72 | import numpy as np 73 | import dataloaders.custom_transforms as tr 74 | from torchvision import transforms 75 | 76 | transform = transforms.Compose([tr.FixedResize({'image': (512, 512), 'gt': (512, 512)}), 77 | tr.ToTensor()]) 78 | 79 | pascal_voc_val = pascal.VOCSegmentation(split='val', transform=transform, retname=True) 80 | sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=transform, retname=True) 81 | pascal_voc_train = pascal.VOCSegmentation(split='train', transform=transform, retname=True) 82 | 83 | dataset = CombineDBs([pascal_voc_train, sbd], excluded=[pascal_voc_val]) 84 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0) 85 | 86 | for ii, sample in enumerate(dataloader): 87 | for jj in range(sample["image"].size()[0]): 88 | plt.figure() 89 | max_img = np.max(sample["image"][jj].numpy()) 90 | overlay = helpers.overlay_mask(helpers.tens2image(sample["image"][jj])/max_img, helpers.tens2image(sample["gt"][jj])) 91 | plt.imshow(overlay) 92 | plt.title(sample["meta"]) 93 | if ii == 5: 94 | break 95 | 96 | plt.show(block=True) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ### etos-deepcut 3 | 4 | A tool for object segmentation from extreme points. 5 | 6 | ### About Deep Extreme Cut (DEXTR) 7 | 8 | Visit [project page](http://www.vision.ee.ethz.ch/~cvlsegmentation/dextr) for accessing the paper, and the pre-computed results. 9 | 10 | ![DEXTR](doc/dextr.png) 11 | 12 | This is based the implementation of `Deep Extreme Cut (DEXTR)`, for object segmentation from extreme points. 13 | 14 | ### Installation 15 | 16 | 0. Clone the repo: 17 | ```Shell 18 | git clone https://github.com/etosworld/etos-deepcut 19 | cd etos-deepcut 20 | ``` 21 | 22 | 1. Install dependencies: 23 | ```Shell 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | 2. Download the model by running the script inside ```models/```: 28 | ```Shell 29 | cd models/ 30 | chmod +x download_dextr_model.sh 31 | ./download_dextr_model.sh 32 | cd .. 33 | ``` 34 | The default model is trained on PASCAL VOC Segmentation train + SBD (10582 images). To download models trained on PASCAL VOC Segmentation train or COCO, please visit [project page](http://www.vision.ee.ethz.ch/~cvlsegmentation/dextr/#downloads), or keep scrolling till the end of this README. 35 | 36 | 3. To try the demo version of etos-deepcut, please run: 37 | ```Shell 38 | python demo.py 39 | ``` 40 | If installed correctly, the result should look like this: 41 |

42 | 43 | ### image2txt 44 | 45 | we have implemented the [image2txt](https://github.com/etosworld/etos-img2txt) function , for each segmented object, an image and text file would be saved.enjoy! 46 | 47 | ![](media/out1.png) ![](media/out1_txt.png) 48 | 49 | 50 | ### Training 51 | 52 | To train and evaluate etos-deepcut on PASCAL (or PASCAL + SBD), please follow these additional steps: 53 | 54 | 4. Install tensorboard (integrated with PyTorch). 55 | ```Shell 56 | pip install tensorboard tensorboardx 57 | ``` 58 | 59 | 5. Download the pre-trained PSPNet model for semantic segmentation, taken from [this](https://github.com/isht7/pytorch-deeplab-resnet) repository. 60 | ```Shell 61 | cd models/ 62 | chmod +x download_pretrained_psp_model.sh 63 | ./download_pretrained_psp_model.sh 64 | cd .. 65 | ``` 66 | 6. Set the paths in ```mypath.py```, so that they point to the location of PASCAL/SBD dataset. 67 | 68 | 7. Run ```python train_pascal.py```, after changing the default parameters, if necessary (eg. gpu_id). 69 | 70 | Enjoy!! 71 | 72 | ### Pre-trained models 73 | 74 | you can download the following DEXTR models, pre-trained on: 75 | * [PASCAL + SBD](https://data.vision.ee.ethz.ch/kmaninis/share/DEXTR/Downloads/models/dextr_pascal-sbd.pth), trained on PASCAL VOC Segmentation train + SBD (10582 images). Achieves mIoU of 91.5% on PASCAL VOC Segmentation val. 76 | * [PASCAL](https://data.vision.ee.ethz.ch/kmaninis/share/DEXTR/Downloads/models/dextr_pascal.pth), trained on PASCAL VOC Segmentation train (1464 images). Achieves mIoU of 90.5% on PASCAL VOC Segmentation val. 77 | * [COCO](https://data.vision.ee.ethz.ch/kmaninis/share/DEXTR/Downloads/models/dextr_coco.pth), trained on COCO train 2014 (82783 images). Achieves mIoU of 87.8% on PASCAL VOC Segmentation val. 78 | 79 | ### TODO 80 | 81 | - to support deep extreme video cut 82 | 83 | ### Citation 84 | 85 | @Inproceedings{Man+18, 86 | Title = {Deep Extreme Cut: From Extreme Points to Object Segmentation}, 87 | Author = {K.K. Maninis and S. Caelles and J. Pont-Tuset and L. {Van Gool}}, 88 | Booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 89 | Year = {2018} 90 | } 91 | 92 | @InProceedings{Pap+17, 93 | Title = {Extreme clicking for efficient object annotation}, 94 | Author = {D.P. Papadopoulos and J. Uijlings and F. Keller and V. Ferrari}, 95 | Booktitle = {ICCV}, 96 | Year = {2017} 97 | } 98 | 99 | 100 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from PIL import Image 5 | import numpy as np 6 | from matplotlib import pyplot as plt 7 | 8 | from torch.nn.functional import upsample 9 | 10 | import networks.deeplab_resnet as resnet 11 | from mypath import Path 12 | from dataloaders import helpers as helpers 13 | 14 | modelName = 'dextr_pascal-sbd' 15 | pad = 50 16 | thres = 0.8 17 | gpu_id = 0 18 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") 19 | 20 | #img2txt 21 | CHAR_LIST = " ;:,^`'.*~oahkbdpqwmZO0QLCJUYXzcvunxrjft/\|()1{}[]?-_+\<>i!lI$@B%8&WM#" 22 | num_chars = len(CHAR_LIST) 23 | num_cols = 50 24 | 25 | # Create the network and load the weights 26 | net = resnet.resnet101(1, nInputChannels=4, classifier='psp') 27 | print("Initializing weights from: {}".format(os.path.join(Path.models_dir(), modelName + '.pth'))) 28 | state_dict_checkpoint = torch.load(os.path.join(Path.models_dir(), modelName + '.pth'), 29 | map_location=lambda storage, loc: storage) 30 | # Remove the prefix .module from the model when it is trained using DataParallel 31 | if 'module.' in list(state_dict_checkpoint.keys())[0]: 32 | new_state_dict = OrderedDict() 33 | for k, v in state_dict_checkpoint.items(): 34 | name = k[7:] # remove `module.` from multi-gpu training 35 | new_state_dict[name] = v 36 | else: 37 | new_state_dict = state_dict_checkpoint 38 | net.load_state_dict(new_state_dict) 39 | net.eval() 40 | net.to(device) 41 | 42 | # Read image and click the points 43 | image = np.array(Image.open('ims/dog-cat.jpg')) 44 | plt.ion() 45 | plt.axis('off') 46 | plt.imshow(image) 47 | plt.title('Click the four extreme points of the objects\nHit enter when done (do not close the window)') 48 | 49 | results = [] 50 | 51 | idx = 0 52 | 53 | with torch.no_grad(): 54 | while 1: 55 | 56 | extreme_points_ori = np.array(plt.ginput(4, timeout=0)).astype(np.int) 57 | 58 | # Crop image to the bounding box from the extreme points and resize 59 | bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=pad, zero_pad=True) 60 | crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True) 61 | resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32) 62 | 63 | # Generate extreme point heat map normalized to image values 64 | extreme_points = extreme_points_ori - [np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1])] + [pad, 65 | pad] 66 | extreme_points = (512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int) 67 | extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10) 68 | extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255) 69 | 70 | # Concatenate inputs and convert to tensor 71 | input_dextr = np.concatenate((resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2) 72 | inputs = torch.from_numpy(input_dextr.transpose((2, 0, 1))[np.newaxis, ...]) 73 | 74 | # Run a forward pass 75 | inputs = inputs.to(device) 76 | outputs = net.forward(inputs) 77 | outputs = upsample(outputs, size=(512, 512), mode='bilinear', align_corners=True) 78 | outputs = outputs.to(torch.device('cpu')) 79 | 80 | pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0)) 81 | pred = 1 / (1 + np.exp(-pred)) 82 | pred = np.squeeze(pred) 83 | result = helpers.crop2fullmask(pred, bbox, im_size=image.shape[:2], zero_pad=True, relax=pad) > thres 84 | 85 | #save cut image and image2txt file 86 | #cut_mask = result.astype(int) 87 | height, width,_ = image.shape 88 | out_img = np.zeros((height, width,3),np.uint8) 89 | out_img[result]=image[result] 90 | #print(out_img) 91 | out_gray=np.dot(out_img[...,:3], [0.299, 0.587, 0.114]).astype(np.uint8) 92 | cell_width = width / num_cols 93 | cell_height = 2 * cell_width 94 | num_rows = int(height / cell_height) 95 | if num_cols > width or num_rows > height: 96 | print("Too many columns or rows. Use default setting") 97 | cell_width = 6 98 | cell_height = 12 99 | num_cols = int(width / cell_width) 100 | num_rows = int(height / cell_height) 101 | 102 | idx = idx + 1 103 | output_file = open('out%d.txt'%idx, 'w') 104 | for i in range(num_rows): 105 | for j in range(num_cols): 106 | output_file.write( 107 | CHAR_LIST[min(int(np.mean(out_gray[int(i * cell_height):min(int((i + 1) * cell_height), height), 108 | int(j * cell_width):min(int((j + 1) * cell_width), 109 | width)]) * num_chars / 255), num_chars - 1)]) 110 | output_file.write("\n") 111 | output_file.close() 112 | 113 | 114 | im = Image.fromarray(out_img) 115 | #im = Image.fromarray(out_gray) 116 | 117 | im.save("out%d.png"%idx) 118 | 119 | results.append(result) 120 | 121 | # Plot the results 122 | plt.imshow(helpers.overlay_masks(image / 255, results)) 123 | plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1], 'gx') 124 | -------------------------------------------------------------------------------- /dataloaders/sbd.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch, cv2 4 | import errno 5 | import hashlib 6 | import json 7 | import os 8 | import sys 9 | import tarfile 10 | 11 | import numpy as np 12 | import scipy.io 13 | import torch.utils.data as data 14 | from PIL import Image 15 | from six.moves import urllib 16 | from mypath import Path 17 | 18 | 19 | class SBDSegmentation(data.Dataset): 20 | 21 | URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz" 22 | FILE = "benchmark.tgz" 23 | MD5 = '82b4d87ceb2ed10f6038a1cba92111cb' 24 | 25 | def __init__(self, 26 | root=Path.db_root_dir('sbd'), 27 | split='val', 28 | transform=None, 29 | download=False, 30 | preprocess=False, 31 | area_thres=0, 32 | retname=True): 33 | 34 | # Store parameters 35 | self.root = root 36 | self.transform = transform 37 | if isinstance(split, str): 38 | self.split = [split] 39 | else: 40 | split.sort() 41 | self.split = split 42 | self.area_thres = area_thres 43 | self.retname = retname 44 | 45 | # Where to find things according to the author's structure 46 | self.dataset_dir = os.path.join(self.root, 'benchmark_RELEASE', 'dataset') 47 | _mask_dir = os.path.join(self.dataset_dir, 'inst') 48 | _image_dir = os.path.join(self.dataset_dir, 'img') 49 | 50 | if self.area_thres != 0: 51 | self.obj_list_file = os.path.join(self.dataset_dir, '_'.join(self.split) + '_instances_area_thres-' + 52 | str(area_thres) + '.txt') 53 | else: 54 | self.obj_list_file = os.path.join(self.dataset_dir, '_'.join(self.split) + '_instances' + '.txt') 55 | 56 | # Download dataset? 57 | if download: 58 | self._download() 59 | if not self._check_integrity(): 60 | raise RuntimeError('Dataset file downloaded is corrupted.') 61 | 62 | # Get list of all images from the split and check that the files exist 63 | self.im_ids = [] 64 | self.images = [] 65 | self.masks = [] 66 | for splt in self.split: 67 | with open(os.path.join(self.dataset_dir, splt+'.txt'), "r") as f: 68 | lines = f.read().splitlines() 69 | 70 | for line in lines: 71 | _image = os.path.join(_image_dir, line + ".jpg") 72 | _mask = os.path.join(_mask_dir, line + ".mat") 73 | assert os.path.isfile(_image) 74 | assert os.path.isfile(_mask) 75 | self.im_ids.append(line) 76 | self.images.append(_image) 77 | self.masks.append(_mask) 78 | 79 | assert (len(self.images) == len(self.masks)) 80 | 81 | # Precompute the list of objects and their categories for each image 82 | if (not self._check_preprocess()) or preprocess: 83 | print('Preprocessing SBD dataset, this will take long, but it will be done only once.') 84 | self._preprocess() 85 | 86 | # Build the list of objects 87 | self.obj_list = [] 88 | num_images = 0 89 | for ii in range(len(self.im_ids)): 90 | if self.im_ids[ii] in self.obj_dict.keys(): 91 | flag = False 92 | for jj in range(len(self.obj_dict[self.im_ids[ii]])): 93 | if self.obj_dict[self.im_ids[ii]][jj] != -1: 94 | self.obj_list.append([ii, jj]) 95 | flag = True 96 | if flag: 97 | num_images += 1 98 | 99 | # Display stats 100 | print('Number of images: {:d}\nNumber of objects: {:d}'.format(num_images, len(self.obj_list))) 101 | 102 | def __getitem__(self, index): 103 | 104 | _img, _target = self._make_img_gt_point_pair(index) 105 | _void_pixels = (_target == 255).astype(np.float32) 106 | sample = {'image': _img, 'gt': _target, 'void_pixels': _void_pixels} 107 | 108 | if self.retname: 109 | _im_ii = self.obj_list[index][0] 110 | _obj_ii = self.obj_list[index][1] 111 | sample['meta'] = {'image': str(self.im_ids[_im_ii]), 112 | 'object': str(_obj_ii), 113 | 'im_size': (_img.shape[0], _img.shape[1]), 114 | 'category': self.obj_dict[self.im_ids[_im_ii]][_obj_ii]} 115 | 116 | if self.transform is not None: 117 | sample = self.transform(sample) 118 | 119 | return sample 120 | 121 | def __len__(self): 122 | return len(self.obj_list) 123 | 124 | def _check_integrity(self): 125 | _fpath = os.path.join(self.root, self.FILE) 126 | if not os.path.isfile(_fpath): 127 | print("{} does not exist".format(_fpath)) 128 | return False 129 | _md5c = hashlib.md5(open(_fpath, 'rb').read()).hexdigest() 130 | if _md5c != self.MD5: 131 | print(" MD5({}) did not match MD5({}) expected for {}".format( 132 | _md5c, self.MD5, _fpath)) 133 | return False 134 | return True 135 | 136 | def _check_preprocess(self): 137 | # Check that the file with categories is there and with correct size 138 | _obj_list_file = self.obj_list_file 139 | if not os.path.isfile(_obj_list_file): 140 | return False 141 | else: 142 | self.obj_dict = json.load(open(_obj_list_file, 'r')) 143 | return list(np.sort([str(x) for x in self.obj_dict.keys()])) == list(np.sort(self.im_ids)) 144 | 145 | def _preprocess(self): 146 | # Get all object instances and their category 147 | self.obj_dict = {} 148 | obj_counter = 0 149 | for ii in range(len(self.im_ids)): 150 | # Read object masks and get number of objects 151 | tmp = scipy.io.loadmat(self.masks[ii]) 152 | _mask = tmp["GTinst"][0]["Segmentation"][0] 153 | _cat_ids = tmp["GTinst"][0]["Categories"][0].astype(int) 154 | 155 | _mask_ids = np.unique(_mask) 156 | n_obj = _mask_ids[-1] 157 | assert(n_obj == len(_cat_ids)) 158 | 159 | for jj in range(n_obj): 160 | temp = np.where(_mask == jj + 1) 161 | obj_area = len(temp[0]) 162 | if obj_area < self.area_thres: 163 | _cat_ids[jj] = -1 164 | obj_counter += 1 165 | 166 | self.obj_dict[self.im_ids[ii]] = np.squeeze(_cat_ids, 1).tolist() 167 | 168 | # Save it to file for future reference 169 | with open(self.obj_list_file, 'w') as outfile: 170 | outfile.write('{{\n\t"{:s}": {:s}'.format(self.im_ids[0], json.dumps(self.obj_dict[self.im_ids[0]]))) 171 | for ii in range(1, len(self.im_ids)): 172 | outfile.write(',\n\t"{:s}": {:s}'.format(self.im_ids[ii], json.dumps(self.obj_dict[self.im_ids[ii]]))) 173 | outfile.write('\n}\n') 174 | 175 | print('Pre-processing finished') 176 | 177 | def _download(self): 178 | _fpath = os.path.join(self.root, self.FILE) 179 | 180 | try: 181 | os.makedirs(self.root) 182 | except OSError as e: 183 | if e.errno == errno.EEXIST: 184 | pass 185 | else: 186 | raise 187 | 188 | if self._check_integrity(): 189 | print('Files already downloaded and verified') 190 | return 191 | else: 192 | print('Downloading ' + self.URL + ' to ' + _fpath) 193 | 194 | def _progress(count, block_size, total_size): 195 | sys.stdout.write('\r>> %s %.1f%%' % 196 | (_fpath, float(count * block_size) / 197 | float(total_size) * 100.0)) 198 | sys.stdout.flush() 199 | 200 | urllib.request.urlretrieve(self.URL, _fpath, _progress) 201 | 202 | # extract file 203 | cwd = os.getcwd() 204 | print('Extracting tar file') 205 | tar = tarfile.open(_fpath) 206 | os.chdir(self.root) 207 | tar.extractall() 208 | tar.close() 209 | os.chdir(cwd) 210 | print('Done!') 211 | 212 | def _make_img_gt_point_pair(self, index): 213 | _im_ii = self.obj_list[index][0] 214 | _obj_ii = self.obj_list[index][1] 215 | 216 | # Read Image 217 | _img = np.array(Image.open(self.images[_im_ii]).convert('RGB')).astype(np.float32) 218 | 219 | # Read Taret object 220 | _tmp = scipy.io.loadmat(self.masks[_im_ii])["GTinst"][0]["Segmentation"][0] 221 | _target = (_tmp == (_obj_ii + 1)).astype(np.float32) 222 | 223 | return _img, _target 224 | 225 | def __str__(self): 226 | return 'SBDSegmentation(split='+str(self.split)+', area_thres='+str(self.area_thres)+')' 227 | 228 | 229 | if __name__ == '__main__': 230 | import matplotlib.pyplot as plt 231 | import dataloaders.helpers as helpers 232 | import torch 233 | import torchvision.transforms as transforms 234 | import dataloaders.custom_transforms as tr 235 | 236 | transform = transforms.Compose([tr.ToTensor()]) 237 | dataset = SBDSegmentation(transform=transform) 238 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 239 | 240 | for i, data in enumerate(dataloader): 241 | plt.figure() 242 | plt.imshow(helpers.tens2image(data['image'])/255) 243 | plt.figure() 244 | plt.imshow(helpers.tens2image(data['gt'])[:, :, 0]) 245 | if i == 10: 246 | break 247 | 248 | plt.show(block=True) 249 | -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch, cv2 2 | 3 | import numpy.random as random 4 | import numpy as np 5 | import dataloaders.helpers as helpers 6 | 7 | 8 | class ScaleNRotate(object): 9 | """Scale (zoom-in, zoom-out) and Rotate the image and the ground truth. 10 | Args: 11 | two possibilities: 12 | 1. rots (tuple): (minimum, maximum) rotation angle 13 | scales (tuple): (minimum, maximum) scale 14 | 2. rots [list]: list of fixed possible rotation angles 15 | scales [list]: list of fixed possible scales 16 | """ 17 | def __init__(self, rots=(-30, 30), scales=(.75, 1.25), semseg=False): 18 | assert (isinstance(rots, type(scales))) 19 | self.rots = rots 20 | self.scales = scales 21 | self.semseg = semseg 22 | 23 | def __call__(self, sample): 24 | 25 | if type(self.rots) == tuple: 26 | # Continuous range of scales and rotations 27 | rot = (self.rots[1] - self.rots[0]) * random.random() - \ 28 | (self.rots[1] - self.rots[0])/2 29 | 30 | sc = (self.scales[1] - self.scales[0]) * random.random() - \ 31 | (self.scales[1] - self.scales[0]) / 2 + 1 32 | elif type(self.rots) == list: 33 | # Fixed range of scales and rotations 34 | rot = self.rots[random.randint(0, len(self.rots))] 35 | sc = self.scales[random.randint(0, len(self.scales))] 36 | 37 | for elem in sample.keys(): 38 | if 'meta' in elem: 39 | continue 40 | 41 | tmp = sample[elem] 42 | 43 | h, w = tmp.shape[:2] 44 | center = (w / 2, h / 2) 45 | assert(center != 0) # Strange behaviour warpAffine 46 | M = cv2.getRotationMatrix2D(center, rot, sc) 47 | 48 | if ((tmp == 0) | (tmp == 1)).all(): 49 | flagval = cv2.INTER_NEAREST 50 | elif 'gt' in elem and self.semseg: 51 | flagval = cv2.INTER_NEAREST 52 | else: 53 | flagval = cv2.INTER_CUBIC 54 | tmp = cv2.warpAffine(tmp, M, (w, h), flags=flagval) 55 | 56 | sample[elem] = tmp 57 | 58 | return sample 59 | 60 | def __str__(self): 61 | return 'ScaleNRotate:(rot='+str(self.rots)+',scale='+str(self.scales)+')' 62 | 63 | 64 | class FixedResize(object): 65 | """Resize the image and the ground truth to specified resolution. 66 | Args: 67 | resolutions (dict): the list of resolutions 68 | """ 69 | def __init__(self, resolutions=None, flagvals=None): 70 | self.resolutions = resolutions 71 | self.flagvals = flagvals 72 | if self.flagvals is not None: 73 | assert(len(self.resolutions) == len(self.flagvals)) 74 | 75 | def __call__(self, sample): 76 | 77 | # Fixed range of scales 78 | if self.resolutions is None: 79 | return sample 80 | 81 | elems = list(sample.keys()) 82 | 83 | for elem in elems: 84 | 85 | if 'meta' in elem or 'bbox' in elem or ('extreme_points_coord' in elem and elem not in self.resolutions): 86 | continue 87 | if 'extreme_points_coord' in elem and elem in self.resolutions: 88 | bbox = sample['bbox'] 89 | crop_size = np.array([bbox[3]-bbox[1]+1, bbox[4]-bbox[2]+1]) 90 | res = np.array(self.resolutions[elem]).astype(np.float32) 91 | sample[elem] = np.round(sample[elem]*res/crop_size).astype(np.int) 92 | continue 93 | if elem in self.resolutions: 94 | if self.resolutions[elem] is None: 95 | continue 96 | if isinstance(sample[elem], list): 97 | if sample[elem][0].ndim == 3: 98 | output_size = np.append(self.resolutions[elem], [3, len(sample[elem])]) 99 | else: 100 | output_size = np.append(self.resolutions[elem], len(sample[elem])) 101 | tmp = sample[elem] 102 | sample[elem] = np.zeros(output_size, dtype=np.float32) 103 | for ii, crop in enumerate(tmp): 104 | if self.flagvals is None: 105 | sample[elem][..., ii] = helpers.fixed_resize(crop, self.resolutions[elem]) 106 | else: 107 | sample[elem][..., ii] = helpers.fixed_resize(crop, self.resolutions[elem], flagval=self.flagvals[elem]) 108 | else: 109 | if self.flagvals is None: 110 | sample[elem] = helpers.fixed_resize(sample[elem], self.resolutions[elem]) 111 | else: 112 | sample[elem] = helpers.fixed_resize(sample[elem], self.resolutions[elem], flagval=self.flagvals[elem]) 113 | else: 114 | del sample[elem] 115 | 116 | return sample 117 | 118 | def __str__(self): 119 | return 'FixedResize:'+str(self.resolutions) 120 | 121 | 122 | class RandomHorizontalFlip(object): 123 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5.""" 124 | 125 | def __call__(self, sample): 126 | 127 | if random.random() < 0.5: 128 | for elem in sample.keys(): 129 | if 'meta' in elem: 130 | continue 131 | tmp = sample[elem] 132 | tmp = cv2.flip(tmp, flipCode=1) 133 | sample[elem] = tmp 134 | 135 | return sample 136 | 137 | def __str__(self): 138 | return 'RandomHorizontalFlip' 139 | 140 | 141 | class ExtremePoints(object): 142 | """ 143 | Returns the four extreme points (left, right, top, bottom) (with some random perturbation) in a given binary mask 144 | sigma: sigma of Gaussian to create a heatmap from a point 145 | pert: number of pixels fo the maximum perturbation 146 | elem: which element of the sample to choose as the binary mask 147 | """ 148 | def __init__(self, sigma=10, pert=0, elem='gt'): 149 | self.sigma = sigma 150 | self.pert = pert 151 | self.elem = elem 152 | 153 | def __call__(self, sample): 154 | if sample[self.elem].ndim == 3: 155 | raise ValueError('ExtremePoints not implemented for multiple object per image.') 156 | _target = sample[self.elem] 157 | if np.max(_target) == 0: 158 | sample['extreme_points'] = np.zeros(_target.shape, dtype=_target.dtype) # TODO: handle one_mask_per_point case 159 | else: 160 | _points = helpers.extreme_points(_target, self.pert) 161 | sample['extreme_points'] = helpers.make_gt(_target, _points, sigma=self.sigma, one_mask_per_point=False) 162 | 163 | return sample 164 | 165 | def __str__(self): 166 | return 'ExtremePoints:(sigma='+str(self.sigma)+', pert='+str(self.pert)+', elem='+str(self.elem)+')' 167 | 168 | 169 | class ConcatInputs(object): 170 | 171 | def __init__(self, elems=('image', 'point')): 172 | self.elems = elems 173 | 174 | def __call__(self, sample): 175 | 176 | res = sample[self.elems[0]] 177 | 178 | for elem in self.elems[1:]: 179 | assert(sample[self.elems[0]].shape[:2] == sample[elem].shape[:2]) 180 | 181 | # Check if third dimension is missing 182 | tmp = sample[elem] 183 | if tmp.ndim == 2: 184 | tmp = tmp[:, :, np.newaxis] 185 | 186 | res = np.concatenate((res, tmp), axis=2) 187 | 188 | sample['concat'] = res 189 | 190 | return sample 191 | 192 | def __str__(self): 193 | return 'ExtremePoints:'+str(self.elems) 194 | 195 | 196 | class CropFromMask(object): 197 | """ 198 | Returns image cropped in bounding box from a given mask 199 | """ 200 | def __init__(self, crop_elems=('image', 'gt'), 201 | mask_elem='gt', 202 | relax=0, 203 | zero_pad=False): 204 | 205 | self.crop_elems = crop_elems 206 | self.mask_elem = mask_elem 207 | self.relax = relax 208 | self.zero_pad = zero_pad 209 | 210 | def __call__(self, sample): 211 | _target = sample[self.mask_elem] 212 | if _target.ndim == 2: 213 | _target = np.expand_dims(_target, axis=-1) 214 | for elem in self.crop_elems: 215 | _img = sample[elem] 216 | _crop = [] 217 | if self.mask_elem == elem: 218 | if _img.ndim == 2: 219 | _img = np.expand_dims(_img, axis=-1) 220 | for k in range(0, _target.shape[-1]): 221 | _tmp_img = _img[..., k] 222 | _tmp_target = _target[..., k] 223 | if np.max(_target[..., k]) == 0: 224 | _crop.append(np.zeros(_tmp_img.shape, dtype=_img.dtype)) 225 | else: 226 | _crop.append(helpers.crop_from_mask(_tmp_img, _tmp_target, relax=self.relax, zero_pad=self.zero_pad)) 227 | else: 228 | for k in range(0, _target.shape[-1]): 229 | if np.max(_target[..., k]) == 0: 230 | _crop.append(np.zeros(_img.shape, dtype=_img.dtype)) 231 | else: 232 | _tmp_target = _target[..., k] 233 | _crop.append(helpers.crop_from_mask(_img, _tmp_target, relax=self.relax, zero_pad=self.zero_pad)) 234 | if len(_crop) == 1: 235 | sample['crop_' + elem] = _crop[0] 236 | else: 237 | sample['crop_' + elem] = _crop 238 | return sample 239 | 240 | def __str__(self): 241 | return 'CropFromMask:(crop_elems='+str(self.crop_elems)+', mask_elem='+str(self.mask_elem)+\ 242 | ', relax='+str(self.relax)+',zero_pad='+str(self.zero_pad)+')' 243 | 244 | 245 | class ToImage(object): 246 | """ 247 | Return the given elements between 0 and 255 248 | """ 249 | def __init__(self, norm_elem='image', custom_max=255.): 250 | self.norm_elem = norm_elem 251 | self.custom_max = custom_max 252 | 253 | def __call__(self, sample): 254 | if isinstance(self.norm_elem, tuple): 255 | for elem in self.norm_elem: 256 | tmp = sample[elem] 257 | sample[elem] = self.custom_max * (tmp - tmp.min()) / (tmp.max() - tmp.min() + 1e-10) 258 | else: 259 | tmp = sample[self.norm_elem] 260 | sample[self.norm_elem] = self.custom_max * (tmp - tmp.min()) / (tmp.max() - tmp.min() + 1e-10) 261 | return sample 262 | 263 | def __str__(self): 264 | return 'NormalizeImage' 265 | 266 | 267 | class ToTensor(object): 268 | """Convert ndarrays in sample to Tensors.""" 269 | 270 | def __call__(self, sample): 271 | 272 | for elem in sample.keys(): 273 | if 'meta' in elem: 274 | continue 275 | elif 'bbox' in elem: 276 | tmp = sample[elem] 277 | sample[elem] = torch.from_numpy(tmp) 278 | continue 279 | 280 | tmp = sample[elem] 281 | 282 | if tmp.ndim == 2: 283 | tmp = tmp[:, :, np.newaxis] 284 | 285 | # swap color axis because 286 | # numpy image: H x W x C 287 | # torch image: C X H X W 288 | tmp = tmp.transpose((2, 0, 1)) 289 | sample[elem] = torch.from_numpy(tmp) 290 | 291 | return sample 292 | 293 | def __str__(self): 294 | return 'ToTensor' 295 | -------------------------------------------------------------------------------- /train_pascal.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import timeit 3 | from datetime import datetime 4 | import scipy.misc as sm 5 | from collections import OrderedDict 6 | import glob 7 | 8 | # PyTorch includes 9 | import torch.optim as optim 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader 12 | from torch.nn.functional import upsample 13 | 14 | # Tensorboard include 15 | from tensorboardX import SummaryWriter 16 | 17 | # Custom includes 18 | from dataloaders.combine_dbs import CombineDBs as combine_dbs 19 | import dataloaders.pascal as pascal 20 | import dataloaders.sbd as sbd 21 | from dataloaders import custom_transforms as tr 22 | import networks.deeplab_resnet as resnet 23 | from layers.loss import class_balanced_cross_entropy_loss 24 | from dataloaders.helpers import * 25 | 26 | # Set gpu_id to -1 to run in CPU mode, otherwise set the id of the corresponding gpu 27 | gpu_id = 0 28 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") 29 | if torch.cuda.is_available(): 30 | print('Using GPU: {} '.format(gpu_id)) 31 | 32 | # Setting parameters 33 | use_sbd = False 34 | nEpochs = 100 # Number of epochs for training 35 | resume_epoch = 0 # Default is 0, change if want to resume 36 | 37 | p = OrderedDict() # Parameters to include in report 38 | classifier = 'psp' # Head classifier to use 39 | p['trainBatch'] = 5 # Training batch size 40 | testBatch = 5 # Testing batch size 41 | useTest = 1 # See evolution of the test set when training? 42 | nTestInterval = 10 # Run on test set every nTestInterval epochs 43 | snapshot = 20 # Store a model every snapshot epochs 44 | relax_crop = 50 # Enlarge the bounding box by relax_crop pixels 45 | nInputChannels = 4 # Number of input channels (RGB + heatmap of extreme points) 46 | zero_pad_crop = True # Insert zero padding when cropping the image 47 | p['nAveGrad'] = 1 # Average the gradient of several iterations 48 | p['lr'] = 1e-8 # Learning rate 49 | p['wd'] = 0.0005 # Weight decay 50 | p['momentum'] = 0.9 # Momentum 51 | 52 | # Results and model directories (a new directory is generated for every run) 53 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) 54 | exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] 55 | if resume_epoch == 0: 56 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*'))) 57 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 58 | else: 59 | run_id = 0 60 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id)) 61 | if not os.path.exists(os.path.join(save_dir, 'models')): 62 | os.makedirs(os.path.join(save_dir, 'models')) 63 | 64 | # Network definition 65 | modelName = 'dextr_pascal' 66 | net = resnet.resnet101(1, pretrained=True, nInputChannels=nInputChannels, classifier=classifier) 67 | if resume_epoch == 0: 68 | print("Initializing from pretrained Deeplab-v2 model") 69 | else: 70 | print("Initializing weights from: {}".format( 71 | os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'))) 72 | net.load_state_dict( 73 | torch.load(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'), 74 | map_location=lambda storage, loc: storage)) 75 | train_params = [{'params': resnet.get_1x_lr_params(net), 'lr': p['lr']}, 76 | {'params': resnet.get_10x_lr_params(net), 'lr': p['lr'] * 10}] 77 | 78 | net.to(device) 79 | 80 | # Training the network 81 | if resume_epoch != nEpochs: 82 | # Logging into Tensorboard 83 | log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) 84 | writer = SummaryWriter(log_dir=log_dir) 85 | 86 | # Use the following optimizer 87 | optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd']) 88 | p['optimizer'] = str(optimizer) 89 | 90 | # Preparation of the data loaders 91 | composed_transforms_tr = transforms.Compose([ 92 | tr.RandomHorizontalFlip(), 93 | tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)), 94 | tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop), 95 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}), 96 | tr.ExtremePoints(sigma=10, pert=5, elem='crop_gt'), 97 | tr.ToImage(norm_elem='extreme_points'), 98 | tr.ConcatInputs(elems=('crop_image', 'extreme_points')), 99 | tr.ToTensor()]) 100 | composed_transforms_ts = transforms.Compose([ 101 | tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop), 102 | tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}), 103 | tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'), 104 | tr.ToImage(norm_elem='extreme_points'), 105 | tr.ConcatInputs(elems=('crop_image', 'extreme_points')), 106 | tr.ToTensor()]) 107 | 108 | voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr) 109 | voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts) 110 | 111 | if use_sbd: 112 | sbd = sbd.SBDSegmentation(split=['train', 'val'], transform=composed_transforms_tr, retname=True) 113 | db_train = combine_dbs([voc_train, sbd], excluded=[voc_val]) 114 | else: 115 | db_train = voc_train 116 | 117 | p['dataset_train'] = str(db_train) 118 | p['transformations_train'] = [str(tran) for tran in composed_transforms_tr.transforms] 119 | p['dataset_test'] = str(db_train) 120 | p['transformations_test'] = [str(tran) for tran in composed_transforms_ts.transforms] 121 | 122 | trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=2) 123 | testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=2) 124 | 125 | generate_param_report(os.path.join(save_dir, exp_name + '.txt'), p) 126 | 127 | # Train variables 128 | num_img_tr = len(trainloader) 129 | num_img_ts = len(testloader) 130 | running_loss_tr = 0.0 131 | running_loss_ts = 0.0 132 | aveGrad = 0 133 | print("Training Network") 134 | # Main Training and Testing Loop 135 | for epoch in range(resume_epoch, nEpochs): 136 | start_time = timeit.default_timer() 137 | 138 | net.train() 139 | for ii, sample_batched in enumerate(trainloader): 140 | 141 | inputs, gts = sample_batched['concat'], sample_batched['crop_gt'] 142 | 143 | # Forward-Backward of the mini-batch 144 | inputs.requires_grad_() 145 | inputs, gts = inputs.to(device), gts.to(device) 146 | 147 | output = net.forward(inputs) 148 | output = upsample(output, size=(512, 512), mode='bilinear', align_corners=True) 149 | 150 | # Compute the losses, side outputs and fuse 151 | loss = class_balanced_cross_entropy_loss(output, gts, size_average=False, batch_average=True) 152 | running_loss_tr += loss.item() 153 | 154 | # Print stuff 155 | if ii % num_img_tr == num_img_tr - 1: 156 | running_loss_tr = running_loss_tr / num_img_tr 157 | writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch) 158 | print('[Epoch: %d, numImages: %5d]' % (epoch, ii*p['trainBatch']+inputs.data.shape[0])) 159 | print('Loss: %f' % running_loss_tr) 160 | running_loss_tr = 0 161 | stop_time = timeit.default_timer() 162 | print("Execution time: " + str(stop_time - start_time)+"\n") 163 | 164 | # Backward the averaged gradient 165 | loss /= p['nAveGrad'] 166 | loss.backward() 167 | aveGrad += 1 168 | 169 | # Update the weights once in p['nAveGrad'] forward passes 170 | if aveGrad % p['nAveGrad'] == 0: 171 | writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch) 172 | optimizer.step() 173 | optimizer.zero_grad() 174 | aveGrad = 0 175 | 176 | # Save the model 177 | if (epoch % snapshot) == snapshot - 1 and epoch != 0: 178 | torch.save(net.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')) 179 | 180 | # One testing epoch 181 | if useTest and epoch % nTestInterval == (nTestInterval - 1): 182 | net.eval() 183 | with torch.no_grad(): 184 | for ii, sample_batched in enumerate(testloader): 185 | inputs, gts = sample_batched['concat'], sample_batched['crop_gt'] 186 | 187 | # Forward pass of the mini-batch 188 | inputs, gts = inputs.to(device), gts.to(device) 189 | 190 | output = net.forward(inputs) 191 | output = upsample(output, size=(512, 512), mode='bilinear', align_corners=True) 192 | 193 | # Compute the losses, side outputs and fuse 194 | loss = class_balanced_cross_entropy_loss(output, gts, size_average=False) 195 | running_loss_ts += loss.item() 196 | 197 | # Print stuff 198 | if ii % num_img_ts == num_img_ts - 1: 199 | running_loss_ts = running_loss_ts / num_img_ts 200 | print('[Epoch: %d, numImages: %5d]' % (epoch, ii*testBatch+inputs.data.shape[0])) 201 | writer.add_scalar('data/test_loss_epoch', running_loss_ts, epoch) 202 | print('Loss: %f' % running_loss_ts) 203 | running_loss_ts = 0 204 | 205 | writer.close() 206 | 207 | # Generate result of the validation images 208 | net.eval() 209 | composed_transforms_ts = transforms.Compose([ 210 | tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop), 211 | tr.FixedResize(resolutions={'gt': None, 'crop_image': (512, 512), 'crop_gt': (512, 512)}), 212 | tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'), 213 | tr.ToImage(norm_elem='extreme_points'), 214 | tr.ConcatInputs(elems=('crop_image', 'extreme_points')), 215 | tr.ToTensor()]) 216 | db_test = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts, retname=True) 217 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 218 | 219 | save_dir_res = os.path.join(save_dir, 'Results') 220 | if not os.path.exists(save_dir_res): 221 | os.makedirs(save_dir_res) 222 | 223 | print('Testing Network') 224 | with torch.no_grad(): 225 | # Main Testing Loop 226 | for ii, sample_batched in enumerate(testloader): 227 | 228 | inputs, gts, metas = sample_batched['concat'], sample_batched['gt'], sample_batched['meta'] 229 | 230 | # Forward of the mini-batch 231 | inputs = inputs.to(device) 232 | 233 | outputs = net.forward(inputs) 234 | outputs = upsample(outputs, size=(512, 512), mode='bilinear', align_corners=True) 235 | outputs = outputs.to(torch.device('cpu')) 236 | 237 | for jj in range(int(inputs.size()[0])): 238 | pred = np.transpose(outputs.data.numpy()[jj, :, :, :], (1, 2, 0)) 239 | pred = 1 / (1 + np.exp(-pred)) 240 | pred = np.squeeze(pred) 241 | gt = tens2image(gts[jj, :, :, :]) 242 | bbox = get_bbox(gt, pad=relax_crop, zero_pad=zero_pad_crop) 243 | result = crop2fullmask(pred, bbox, gt, zero_pad=zero_pad_crop, relax=relax_crop) 244 | 245 | # Save the result, attention to the index jj 246 | sm.imsave(os.path.join(save_dir_res, metas['image'][jj] + '-' + metas['object'][jj] + '.png'), result) 247 | -------------------------------------------------------------------------------- /dataloaders/pascal.py: -------------------------------------------------------------------------------- 1 | import torch, cv2 2 | import errno 3 | import hashlib 4 | import os 5 | import sys 6 | import tarfile 7 | import numpy as np 8 | 9 | import torch.utils.data as data 10 | from PIL import Image 11 | from six.moves import urllib 12 | import json 13 | from mypath import Path 14 | 15 | 16 | class VOCSegmentation(data.Dataset): 17 | 18 | URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar" 19 | FILE = "VOCtrainval_11-May-2012.tar" 20 | MD5 = '6cd6e144f989b92b3379bac3b3de84fd' 21 | BASE_DIR = 'VOCdevkit/VOC2012' 22 | 23 | category_names = ['background', 24 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 25 | 'bus', 'car', 'cat', 'chair', 'cow', 26 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 27 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 28 | 29 | def __init__(self, 30 | root=Path.db_root_dir('pascal'), 31 | split='val', 32 | transform=None, 33 | download=False, 34 | preprocess=False, 35 | area_thres=0, 36 | retname=True, 37 | suppress_void_pixels=True, 38 | default=False): 39 | 40 | self.root = root 41 | _voc_root = os.path.join(self.root, self.BASE_DIR) 42 | _mask_dir = os.path.join(_voc_root, 'SegmentationObject') 43 | _cat_dir = os.path.join(_voc_root, 'SegmentationClass') 44 | _image_dir = os.path.join(_voc_root, 'JPEGImages') 45 | self.transform = transform 46 | if isinstance(split, str): 47 | self.split = [split] 48 | else: 49 | split.sort() 50 | self.split = split 51 | self.area_thres = area_thres 52 | self.retname = retname 53 | self.suppress_void_pixels = suppress_void_pixels 54 | self.default = default 55 | 56 | # Build the ids file 57 | area_th_str = "" 58 | if self.area_thres != 0: 59 | area_th_str = '_area_thres-' + str(area_thres) 60 | 61 | self.obj_list_file = os.path.join(self.root, self.BASE_DIR, 'ImageSets', 'Segmentation', 62 | '_'.join(self.split) + '_instances' + area_th_str + '.txt') 63 | 64 | if download: 65 | self._download() 66 | 67 | if not self._check_integrity(): 68 | raise RuntimeError('Dataset not found or corrupted.' + 69 | ' You can use download=True to download it') 70 | 71 | # train/val/test splits are pre-cut 72 | _splits_dir = os.path.join(_voc_root, 'ImageSets', 'Segmentation') 73 | 74 | self.im_ids = [] 75 | self.images = [] 76 | self.categories = [] 77 | self.masks = [] 78 | 79 | for splt in self.split: 80 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: 81 | lines = f.read().splitlines() 82 | 83 | for ii, line in enumerate(lines): 84 | _image = os.path.join(_image_dir, line + ".jpg") 85 | _cat = os.path.join(_cat_dir, line + ".png") 86 | _mask = os.path.join(_mask_dir, line + ".png") 87 | assert os.path.isfile(_image) 88 | assert os.path.isfile(_cat) 89 | assert os.path.isfile(_mask) 90 | self.im_ids.append(line.rstrip('\n')) 91 | self.images.append(_image) 92 | self.categories.append(_cat) 93 | self.masks.append(_mask) 94 | 95 | assert (len(self.images) == len(self.masks)) 96 | assert (len(self.images) == len(self.categories)) 97 | 98 | # Precompute the list of objects and their categories for each image 99 | if (not self._check_preprocess()) or preprocess: 100 | print('Preprocessing of PASCAL VOC dataset, this will take long, but it will be done only once.') 101 | self._preprocess() 102 | 103 | # Build the list of objects 104 | self.obj_list = [] 105 | num_images = 0 106 | for ii in range(len(self.im_ids)): 107 | flag = False 108 | for jj in range(len(self.obj_dict[self.im_ids[ii]])): 109 | if self.obj_dict[self.im_ids[ii]][jj] != -1: 110 | self.obj_list.append([ii, jj]) 111 | flag = True 112 | if flag: 113 | num_images += 1 114 | 115 | # Display stats 116 | print('Number of images: {:d}\nNumber of objects: {:d}'.format(num_images, len(self.obj_list))) 117 | 118 | def __getitem__(self, index): 119 | _img, _target, _void_pixels, _, _, _ = self._make_img_gt_point_pair(index) 120 | sample = {'image': _img, 'gt': _target, 'void_pixels': _void_pixels} 121 | 122 | if self.retname: 123 | _im_ii = self.obj_list[index][0] 124 | _obj_ii = self.obj_list[index][1] 125 | sample['meta'] = {'image': str(self.im_ids[_im_ii]), 126 | 'object': str(_obj_ii), 127 | 'category': self.obj_dict[self.im_ids[_im_ii]][_obj_ii], 128 | 'im_size': (_img.shape[0], _img.shape[1])} 129 | 130 | if self.transform is not None: 131 | sample = self.transform(sample) 132 | 133 | return sample 134 | 135 | def __len__(self): 136 | return len(self.obj_list) 137 | 138 | def _check_integrity(self): 139 | _fpath = os.path.join(self.root, self.FILE) 140 | if not os.path.isfile(_fpath): 141 | print("{} does not exist".format(_fpath)) 142 | return False 143 | _md5c = hashlib.md5(open(_fpath, 'rb').read()).hexdigest() 144 | if _md5c != self.MD5: 145 | print(" MD5({}) did not match MD5({}) expected for {}".format( 146 | _md5c, self.MD5, _fpath)) 147 | return False 148 | return True 149 | 150 | def _check_preprocess(self): 151 | _obj_list_file = self.obj_list_file 152 | if not os.path.isfile(_obj_list_file): 153 | return False 154 | else: 155 | self.obj_dict = json.load(open(_obj_list_file, 'r')) 156 | 157 | return list(np.sort([str(x) for x in self.obj_dict.keys()])) == list(np.sort(self.im_ids)) 158 | 159 | def _preprocess(self): 160 | self.obj_dict = {} 161 | obj_counter = 0 162 | for ii in range(len(self.im_ids)): 163 | # Read object masks and get number of objects 164 | _mask = np.array(Image.open(self.masks[ii])) 165 | _mask_ids = np.unique(_mask) 166 | if _mask_ids[-1] == 255: 167 | n_obj = _mask_ids[-2] 168 | else: 169 | n_obj = _mask_ids[-1] 170 | 171 | # Get the categories from these objects 172 | _cats = np.array(Image.open(self.categories[ii])) 173 | _cat_ids = [] 174 | for jj in range(n_obj): 175 | tmp = np.where(_mask == jj + 1) 176 | obj_area = len(tmp[0]) 177 | if obj_area > self.area_thres: 178 | _cat_ids.append(int(_cats[tmp[0][0], tmp[1][0]])) 179 | else: 180 | _cat_ids.append(-1) 181 | obj_counter += 1 182 | 183 | self.obj_dict[self.im_ids[ii]] = _cat_ids 184 | 185 | with open(self.obj_list_file, 'w') as outfile: 186 | outfile.write('{{\n\t"{:s}": {:s}'.format(self.im_ids[0], json.dumps(self.obj_dict[self.im_ids[0]]))) 187 | for ii in range(1, len(self.im_ids)): 188 | outfile.write(',\n\t"{:s}": {:s}'.format(self.im_ids[ii], json.dumps(self.obj_dict[self.im_ids[ii]]))) 189 | outfile.write('\n}\n') 190 | 191 | print('Preprocessing finished') 192 | 193 | def _download(self): 194 | _fpath = os.path.join(self.root, self.FILE) 195 | 196 | try: 197 | os.makedirs(self.root) 198 | except OSError as e: 199 | if e.errno == errno.EEXIST: 200 | pass 201 | else: 202 | raise 203 | 204 | if self._check_integrity(): 205 | print('Files already downloaded and verified') 206 | return 207 | else: 208 | print('Downloading ' + self.URL + ' to ' + _fpath) 209 | 210 | def _progress(count, block_size, total_size): 211 | sys.stdout.write('\r>> %s %.1f%%' % 212 | (_fpath, float(count * block_size) / 213 | float(total_size) * 100.0)) 214 | sys.stdout.flush() 215 | 216 | urllib.request.urlretrieve(self.URL, _fpath, _progress) 217 | 218 | # extract file 219 | cwd = os.getcwd() 220 | print('Extracting tar file') 221 | tar = tarfile.open(_fpath) 222 | os.chdir(self.root) 223 | tar.extractall() 224 | tar.close() 225 | os.chdir(cwd) 226 | print('Done!') 227 | 228 | def _make_img_gt_point_pair(self, index): 229 | _im_ii = self.obj_list[index][0] 230 | _obj_ii = self.obj_list[index][1] 231 | 232 | # Read Image 233 | _img = np.array(Image.open(self.images[_im_ii]).convert('RGB')).astype(np.float32) 234 | 235 | # Read Target object 236 | _tmp = (np.array(Image.open(self.masks[_im_ii]))).astype(np.float32) 237 | _void_pixels = (_tmp == 255) 238 | _tmp[_void_pixels] = 0 239 | 240 | _other_same_class = np.zeros(_tmp.shape) 241 | _other_classes = np.zeros(_tmp.shape) 242 | 243 | if self.default: 244 | _target = _tmp 245 | _background = np.logical_and(_tmp == 0, ~_void_pixels) 246 | else: 247 | _target = (_tmp == (_obj_ii + 1)).astype(np.float32) 248 | _background = np.logical_and(_tmp == 0, ~_void_pixels) 249 | obj_cat = self.obj_dict[self.im_ids[_im_ii]][_obj_ii] 250 | for ii in range(1, np.max(_tmp).astype(np.int)+1): 251 | ii_cat = self.obj_dict[self.im_ids[_im_ii]][ii-1] 252 | if obj_cat == ii_cat and ii != _obj_ii+1: 253 | _other_same_class = np.logical_or(_other_same_class, _tmp == ii) 254 | elif ii != _obj_ii+1: 255 | _other_classes = np.logical_or(_other_classes, _tmp == ii) 256 | 257 | return _img, _target, _void_pixels.astype(np.float32), \ 258 | _other_classes.astype(np.float32), _other_same_class.astype(np.float32), \ 259 | _background.astype(np.float32) 260 | 261 | def __str__(self): 262 | return 'VOC2012(split=' + str(self.split) + ',area_thres=' + str(self.area_thres) + ')' 263 | 264 | 265 | if __name__ == '__main__': 266 | import matplotlib.pyplot as plt 267 | import dataloaders.helpers as helpers 268 | import torch 269 | import dataloaders.custom_transforms as tr 270 | from torchvision import transforms 271 | 272 | transform = transforms.Compose([tr.ToTensor()]) 273 | 274 | dataset = VOCSegmentation(split=['train', 'val'], transform=transform, retname=True) 275 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) 276 | 277 | for i, sample in enumerate(dataloader): 278 | plt.figure() 279 | overlay = helpers.overlay_mask(helpers.tens2image(sample["image"]) / 255., 280 | np.squeeze(helpers.tens2image(sample["gt"]))) 281 | plt.imshow(overlay) 282 | plt.title(dataset.category_names[sample["meta"]["category"][0]]) 283 | if i == 3: 284 | break 285 | 286 | plt.show(block=True) 287 | -------------------------------------------------------------------------------- /dataloaders/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch, cv2 4 | import random 5 | import numpy as np 6 | 7 | 8 | def tens2image(im): 9 | if im.size()[0] == 1: 10 | tmp = np.squeeze(im.numpy(), axis=0) 11 | else: 12 | tmp = im.numpy() 13 | if tmp.ndim == 2: 14 | return tmp 15 | else: 16 | return tmp.transpose((1, 2, 0)) 17 | 18 | 19 | def crop2fullmask(crop_mask, bbox, im=None, im_size=None, zero_pad=False, relax=0, mask_relax=True, 20 | interpolation=cv2.INTER_CUBIC, scikit=False): 21 | if scikit: 22 | from skimage.transform import resize as sk_resize 23 | assert(not(im is None and im_size is None)), 'You have to provide an image or the image size' 24 | if im is None: 25 | im_si = im_size 26 | else: 27 | im_si = im.shape 28 | # Borers of image 29 | bounds = (0, 0, im_si[1] - 1, im_si[0] - 1) 30 | 31 | # Valid bounding box locations as (x_min, y_min, x_max, y_max) 32 | bbox_valid = (max(bbox[0], bounds[0]), 33 | max(bbox[1], bounds[1]), 34 | min(bbox[2], bounds[2]), 35 | min(bbox[3], bounds[3])) 36 | 37 | # Bounding box of initial mask 38 | bbox_init = (bbox[0] + relax, 39 | bbox[1] + relax, 40 | bbox[2] - relax, 41 | bbox[3] - relax) 42 | 43 | if zero_pad: 44 | # Offsets for x and y 45 | offsets = (-bbox[0], -bbox[1]) 46 | else: 47 | assert((bbox == bbox_valid).all()) 48 | offsets = (-bbox_valid[0], -bbox_valid[1]) 49 | 50 | # Simple per element addition in the tuple 51 | inds = tuple(map(sum, zip(bbox_valid, offsets + offsets))) 52 | 53 | if scikit: 54 | crop_mask = sk_resize(crop_mask, (bbox[3] - bbox[1] + 1, bbox[2] - bbox[0] + 1), order=0, mode='constant').astype(crop_mask.dtype) 55 | else: 56 | crop_mask = cv2.resize(crop_mask, (bbox[2] - bbox[0] + 1, bbox[3] - bbox[1] + 1), interpolation=interpolation) 57 | result_ = np.zeros(im_si) 58 | result_[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1] = \ 59 | crop_mask[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1] 60 | 61 | result = np.zeros(im_si) 62 | if mask_relax: 63 | result[bbox_init[1]:bbox_init[3]+1, bbox_init[0]:bbox_init[2]+1] = \ 64 | result_[bbox_init[1]:bbox_init[3]+1, bbox_init[0]:bbox_init[2]+1] 65 | else: 66 | result = result_ 67 | 68 | return result 69 | 70 | 71 | def overlay_mask(im, ma, colors=None, alpha=0.5): 72 | assert np.max(im) <= 1.0 73 | if colors is None: 74 | colors = np.load(os.path.join(os.path.dirname(__file__), 'pascal_map.npy'))/255. 75 | else: 76 | colors = np.append([[0.,0.,0.]], colors, axis=0); 77 | 78 | if ma.ndim == 3: 79 | assert len(colors) >= ma.shape[0], 'Not enough colors' 80 | ma = ma.astype(np.bool) 81 | im = im.astype(np.float32) 82 | 83 | if ma.ndim == 2: 84 | fg = im * alpha+np.ones(im.shape) * (1 - alpha) * colors[1, :3] # np.array([0,0,255])/255.0 85 | else: 86 | fg = [] 87 | for n in range(ma.ndim): 88 | fg.append(im * alpha + np.ones(im.shape) * (1 - alpha) * colors[1+n, :3]) 89 | # Whiten background 90 | bg = im.copy() 91 | if ma.ndim == 2: 92 | bg[ma == 0] = im[ma == 0] 93 | bg[ma == 1] = fg[ma == 1] 94 | total_ma = ma 95 | else: 96 | total_ma = np.zeros([ma.shape[1], ma.shape[2]]) 97 | for n in range(ma.shape[0]): 98 | tmp_ma = ma[n, :, :] 99 | total_ma = np.logical_or(tmp_ma, total_ma) 100 | tmp_fg = fg[n] 101 | bg[tmp_ma == 1] = tmp_fg[tmp_ma == 1] 102 | bg[total_ma == 0] = im[total_ma == 0] 103 | 104 | # [-2:] is s trick to be compatible both with opencv 2 and 3 105 | contours = cv2.findContours(total_ma.copy().astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 106 | cv2.drawContours(bg, contours[0], -1, (0.0, 0.0, 0.0), 1) 107 | 108 | return bg 109 | 110 | def overlay_masks(im, masks, alpha=0.5): 111 | colors = np.load(os.path.join(os.path.dirname(__file__), 'pascal_map.npy'))/255. 112 | 113 | if isinstance(masks, np.ndarray): 114 | masks = [masks] 115 | 116 | assert len(colors) >= len(masks), 'Not enough colors' 117 | 118 | ov = im.copy() 119 | im = im.astype(np.float32) 120 | total_ma = np.zeros([im.shape[0], im.shape[1]]) 121 | i = 1 122 | for ma in masks: 123 | ma = ma.astype(np.bool) 124 | fg = im * alpha+np.ones(im.shape) * (1 - alpha) * colors[i, :3] # np.array([0,0,255])/255.0 125 | i = i + 1 126 | ov[ma == 1] = fg[ma == 1] 127 | total_ma += ma 128 | 129 | # [-2:] is s trick to be compatible both with opencv 2 and 3 130 | contours = cv2.findContours(ma.copy().astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 131 | cv2.drawContours(ov, contours[0], -1, (0.0, 0.0, 0.0), 1) 132 | ov[total_ma == 0] = im[total_ma == 0] 133 | 134 | return ov 135 | 136 | 137 | def extreme_points(mask, pert): 138 | def find_point(id_x, id_y, ids): 139 | sel_id = ids[0][random.randint(0, len(ids[0]) - 1)] 140 | return [id_x[sel_id], id_y[sel_id]] 141 | 142 | # List of coordinates of the mask 143 | inds_y, inds_x = np.where(mask > 0.5) 144 | 145 | # Find extreme points 146 | return np.array([find_point(inds_x, inds_y, np.where(inds_x <= np.min(inds_x)+pert)), # left 147 | find_point(inds_x, inds_y, np.where(inds_x >= np.max(inds_x)-pert)), # right 148 | find_point(inds_x, inds_y, np.where(inds_y <= np.min(inds_y)+pert)), # top 149 | find_point(inds_x, inds_y, np.where(inds_y >= np.max(inds_y)-pert)) # bottom 150 | ]) 151 | 152 | 153 | def get_bbox(mask, points=None, pad=0, zero_pad=False): 154 | if points is not None: 155 | inds = np.flip(points.transpose(), axis=0) 156 | else: 157 | inds = np.where(mask > 0) 158 | 159 | if inds[0].shape[0] == 0: 160 | return None 161 | 162 | if zero_pad: 163 | x_min_bound = -np.inf 164 | y_min_bound = -np.inf 165 | x_max_bound = np.inf 166 | y_max_bound = np.inf 167 | else: 168 | x_min_bound = 0 169 | y_min_bound = 0 170 | x_max_bound = mask.shape[1] - 1 171 | y_max_bound = mask.shape[0] - 1 172 | 173 | x_min = max(inds[1].min() - pad, x_min_bound) 174 | y_min = max(inds[0].min() - pad, y_min_bound) 175 | x_max = min(inds[1].max() + pad, x_max_bound) 176 | y_max = min(inds[0].max() + pad, y_max_bound) 177 | 178 | return x_min, y_min, x_max, y_max 179 | 180 | 181 | def crop_from_bbox(img, bbox, zero_pad=False): 182 | # Borders of image 183 | bounds = (0, 0, img.shape[1] - 1, img.shape[0] - 1) 184 | 185 | # Valid bounding box locations as (x_min, y_min, x_max, y_max) 186 | bbox_valid = (max(bbox[0], bounds[0]), 187 | max(bbox[1], bounds[1]), 188 | min(bbox[2], bounds[2]), 189 | min(bbox[3], bounds[3])) 190 | 191 | if zero_pad: 192 | # Initialize crop size (first 2 dimensions) 193 | crop = np.zeros((bbox[3] - bbox[1] + 1, bbox[2] - bbox[0] + 1), dtype=img.dtype) 194 | 195 | # Offsets for x and y 196 | offsets = (-bbox[0], -bbox[1]) 197 | 198 | else: 199 | assert(bbox == bbox_valid) 200 | crop = np.zeros((bbox_valid[3] - bbox_valid[1] + 1, bbox_valid[2] - bbox_valid[0] + 1), dtype=img.dtype) 201 | offsets = (-bbox_valid[0], -bbox_valid[1]) 202 | 203 | # Simple per element addition in the tuple 204 | inds = tuple(map(sum, zip(bbox_valid, offsets + offsets))) 205 | 206 | img = np.squeeze(img) 207 | if img.ndim == 2: 208 | crop[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1] = \ 209 | img[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1] 210 | else: 211 | crop = np.tile(crop[:, :, np.newaxis], [1, 1, 3]) # Add 3 RGB Channels 212 | crop[inds[1]:inds[3] + 1, inds[0]:inds[2] + 1, :] = \ 213 | img[bbox_valid[1]:bbox_valid[3] + 1, bbox_valid[0]:bbox_valid[2] + 1, :] 214 | 215 | return crop 216 | 217 | 218 | def fixed_resize(sample, resolution, flagval=None): 219 | 220 | if flagval is None: 221 | if ((sample == 0) | (sample == 1)).all(): 222 | flagval = cv2.INTER_NEAREST 223 | else: 224 | flagval = cv2.INTER_CUBIC 225 | 226 | if isinstance(resolution, int): 227 | tmp = [resolution, resolution] 228 | tmp[np.argmax(sample.shape[:2])] = int(round(float(resolution)/np.min(sample.shape[:2])*np.max(sample.shape[:2]))) 229 | resolution = tuple(tmp) 230 | 231 | if sample.ndim == 2 or (sample.ndim == 3 and sample.shape[2] == 3): 232 | sample = cv2.resize(sample, resolution[::-1], interpolation=flagval) 233 | else: 234 | tmp = sample 235 | sample = np.zeros(np.append(resolution, tmp.shape[2]), dtype=np.float32) 236 | for ii in range(sample.shape[2]): 237 | sample[:, :, ii] = cv2.resize(tmp[:, :, ii], resolution[::-1], interpolation=flagval) 238 | return sample 239 | 240 | 241 | def crop_from_mask(img, mask, relax=0, zero_pad=False): 242 | if mask.shape[:2] != img.shape[:2]: 243 | mask = cv2.resize(mask, dsize=tuple(reversed(img.shape[:2])), interpolation=cv2.INTER_NEAREST) 244 | 245 | assert(mask.shape[:2] == img.shape[:2]) 246 | 247 | bbox = get_bbox(mask, pad=relax, zero_pad=zero_pad) 248 | 249 | if bbox is None: 250 | return None 251 | 252 | crop = crop_from_bbox(img, bbox, zero_pad) 253 | 254 | return crop 255 | 256 | 257 | def make_gaussian(size, sigma=10, center=None, d_type=np.float64): 258 | """ Make a square gaussian kernel. 259 | size: is the dimensions of the output gaussian 260 | sigma: is full-width-half-maximum, which 261 | can be thought of as an effective radius. 262 | """ 263 | 264 | x = np.arange(0, size[1], 1, float) 265 | y = np.arange(0, size[0], 1, float) 266 | y = y[:, np.newaxis] 267 | 268 | if center is None: 269 | x0 = y0 = size[0] // 2 270 | else: 271 | x0 = center[0] 272 | y0 = center[1] 273 | 274 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / sigma ** 2).astype(d_type) 275 | 276 | 277 | def make_gt(img, labels, sigma=10, one_mask_per_point=False): 278 | """ Make the ground-truth for landmark. 279 | img: the original color image 280 | labels: label with the Gaussian center(s) [[x0, y0],[x1, y1],...] 281 | sigma: sigma of the Gaussian. 282 | one_mask_per_point: masks for each point in different channels? 283 | """ 284 | h, w = img.shape[:2] 285 | if labels is None: 286 | gt = make_gaussian((h, w), center=(h//2, w//2), sigma=sigma) 287 | else: 288 | labels = np.array(labels) 289 | if labels.ndim == 1: 290 | labels = labels[np.newaxis] 291 | if one_mask_per_point: 292 | gt = np.zeros(shape=(h, w, labels.shape[0])) 293 | for ii in range(labels.shape[0]): 294 | gt[:, :, ii] = make_gaussian((h, w), center=labels[ii, :], sigma=sigma) 295 | else: 296 | gt = np.zeros(shape=(h, w), dtype=np.float64) 297 | for ii in range(labels.shape[0]): 298 | gt = np.maximum(gt, make_gaussian((h, w), center=labels[ii, :], sigma=sigma)) 299 | 300 | gt = gt.astype(dtype=img.dtype) 301 | 302 | return gt 303 | 304 | 305 | def cstm_normalize(im, max_value): 306 | """ 307 | Normalize image to range 0 - max_value 308 | """ 309 | imn = max_value*(im - im.min()) / max((im.max() - im.min()), 1e-8) 310 | return imn 311 | 312 | 313 | def generate_param_report(logfile, param): 314 | log_file = open(logfile, 'w') 315 | for key, val in param.items(): 316 | log_file.write(key+':'+str(val)+'\n') 317 | log_file.close() 318 | -------------------------------------------------------------------------------- /networks/deeplab_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.models.resnet as resnet 3 | import torch 4 | import numpy as np 5 | from copy import deepcopy 6 | import os 7 | from torch.nn import functional as F 8 | from mypath import Path 9 | 10 | affine_par = True 11 | 12 | 13 | def outS(i): 14 | i = int(i) 15 | i = (i+1)/2 16 | i = int(np.ceil((i+1)/2.0)) 17 | i = (i+1)/2 18 | return i 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | expansion = 4 23 | 24 | def __init__(self, inplanes, planes, stride=1, dilation_=1, downsample=None): 25 | super(Bottleneck, self).__init__() 26 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 27 | self.bn1 = nn.BatchNorm2d(planes, affine=affine_par) 28 | for i in self.bn1.parameters(): 29 | i.requires_grad = False 30 | padding = 1 31 | if dilation_ == 2: 32 | padding = 2 33 | elif dilation_ == 4: 34 | padding = 4 35 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 36 | padding=padding, bias=False, dilation=dilation_) 37 | self.bn2 = nn.BatchNorm2d(planes, affine=affine_par) 38 | for i in self.bn2.parameters(): 39 | i.requires_grad = False 40 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 41 | self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par) 42 | for i in self.bn3.parameters(): 43 | i.requires_grad = False 44 | self.relu = nn.ReLU(inplace=True) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv3(out) 60 | out = self.bn3(out) 61 | 62 | if self.downsample is not None: 63 | residual = self.downsample(x) 64 | 65 | out += residual 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class ClassifierModule(nn.Module): 72 | 73 | def __init__(self, dilation_series, padding_series, n_classes): 74 | super(ClassifierModule, self).__init__() 75 | self.conv2d_list = nn.ModuleList() 76 | for dilation, padding in zip(dilation_series, padding_series): 77 | self.conv2d_list.append(nn.Conv2d(2048, n_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True)) 78 | 79 | for m in self.conv2d_list: 80 | m.weight.data.normal_(0, 0.01) 81 | 82 | def forward(self, x): 83 | out = self.conv2d_list[0](x) 84 | for i in range(len(self.conv2d_list)-1): 85 | out += self.conv2d_list[i+1](x) 86 | return out 87 | 88 | 89 | class PSPModule(nn.Module): 90 | """ 91 | Pyramid Scene Parsing module 92 | """ 93 | def __init__(self, in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=1): 94 | super(PSPModule, self).__init__() 95 | self.stages = [] 96 | self.stages = nn.ModuleList([self._make_stage_1(in_features, size) for size in sizes]) 97 | self.bottleneck = self._make_stage_2(in_features * (len(sizes)//4 + 1), out_features) 98 | self.relu = nn.ReLU() 99 | self.final = nn.Conv2d(out_features, n_classes, kernel_size=1) 100 | 101 | def _make_stage_1(self, in_features, size): 102 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 103 | conv = nn.Conv2d(in_features, in_features//4, kernel_size=1, bias=False) 104 | bn = nn.BatchNorm2d(in_features//4, affine=affine_par) 105 | relu = nn.ReLU(inplace=True) 106 | 107 | return nn.Sequential(prior, conv, bn, relu) 108 | 109 | def _make_stage_2(self, in_features, out_features): 110 | conv = nn.Conv2d(in_features, out_features, kernel_size=1, bias=False) 111 | bn = nn.BatchNorm2d(out_features, affine=affine_par) 112 | relu = nn.ReLU(inplace=True) 113 | 114 | return nn.Sequential(conv, bn, relu) 115 | 116 | def forward(self, feats): 117 | h, w = feats.size(2), feats.size(3) 118 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] 119 | priors.append(feats) 120 | bottle = self.relu(self.bottleneck(torch.cat(priors, 1))) 121 | out = self.final(bottle) 122 | 123 | return out 124 | 125 | 126 | class ResNet(nn.Module): 127 | def __init__(self, block, layers, n_classes, nInputChannels=3, classifier="atrous", 128 | dilations=(2, 4), strides=(2, 2, 2, 1, 1), _print=False): 129 | if _print: 130 | print("Constructing ResNet model...") 131 | print("Dilations: {}".format(dilations)) 132 | print("Number of classes: {}".format(n_classes)) 133 | print("Number of Input Channels: {}".format(nInputChannels)) 134 | self.inplanes = 64 135 | self.classifier = classifier 136 | super(ResNet, self).__init__() 137 | self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=strides[0], padding=3, 138 | bias=False) 139 | self.bn1 = nn.BatchNorm2d(64, affine=affine_par) 140 | for i in self.bn1.parameters(): 141 | i.requires_grad = False 142 | self.relu = nn.ReLU(inplace=True) 143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=strides[1], padding=1, ceil_mode=False) 144 | self.layer1 = self._make_layer(block, 64, layers[0]) 145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[2]) 146 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[3], dilation__=dilations[0]) 147 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[4], dilation__=dilations[1]) 148 | 149 | if classifier == "atrous": 150 | if _print: 151 | print('Initializing classifier: A-trous pyramid') 152 | self.layer5 = self._make_pred_layer(ClassifierModule, [6, 12, 18, 24], [6, 12, 18, 24], n_classes=n_classes) 153 | elif classifier == "psp": 154 | if _print: 155 | print('Initializing classifier: PSP') 156 | self.layer5 = PSPModule(in_features=2048, out_features=512, sizes=(1, 2, 3, 6), n_classes=n_classes) 157 | else: 158 | self.layer5 = None 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d): 161 | m.weight.data.normal_(0, 0.01) 162 | elif isinstance(m, nn.BatchNorm2d): 163 | m.weight.data.fill_(1) 164 | m.bias.data.zero_() 165 | 166 | def _make_layer(self, block, planes, blocks, stride=1, dilation__=1): 167 | downsample = None 168 | if stride != 1 or self.inplanes != planes * block.expansion or dilation__ == 2 or dilation__ == 4: 169 | downsample = nn.Sequential( 170 | nn.Conv2d(self.inplanes, planes * block.expansion, 171 | kernel_size=1, stride=stride, bias=False), 172 | nn.BatchNorm2d(planes * block.expansion, affine=affine_par), 173 | ) 174 | for i in downsample._modules['1'].parameters(): 175 | i.requires_grad = False 176 | layers = [block(self.inplanes, planes, stride, dilation_=dilation__, downsample=downsample)] 177 | self.inplanes = planes * block.expansion 178 | for i in range(1, blocks): 179 | layers.append(block(self.inplanes, planes, dilation_=dilation__)) 180 | 181 | return nn.Sequential(*layers) 182 | 183 | def _make_pred_layer(self, block, dilation_series, padding_series, n_classes): 184 | return block(dilation_series, padding_series, n_classes) 185 | 186 | def forward(self, x, bbox=None): 187 | x = self.conv1(x) 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | x = self.maxpool(x) 191 | x = self.layer1(x) 192 | x = self.layer2(x) 193 | x = self.layer3(x) 194 | x = self.layer4(x) 195 | if self.layer5 is not None: 196 | x = self.layer5(x) 197 | return x 198 | 199 | def load_pretrained_ms(self, base_network, nInputChannels=3): 200 | flag = 0 201 | for module, module_ori in zip(self.modules(), base_network.Scale.modules()): 202 | if isinstance(module, nn.Conv2d) and isinstance(module_ori, nn.Conv2d): 203 | if not flag and nInputChannels != 3: 204 | module.weight[:, :3, :, :].data = deepcopy(module_ori.weight.data) 205 | module.bias = deepcopy(module_ori.bias) 206 | for i in range(3, int(module.weight.data.shape[1])): 207 | module.weight[:, i, :, :].data = deepcopy(module_ori.weight[:, -1, :, :][:, np.newaxis, :, :].data) 208 | flag = 1 209 | elif module.weight.data.shape == module_ori.weight.data.shape: 210 | module.weight.data = deepcopy(module_ori.weight.data) 211 | module.bias = deepcopy(module_ori.bias) 212 | else: 213 | print('Skipping Conv layer with size: {} and target size: {}' 214 | .format(module.weight.data.shape, module_ori.weight.data.shape)) 215 | elif isinstance(module, nn.BatchNorm2d) and isinstance(module_ori, nn.BatchNorm2d) \ 216 | and module.weight.data.shape == module_ori.weight.data.shape: 217 | module.weight.data = deepcopy(module_ori.weight.data) 218 | module.bias.data = deepcopy(module_ori.bias.data) 219 | 220 | 221 | def resnet101(n_classes, pretrained=False, nInputChannels=3, classifier="atrous", 222 | dilations=(2, 4), strides=(2, 2, 2, 1, 1)): 223 | """Constructs a ResNet-101 model. 224 | """ 225 | model = ResNet(Bottleneck, [3, 4, 23, 3], n_classes, nInputChannels=nInputChannels, 226 | classifier=classifier, dilations=dilations, strides=strides, _print=True) 227 | if pretrained: 228 | model_full = Res_Deeplab(n_classes, pretrained=pretrained) 229 | model.load_pretrained_ms(model_full, nInputChannels=nInputChannels) 230 | return model 231 | 232 | 233 | class MS_Deeplab(nn.Module): 234 | def __init__(self, block, NoLabels, nInputChannels=3): 235 | super(MS_Deeplab, self).__init__() 236 | self.Scale = ResNet(block, [3, 4, 23, 3], NoLabels, nInputChannels=nInputChannels) 237 | 238 | def forward(self, x): 239 | input_size = x.size()[2] 240 | self.interp1 = nn.Upsample(size=(int(input_size*0.75)+1, int(input_size*0.75)+1), mode='bilinear', align_corners=True) 241 | self.interp2 = nn.Upsample(size=(int(input_size*0.5)+1, int(input_size*0.5)+1), mode='bilinear', align_corners=True) 242 | self.interp3 = nn.Upsample(size=(outS(input_size), outS(input_size)), mode='bilinear', align_corners=True) 243 | out = [] 244 | x2 = self.interp1(x) 245 | x3 = self.interp2(x) 246 | out.append(self.Scale(x)) # for original scale 247 | out.append(self.interp3(self.Scale(x2))) # for 0.75x scale 248 | out.append(self.Scale(x3)) # for 0.5x scale 249 | 250 | x2Out_interp = out[1] 251 | x3Out_interp = self.interp3(out[2]) 252 | temp1 = torch.max(out[0], x2Out_interp) 253 | out.append(torch.max(temp1, x3Out_interp)) 254 | return out[-1] 255 | 256 | 257 | def Res_Deeplab(n_classes=21, pretrained=False): 258 | model = MS_Deeplab(Bottleneck, n_classes) 259 | if pretrained: 260 | pth_model = 'MS_DeepLab_resnet_trained_VOC.pth' 261 | saved_state_dict = torch.load(os.path.join(Path.models_dir(), pth_model), 262 | map_location=lambda storage, loc: storage) 263 | if n_classes != 21: 264 | for i in saved_state_dict: 265 | i_parts = i.split('.') 266 | if i_parts[1] == 'layer5': 267 | saved_state_dict[i] = model.state_dict()[i] 268 | model.load_state_dict(saved_state_dict) 269 | return model 270 | 271 | 272 | def get_lr_params(model): 273 | """ 274 | This generator returns all the parameters of the net except for 275 | the last classification layer. Note that for each batchnorm layer, 276 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 277 | any batchnorm parameter 278 | """ 279 | b = [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, model.layer4, model.layer5] 280 | for i in range(len(b)): 281 | for k in b[i].parameters(): 282 | if k.requires_grad: 283 | yield k 284 | 285 | 286 | def get_1x_lr_params(model): 287 | """ 288 | This generator returns all the parameters of the net except for 289 | the last classification layer. Note that for each batchnorm layer, 290 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 291 | any batchnorm parameter 292 | """ 293 | b = [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, model.layer4] 294 | for i in range(len(b)): 295 | for k in b[i].parameters(): 296 | if k.requires_grad: 297 | yield k 298 | 299 | 300 | def get_10x_lr_params(model): 301 | """ 302 | This generator returns all the parameters for the last layer of the net, 303 | which does the classification of pixel into classes 304 | """ 305 | b = [model.layer5] 306 | for j in range(len(b)): 307 | for k in b[j].parameters(): 308 | if k.requires_grad: 309 | yield k 310 | 311 | 312 | def lr_poly(base_lr, iter_, max_iter=100, power=0.9): 313 | return base_lr*((1-float(iter_)/max_iter)**power) 314 | --------------------------------------------------------------------------------