├── .gitignore ├── LICENSE ├── README.md ├── base_trainer.py ├── configs ├── voc_resnet101.yaml ├── voc_resnet38.yaml ├── voc_resnet50.yaml └── voc_vgg16.yaml ├── core ├── __init__.py └── config.py ├── data ├── test.txt ├── train_augvoc.txt ├── train_voc.txt └── val_voc.txt ├── datasets ├── __init__.py ├── pascal_voc.py ├── pascal_voc_ms.py ├── transforms.py └── utils.py ├── eval_seg.py ├── figures ├── results.gif └── results.png ├── fonts └── UbuntuMono-R.ttf ├── infer_val.py ├── launch ├── eval_seg.sh ├── infer_val.sh ├── run_bsl_resnet101.sh ├── run_bsl_resnet38.sh ├── run_bsl_resnet50.sh ├── run_bsl_vgg16.sh ├── run_voc_resnet101.sh ├── run_voc_resnet38.sh ├── run_voc_resnet50.sh └── run_voc_vgg16.sh ├── losses └── __init__.py ├── models ├── __init__.py ├── backbones │ ├── __init__.py │ ├── base_net.py │ ├── resnet38d.py │ ├── resnets.py │ └── vgg16d.py ├── mods │ ├── __init__.py │ ├── aspp.py │ ├── gci.py │ ├── pamr.py │ └── sg.py └── stage_net.py ├── opts.py ├── requirements.txt ├── tools └── convert_sbd.py ├── train.py └── utils ├── __init__.py ├── checkpoints.py ├── collections.py ├── dcrf.py ├── inference_tools.py ├── metrics.py ├── pallete.py ├── stat_manager.py └── timer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | logs/ 4 | data/ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, and 10 | distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by the 13 | copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all other 16 | entities that control, are controlled by, or are under common control with 17 | that entity. For the purposes of this definition, "control" means (i) the 18 | power, direct or indirect, to cause the direction or management of such 19 | entity, whether by contract or otherwise, or (ii) ownership of 20 | fifty percent (50%) or more of the outstanding shares, or (iii) beneficial 21 | ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity exercising 24 | permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation source, 28 | and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical transformation 31 | or translation of a Source form, including but not limited to compiled 32 | object code, generated documentation, and conversions to 33 | other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or Object 36 | form, made available under the License, as indicated by a copyright notice 37 | that is included in or attached to the work (an example is provided in the 38 | Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object form, 41 | that is based on (or derived from) the Work and for which the editorial 42 | revisions, annotations, elaborations, or other modifications represent, 43 | as a whole, an original work of authorship. For the purposes of this 44 | License, Derivative Works shall not include works that remain separable 45 | from, or merely link (or bind by name) to the interfaces of, the Work and 46 | Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including the original 49 | version of the Work and any modifications or additions to that Work or 50 | Derivative Works thereof, that is intentionally submitted to Licensor for 51 | inclusion in the Work by the copyright owner or by an individual or 52 | Legal Entity authorized to submit on behalf of the copyright owner. 53 | For the purposes of this definition, "submitted" means any form of 54 | electronic, verbal, or written communication sent to the Licensor or its 55 | representatives, including but not limited to communication on electronic 56 | mailing lists, source code control systems, and issue tracking systems 57 | that are managed by, or on behalf of, the Licensor for the purpose of 58 | discussing and improving the Work, but excluding communication that is 59 | conspicuously marked or otherwise designated in writing by the copyright 60 | owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity on 63 | behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. 67 | 68 | Subject to the terms and conditions of this License, each Contributor 69 | hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, 70 | royalty-free, irrevocable copyright license to reproduce, prepare 71 | Derivative Works of, publicly display, publicly perform, sublicense, 72 | and distribute the Work and such Derivative Works in 73 | Source or Object form. 74 | 75 | 3. Grant of Patent License. 76 | 77 | Subject to the terms and conditions of this License, each Contributor 78 | hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, 79 | royalty-free, irrevocable (except as stated in this section) patent 80 | license to make, have made, use, offer to sell, sell, import, and 81 | otherwise transfer the Work, where such license applies only to those 82 | patent claims licensable by such Contributor that are necessarily 83 | infringed by their Contribution(s) alone or by combination of their 84 | Contribution(s) with the Work to which such Contribution(s) was submitted. 85 | If You institute patent litigation against any entity (including a 86 | cross-claim or counterclaim in a lawsuit) alleging that the Work or a 87 | Contribution incorporated within the Work constitutes direct or 88 | contributory patent infringement, then any patent licenses granted to 89 | You under this License for that Work shall terminate as of the date such 90 | litigation is filed. 91 | 92 | 4. Redistribution. 93 | 94 | You may reproduce and distribute copies of the Work or Derivative Works 95 | thereof in any medium, with or without modifications, and in Source or 96 | Object form, provided that You meet the following conditions: 97 | 98 | 1. You must give any other recipients of the Work or Derivative Works a 99 | copy of this License; and 100 | 101 | 2. You must cause any modified files to carry prominent notices stating 102 | that You changed the files; and 103 | 104 | 3. You must retain, in the Source form of any Derivative Works that You 105 | distribute, all copyright, patent, trademark, and attribution notices from 106 | the Source form of the Work, excluding those notices that do not pertain 107 | to any part of the Derivative Works; and 108 | 109 | 4. If the Work includes a "NOTICE" text file as part of its distribution, 110 | then any Derivative Works that You distribute must include a readable copy 111 | of the attribution notices contained within such NOTICE file, excluding 112 | those notices that do not pertain to any part of the Derivative Works, 113 | in at least one of the following places: within a NOTICE text file 114 | distributed as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, within a 116 | display generated by the Derivative Works, if and wherever such 117 | third-party notices normally appear. The contents of the NOTICE file are 118 | for informational purposes only and do not modify the License. 119 | You may add Your own attribution notices within Derivative Works that You 120 | distribute, alongside or as an addendum to the NOTICE text from the Work, 121 | provided that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and may 125 | provide additional or different license terms and conditions for use, 126 | reproduction, or distribution of Your modifications, or for any such 127 | Derivative Works as a whole, provided Your use, reproduction, and 128 | distribution of the Work otherwise complies with the conditions 129 | stated in this License. 130 | 131 | 5. Submission of Contributions. 132 | 133 | Unless You explicitly state otherwise, any Contribution intentionally 134 | submitted for inclusion in the Work by You to the Licensor shall be under 135 | the terms and conditions of this License, without any additional 136 | terms or conditions. Notwithstanding the above, nothing herein shall 137 | supersede or modify the terms of any separate license agreement you may 138 | have executed with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. 141 | 142 | This License does not grant permission to use the trade names, trademarks, 143 | service marks, or product names of the Licensor, except as required for 144 | reasonable and customary use in describing the origin of the Work and 145 | reproducing the content of the NOTICE file. 146 | 147 | 7. Disclaimer of Warranty. 148 | 149 | Unless required by applicable law or agreed to in writing, Licensor 150 | provides the Work (and each Contributor provides its Contributions) 151 | on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 152 | either express or implied, including, without limitation, any warranties 153 | or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS 154 | FOR A PARTICULAR PURPOSE. You are solely responsible for determining the 155 | appropriateness of using or redistributing the Work and assume any risks 156 | associated with Your exercise of permissions under this License. 157 | 158 | 8. Limitation of Liability. 159 | 160 | In no event and under no legal theory, whether in tort 161 | (including negligence), contract, or otherwise, unless required by 162 | applicable law (such as deliberate and grossly negligent acts) or agreed 163 | to in writing, shall any Contributor be liable to You for damages, 164 | including any direct, indirect, special, incidental, or consequential 165 | damages of any character arising as a result of this License or out of 166 | the use or inability to use the Work (including but not limited to damages 167 | for loss of goodwill, work stoppage, computer failure or malfunction, 168 | or any and all other commercial damages or losses), even if such 169 | Contributor has been advised of the possibility of such damages. 170 | 171 | 9. Accepting Warranty or Additional Liability. 172 | 173 | While redistributing the Work or Derivative Works thereof, You may choose 174 | to offer, and charge a fee for, acceptance of support, warranty, 175 | indemnity, or other liability obligations and/or rights consistent with 176 | this License. However, in accepting such obligations, You may act only 177 | on Your own behalf and on Your sole responsibility, not on behalf of any 178 | other Contributor, and only if You agree to indemnify, defend, and hold 179 | each Contributor harmless for any liability incurred by, or claims 180 | asserted against, such Contributor by reason of your accepting any such 181 | warranty or additional liability. 182 | 183 | END OF TERMS AND CONDITIONS 184 | 185 | Copyright 2020 TU Darmstadt 186 | 187 | Author: Nikita Araslanov 188 | 189 | Licensed under the Apache License, Version 2.0 (the "License"); 190 | you may not use this file except in compliance with the License. 191 | You may obtain a copy of the License at 192 | 193 | http://www.apache.org/licenses/LICENSE-2.0 194 | 195 | Unless required by applicable law or agreed to in writing, software 196 | distributed under the License is distributed on an "AS IS" BASIS, 197 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 198 | or implied. See the License for the specific language governing 199 | permissions and limitations under the License. 200 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Single-Stage Semantic Segmentation from Image Labels 2 | 3 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 4 | [![Framework](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?&logo=PyTorch&logoColor=white)](https://pytorch.org/) 5 | 6 | This repository contains the original implementation of our paper: 7 | 8 | 9 | **Single-stage Semantic Segmentation from Image Labels**
10 | *[Nikita Araslanov](https://arnike.github.io) and [Stefan Roth](https://www.visinf.tu-darmstadt.de/team_members/sroth/sroth.en.jsp)*
11 | CVPR 2020. [[pdf](https://openaccess.thecvf.com/content_CVPR_2020/papers/Araslanov_Single-Stage_Semantic_Segmentation_From_Image_Labels_CVPR_2020_paper.pdf)] [[supp](https://openaccess.thecvf.com/content_CVPR_2020/supplemental/Araslanov_Single-Stage_Semantic_Segmentation_CVPR_2020_supplemental.pdf)] 12 | [[arXiv](https://arxiv.org/abs/2005.08104)] 13 | 14 | Contact: Nikita Araslanov 15 | 16 | 17 | | drawing
| 18 | |:---| 19 | | We attain competitive results by training a single network model
for segmentation in a self-supervised fashion using only
image-level annotations (one run of 20 epochs on Pascal VOC). | 20 | 21 | ### Setup 22 | 0. **Minimum requirements.** This project was originally developed with Python 3.6, PyTorch 1.0 and CUDA 9.0. The training requires at least two Titan X GPUs (12Gb memory each). 23 | 1. **Setup your Python environment.** Please, clone the repository and install the dependencies. We recommend using Anaconda 3 distribution: 24 | ``` 25 | conda create -n --file requirements.txt 26 | ``` 27 | 2. **Download and link to the dataset.** We train our model on the original Pascal VOC 2012 augmented with the SBD data (10K images in total). Download the data from: 28 | - VOC: [Training/Validation (2GB .tar file)](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) 29 | - SBD: [Training (1.4GB .tgz file)](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz) 30 | 31 | Link to the data: 32 | ``` 33 | ln -s /data/voc 34 | ln -s /data/sbd 35 | ``` 36 | Make sure that the first directory in `data/voc` is `VOCdevkit`; the first directory in `data/sbd` is `benchmark_RELEASE`. 37 | 3. **Download pre-trained models.** Download the initial weights (pre-trained on ImageNet) for the backbones you are planning to use and place them into `/models/weights/`. 38 | 39 | | Backbone | Initial Weights | Comment | 40 | |:---:|:---:|:---:| 41 | | WideResNet38 | [ilsvrc-cls_rna-a1_cls1000_ep-0001.pth (402M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/models/ilsvrc-cls_rna-a1_cls1000_ep-0001.pth) | Converted from [mxnet](https://github.com/itijyou/ademxapp) | 42 | | VGG16 | [vgg16_20M.pth (79M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/models/vgg16_20M.pth) | Converted from [Caffe](http://liangchiehchen.com/projects/Init%20Models.html) | 43 | | ResNet50 | [resnet50-19c8e357.pth](https://download.pytorch.org/models/resnet50-19c8e357.pth) | PyTorch official | 44 | | ResNet101 | [resnet101-5d3b4d8f.pth](https://download.pytorch.org/models/resnet101-5d3b4d8f.pth) | PyTorch official | 45 | 46 | 47 | ### Training, Inference and Evaluation 48 | The directory `launch` contains template bash scripts for training, inference and evaluation. 49 | 50 | **Training.** For each run, you need to specify names of two variables, for example 51 | ```bash 52 | EXP=baselines 53 | RUN_ID=v01 54 | ``` 55 | Running `bash ./launch/run_voc_resnet38.sh` will create a directory `./logs/pascal_voc/baselines/v01` with tensorboard events and will save snapshots into `./snapshots/pascal_voc/baselines/v01`. 56 | 57 | **Inference.** To generate final masks, please, use the script `./launch/infer_val.sh`. You will need to specify: 58 | * `EXP` and `RUN_ID` you used for training; 59 | * `OUTPUT_DIR` the path where to save the masks; 60 | * `FILELIST` specifies the file to the data split; 61 | * `SNAPSHOT` specifies the model suffix in the format `e000Xs0.000`. For example, `e020Xs0.928`; 62 | * (optionally) `EXTRA_ARGS` specify additional arguments to the inference script. 63 | 64 | **Evaluation.** To compute IoU of the masks, please, run `./launch/eval_seg.sh`. You will need to specify `SAVE_DIR` that contains the masks and `FILELIST` specifying the split for evaluation. 65 | 66 | ### Pre-trained model 67 | For testing, we provide our pre-trained WideResNet38 model: 68 | 69 | | Backbone | Val | Val (+ CRF) | Link | 70 | |:---:|:---:|:---:|---:| 71 | | WideResNet38 | 59.7 | 62.7 | [model_enc_e020Xs0.928.pth (527M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/models/model_enc_e020Xs0.928.pth) | 72 | 73 | The also release the masks predicted by this model: 74 | 75 | | Split | IoU | IoU (+ CRF) | Link | Comment | 76 | |:---:|:---:|:---:|:---:|:---:| 77 | | train-clean (VOC+SBD) | 64.7 | 66.9 | [train_results_clean.tgz (2.9G)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/results/train_results_clean.tgz) | Reported IoU is for VOC | 78 | | val-clean | 63.4 | 65.3 | [val_results_clean.tgz (423M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/results/val_results_clean.tgz) | | 79 | | val | 59.7 | 62.7 | [val_results.tgz (427M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/results/val_results.tgz) | | 80 | | test | 62.7 | 64.3 | [test_results.tgz (368M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/results/test_results.tgz) | | 81 | 82 | The suffix `-clean` means we used ground-truth image-level labels to remove masks of the categories not present in the image. 83 | These masks are commonly used as pseudo ground truth to train another segmentation model in fully supervised regime. 84 | 85 | ## Acknowledgements 86 | We thank PyTorch team, and Jiwoon Ahn for releasing his [code](https://github.com/jiwoon-ahn/psa) that helped in the early stages of this project. 87 | 88 | ## Citation 89 | We hope that you find this work useful. If you would like to acknowledge us, please, use the following citation: 90 | ``` 91 | @InProceedings{Araslanov:2020:SSS, 92 | author = {Araslanov, Nikita and Roth, Stefan}, 93 | title = {Single-Stage Semantic Segmentation From Image Labels}, 94 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 95 | month = {June}, 96 | pages = {4253--4262} 97 | year = {2020} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /base_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | import math 5 | import numpy as np 6 | 7 | import torchvision.utils as vutils 8 | 9 | from PIL import Image, ImageDraw, ImageFont 10 | from utils.checkpoints import Checkpoint 11 | 12 | try: # backward compatibility 13 | from tensorboardX import SummaryWriter 14 | except ImportError: 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | from core.config import cfg, cfg_from_file, cfg_from_list 18 | 19 | class BaseTrainer(object): 20 | 21 | def __del__(self): 22 | # commented out, because hangs on exit 23 | # (presumably some bug with threading in TensorboardX) 24 | """ 25 | if not self.quiet: 26 | self.writer.close() 27 | self.writer_val.close() 28 | """ 29 | pass 30 | 31 | def __init__(self, args, quiet=False): 32 | self.args = args 33 | self.quiet = quiet 34 | 35 | # config 36 | # Reading the config 37 | if type(args.cfg_file) is str \ 38 | and os.path.isfile(args.cfg_file): 39 | 40 | cfg_from_file(args.cfg_file) 41 | if args.set_cfgs is not None: 42 | cfg_from_list(args.set_cfgs) 43 | 44 | self.start_epoch = 0 45 | self.best_score = -1e16 46 | self.checkpoint = Checkpoint(args.snapshot_dir, max_n = 5) 47 | 48 | if not quiet: 49 | #self.model_id = "%s" % args.run 50 | logdir = os.path.join(args.logdir, 'train') 51 | logdir_val = os.path.join(args.logdir, 'val') 52 | 53 | self.writer = SummaryWriter(logdir) 54 | self.writer_val = SummaryWriter(logdir_val) 55 | 56 | def _define_checkpoint(self, name, model, optim): 57 | self.checkpoint.add_model(name, model, optim) 58 | 59 | def _load_checkpoint(self, suffix): 60 | if self.checkpoint.load(suffix): 61 | # loading the epoch and the best score 62 | tmpl = re.compile("^e(\d+)Xs([\.\d+\-]+)$") 63 | match = tmpl.match(suffix) 64 | if not match: 65 | print("Warning: epoch and score could not be recovered") 66 | return 67 | else: 68 | epoch, score = match.groups() 69 | self.start_epoch = int(epoch) + 1 70 | self.best_score = float(score) 71 | 72 | def checkpoint_epoch(self, score, epoch): 73 | 74 | if score > self.best_score: 75 | self.best_score = score 76 | 77 | print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format(score, epoch)) 78 | suffix = "e{:03d}Xs{:4.3f}".format(epoch, score) 79 | self.checkpoint.checkpoint(suffix) 80 | 81 | return True 82 | 83 | def checkpoint_best(self, score, epoch): 84 | 85 | if score > self.best_score: 86 | print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format(score, epoch)) 87 | self.best_score= score 88 | 89 | suffix = "e{:03d}Xs{:4.3f}".format(epoch, score) 90 | self.checkpoint.checkpoint(suffix) 91 | 92 | return True 93 | 94 | return False 95 | 96 | @staticmethod 97 | def get_optim(params, cfg): 98 | 99 | if not hasattr(torch.optim, cfg.OPT): 100 | print("Optimiser {} not supported".format(cfg.OPT)) 101 | raise NotImplementedError 102 | 103 | optim = getattr(torch.optim, cfg.OPT) 104 | 105 | if cfg.OPT == 'Adam': 106 | upd = torch.optim.Adam(params, lr=cfg.LR, \ 107 | betas=(cfg.BETA1, 0.999), \ 108 | weight_decay=cfg.WEIGHT_DECAY) 109 | elif cfg.OPT == 'SGD': 110 | print("Using SGD >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}".format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY)) 111 | upd = torch.optim.SGD(params, lr=cfg.LR, \ 112 | momentum=cfg.MOMENTUM, \ 113 | weight_decay=cfg.WEIGHT_DECAY) 114 | 115 | else: 116 | upd = optim(params, lr=cfg.LR) 117 | 118 | upd.zero_grad() 119 | 120 | return upd 121 | 122 | @staticmethod 123 | def set_lr(optim, lr): 124 | for param_group in optim.param_groups: 125 | param_group['lr'] = lr 126 | 127 | 128 | def _visualise_grid(self, x_all, labels, t, ious=None, tag="visualisation", scores=None): 129 | 130 | # adding the labels to images 131 | bs, ch, h, w = x_all.size() 132 | x_all_new = torch.zeros(bs, ch, h + 16, w) 133 | _, y_labels_idx = torch.max(labels, -1) 134 | classNamesOffset = len(self.classNames) - labels.size(1) 135 | classNames = self.classNames[classNamesOffset:] 136 | for b in range(bs): 137 | label_idx = labels[b] 138 | label_names = [name for i,name in enumerate(classNames) if label_idx[i].item()] 139 | label_name = ", ".join(label_names) 140 | 141 | ndarr = x_all[b].mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() 142 | arr = np.zeros((16, w, ch), dtype=ndarr.dtype) 143 | ndarr = np.concatenate((arr, ndarr), 0) 144 | im = Image.fromarray(ndarr) 145 | draw = ImageDraw.Draw(im) 146 | 147 | font = ImageFont.truetype("fonts/UbuntuMono-R.ttf", 12) 148 | 149 | # draw.text((x, y),"Sample Text",(r,g,b)) 150 | draw.text((5, 1), label_name, (255,255,255), font=font) 151 | im_np = np.array(im).astype(np.float) 152 | x_all_new[b] = (torch.from_numpy(im_np)/255.0).permute(2,0,1) 153 | 154 | summary_grid = vutils.make_grid(x_all_new, nrow=1, padding=8, pad_value=0.9) 155 | self.writer.add_image(tag, summary_grid, t) 156 | 157 | def _apply_cmap(self, mask_idx, mask_conf): 158 | palette = self.trainloader.dataset.get_palette() 159 | 160 | masks = [] 161 | col = Colorize() 162 | mask_conf = mask_conf.float() / 255.0 163 | for mask, conf in zip(mask_idx.split(1), mask_conf.split(1)): 164 | m = col(mask).float() 165 | m = m * conf 166 | masks.append(m[None, ...]) 167 | 168 | return torch.cat(masks, 0) 169 | 170 | def _mask_rgb(self, masks, image_norm, alpha=0.3): 171 | # visualising masks 172 | masks_conf, masks_idx = torch.max(masks, 1) 173 | masks_conf = masks_conf - F.relu(masks_conf - 1, 0) 174 | 175 | masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), masks_conf.cpu()) 176 | return alpha * image_norm + (1 - alpha) * masks_idx_rgb 177 | 178 | def _init_norm(self): 179 | self.trainloader.dataset.set_norm(self.enc.normalize) 180 | self.valloader.dataset.set_norm(self.enc.normalize) 181 | self.trainloader_val.dataset.set_norm(self.enc.normalize) 182 | -------------------------------------------------------------------------------- /configs/voc_resnet101.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | DATASET: 3 | CROP_SIZE: 448 4 | SCALE_FROM: 0.9 5 | SCALE_TO: 1.0 6 | TRAIN: 7 | BATCH_SIZE: 16 8 | NUM_EPOCHS: 20 9 | NUM_WORKERS: 8 10 | PRETRAIN: 5 11 | NET: 12 | BACKBONE: "resnet101" 13 | MODEL: "ae" 14 | PRE_WEIGHTS_PATH: "./models/weights/resnet101-5d3b4d8f.pth" 15 | LR: 0.001 16 | OPT: "SGD" 17 | LOSS: "SoftMargin" 18 | WEIGHT_DECAY: 0.0005 19 | TEST: 20 | METHOD: "multiscale" # multiscale | crop 21 | DATA_ROOT: "/fastdata/naraslanov" 22 | FLIP: True 23 | BATCH_SIZE: 8 # 4 scales, +1 flip for each 24 | PAD_SIZE: [768, 768] 25 | SCALES: [1, 0.75, 1.25, 1.5] 26 | FP_CUT_SCORE: 0.3 27 | USE_GT_LABELS: True 28 | -------------------------------------------------------------------------------- /configs/voc_resnet38.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | DATASET: 3 | CROP_SIZE: 321 4 | SCALE_FROM: 0.9 5 | SCALE_TO: 1.0 6 | TRAIN: 7 | BATCH_SIZE: 16 8 | NUM_EPOCHS: 24 9 | NUM_WORKERS: 8 10 | PRETRAIN: 5 11 | NET: 12 | BACKBONE: "resnet38" 13 | MODEL: "ae" 14 | PRE_WEIGHTS_PATH: "./models/weights/ilsvrc-cls_rna-a1_cls1000_ep-0001.pth" 15 | LR: 0.001 16 | OPT: "SGD" 17 | LOSS: "SoftMargin" 18 | WEIGHT_DECAY: 0.0005 19 | PAMR_ITER: 10 20 | FOCAL_LAMBDA: 0.01 21 | FOCAL_P: 3 22 | SG_PSI: 0.3 23 | TEST: 24 | METHOD: "multiscale" 25 | DATA_ROOT: "/fastdata/naraslanov" 26 | FLIP: True 27 | BATCH_SIZE: 8 # 4 scales, +1 flip for each 28 | PAD_SIZE: [1024, 1024] 29 | SCALES: [1, 0.5, 1.5, 2.0] 30 | FP_CUT_SCORE: 0.1 31 | BG_POW: 3 32 | USE_GT_LABELS: False 33 | -------------------------------------------------------------------------------- /configs/voc_resnet50.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | DATASET: 3 | CROP_SIZE: 448 4 | SCALE_FROM: 0.9 5 | SCALE_TO: 1.0 6 | TRAIN: 7 | BATCH_SIZE: 16 8 | NUM_EPOCHS: 20 9 | NUM_WORKERS: 8 10 | PRETRAIN: 5 11 | NET: 12 | BACKBONE: "resnet50" 13 | MODEL: "ae" 14 | PRE_WEIGHTS_PATH: "./models/weights/resnet50-19c8e357.pth" 15 | LR: 0.0005 16 | OPT: "SGD" 17 | LOSS: "SoftMargin" 18 | WEIGHT_DECAY: 0.0005 19 | TEST: 20 | METHOD: "multiscale" # multiscale | crop 21 | DATA_ROOT: "/fastdata/naraslanov" 22 | FLIP: True 23 | BATCH_SIZE: 8 # 4 scales, +1 flip for each 24 | PAD_SIZE: [768, 768] 25 | SCALES: [1, 0.75, 1.25, 1.5] 26 | FP_CUT_SCORE: 0.3 27 | USE_GT_LABELS: True 28 | -------------------------------------------------------------------------------- /configs/voc_vgg16.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | DATASET: 3 | CROP_SIZE: 448 4 | SCALE_FROM: 0.9 5 | SCALE_TO: 1.0 6 | TRAIN: 7 | BATCH_SIZE: 16 8 | NUM_EPOCHS: 32 9 | NUM_WORKERS: 8 10 | PRETRAIN: 5 11 | NET: 12 | BACKBONE: "vgg16" 13 | MODEL: "ae" 14 | PRE_WEIGHTS_PATH: "./models/weights/vgg16_20M.pth" 15 | LR: 0.001 16 | OPT: "SGD" 17 | LOSS: "SoftMargin" 18 | WEIGHT_DECAY: 0.0005 19 | TEST: 20 | METHOD: "multiscale" # multiscale | crop 21 | DATA_ROOT: "/fastdata/naraslanov" 22 | FLIP: True 23 | BATCH_SIZE: 8 # 4 scales, +1 flip for each 24 | PAD_SIZE: [768, 768] 25 | SCALES: [1, 0.75, 1.25, 1.5] 26 | FP_CUT_SCORE: 0.3 27 | USE_GT_LABELS: True 28 | BG_POW: 1 29 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/core/__init__.py -------------------------------------------------------------------------------- /core/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | import yaml 8 | import six 9 | import os 10 | import os.path as osp 11 | import copy 12 | from ast import literal_eval 13 | 14 | import numpy as np 15 | from packaging import version 16 | 17 | from utils.collections import AttrDict 18 | 19 | __C = AttrDict() 20 | # Consumers can get config by: 21 | # from fast_rcnn_config import cfg 22 | cfg = __C 23 | 24 | __C.NUM_GPUS = 1 25 | # Random note: avoid using '.ON' as a config key since yaml converts it to True; 26 | # prefer 'ENABLED' instead 27 | 28 | # ---------------------------------------------------------------------------- # 29 | # Training options 30 | # ---------------------------------------------------------------------------- # 31 | __C.TRAIN = AttrDict() 32 | __C.TRAIN.BATCH_SIZE = 20 33 | __C.TRAIN.NUM_EPOCHS = 15 34 | __C.TRAIN.NUM_WORKERS = 4 35 | __C.TRAIN.MASK_LOSS = 0.0 36 | __C.TRAIN.PRETRAIN = 5 37 | 38 | # ---------------------------------------------------------------------------- # 39 | # Inference options 40 | # ---------------------------------------------------------------------------- # 41 | __C.TEST = AttrDict() 42 | __C.TEST.METHOD = "multiscale" # multiscale | crop 43 | __C.TEST.DATA_ROOT = "/data/your_directory" 44 | __C.TEST.SCALES = [1, 0.5, 1.5, 2.0] 45 | __C.TEST.FLIP = True 46 | __C.TEST.PAD_SIZE = [1024, 1024] 47 | __C.TEST.CROP_SIZE = [448, 448] 48 | __C.TEST.CROP_GRID_SIZE = [2, 2] 49 | __C.TEST.BATCH_SIZE = 8 50 | __C.TEST.BG_POW = 3 51 | __C.TEST.NUM_CLASSES = 21 52 | 53 | # use ground-truth labels to remove 54 | # false positive masks 55 | __C.TEST.USE_GT_LABELS = False 56 | 57 | # if class confidence does not exceed this threshold 58 | # the mask is removed (count as false positive) 59 | # used only if MASKS.USE_GT_LABELS is False 60 | __C.TEST.FP_CUT_SCORE = 0.1 61 | 62 | # ---------------------------------------------------------------------------- # 63 | # Dataset options 64 | # ---------------------------------------------------------------------------- # 65 | __C.DATASET = AttrDict() 66 | 67 | __C.DATASET.CROP_SIZE = 321 68 | __C.DATASET.SCALE_FROM = 0.9 69 | __C.DATASET.SCALE_TO = 1.0 70 | __C.DATASET.PATH = "data/images" 71 | 72 | # ---------------------------------------------------------------------------- # 73 | # Network options 74 | # ---------------------------------------------------------------------------- # 75 | __C.NET = AttrDict() 76 | __C.NET.MODEL = 'vgg16' 77 | __C.NET.BACKBONE = 'resnet50' 78 | __C.NET.PRE_WEIGHTS_PATH = "" 79 | __C.NET.OPT = 'SGD' 80 | __C.NET.LR = 0.001 81 | __C.NET.BETA1 = 0.5 82 | __C.NET.MOMENTUM = 0.9 83 | __C.NET.WEIGHT_DECAY = 1e-5 84 | __C.NET.LOSS = 'SoftMargin' 85 | __C.NET.MASK_LOSS_BCE = 1.0 86 | __C.NET.BG_SCORE = 0.1 # background score (only for CAM) 87 | __C.NET.FOCAL_P = 3 88 | __C.NET.FOCAL_LAMBDA = 0.01 89 | __C.NET.PAMR_KERNEL = [1, 2, 4, 8, 12, 24] 90 | __C.NET.PAMR_ITER = 10 91 | __C.NET.SG_PSI = 0.3 92 | 93 | # Mask Inference 94 | __C.MASKS = AttrDict() 95 | 96 | # CRF options 97 | __C.MASKS.CRF = AttrDict() 98 | __C.MASKS.CRF.ALPHA_LOW = 4 99 | __C.MASKS.CRF.ALPHA_HIGH = 32 100 | 101 | # [Infered value] 102 | __C.CUDA = False 103 | 104 | __C.DEBUG = False 105 | 106 | # [Infered value] 107 | __C.PYTORCH_VERSION_LESS_THAN_040 = False 108 | 109 | def assert_and_infer_cfg(make_immutable=True): 110 | """Call this function in your script after you have finished setting all cfg 111 | values that are necessary (e.g., merging a config from a file, merging 112 | command line config options, etc.). By default, this function will also 113 | mark the global cfg as immutable to prevent changing the global cfg settings 114 | during script execution (which can lead to hard to debug errors or code 115 | that's harder to understand than is necessary). 116 | """ 117 | if make_immutable: 118 | cfg.immutable(True) 119 | 120 | 121 | def merge_cfg_from_file(cfg_filename): 122 | """Load a yaml config file and merge it into the global config.""" 123 | with open(cfg_filename, 'r') as f: 124 | if hasattr(yaml, "FullLoader"): 125 | yaml_cfg = AttrDict(yaml.load(f, Loader=yaml.FullLoader)) 126 | else: 127 | yaml_cfg = AttrDict(yaml.load(f)) 128 | 129 | _merge_a_into_b(yaml_cfg, __C) 130 | 131 | cfg_from_file = merge_cfg_from_file 132 | 133 | 134 | def merge_cfg_from_cfg(cfg_other): 135 | """Merge `cfg_other` into the global config.""" 136 | _merge_a_into_b(cfg_other, __C) 137 | 138 | 139 | def merge_cfg_from_list(cfg_list): 140 | """Merge config keys, values in a list (e.g., from command line) into the 141 | global config. For example, `cfg_list = ['TEST.NMS', 0.5]`. 142 | """ 143 | assert len(cfg_list) % 2 == 0 144 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 145 | key_list = full_key.split('.') 146 | d = __C 147 | for subkey in key_list[:-1]: 148 | assert subkey in d, 'Non-existent key: {}'.format(full_key) 149 | d = d[subkey] 150 | subkey = key_list[-1] 151 | assert subkey in d, 'Non-existent key: {}'.format(full_key) 152 | value = _decode_cfg_value(v) 153 | value = _check_and_coerce_cfg_value_type( 154 | value, d[subkey], subkey, full_key 155 | ) 156 | d[subkey] = value 157 | 158 | cfg_from_list = merge_cfg_from_list 159 | 160 | 161 | def _merge_a_into_b(a, b, stack=None): 162 | """Merge config dictionary a into config dictionary b, clobbering the 163 | options in b whenever they are also specified in a. 164 | """ 165 | assert isinstance(a, AttrDict), 'Argument `a` must be an AttrDict' 166 | assert isinstance(b, AttrDict), 'Argument `b` must be an AttrDict' 167 | 168 | for k, v_ in a.items(): 169 | full_key = '.'.join(stack) + '.' + k if stack is not None else k 170 | # a must specify keys that are in b 171 | if k not in b: 172 | raise KeyError('Non-existent config key: {}'.format(full_key)) 173 | 174 | v = copy.deepcopy(v_) 175 | v = _decode_cfg_value(v) 176 | v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) 177 | 178 | # Recursively merge dicts 179 | if isinstance(v, AttrDict): 180 | try: 181 | stack_push = [k] if stack is None else stack + [k] 182 | _merge_a_into_b(v, b[k], stack=stack_push) 183 | except BaseException: 184 | raise 185 | else: 186 | b[k] = v 187 | 188 | 189 | def _decode_cfg_value(v): 190 | """Decodes a raw config value (e.g., from a yaml config files or command 191 | line argument) into a Python object. 192 | """ 193 | # Configs parsed from raw yaml will contain dictionary keys that need to be 194 | # converted to AttrDict objects 195 | if isinstance(v, dict): 196 | return AttrDict(v) 197 | # All remaining processing is only applied to strings 198 | if not isinstance(v, six.string_types): 199 | return v 200 | # Try to interpret `v` as a: 201 | # string, number, tuple, list, dict, boolean, or None 202 | try: 203 | v = literal_eval(v) 204 | # The following two excepts allow v to pass through when it represents a 205 | # string. 206 | # 207 | # Longer explanation: 208 | # The type of v is always a string (before calling literal_eval), but 209 | # sometimes it *represents* a string and other times a data structure, like 210 | # a list. In the case that v represents a string, what we got back from the 211 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 212 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 213 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 214 | # will raise a SyntaxError. 215 | except ValueError: 216 | pass 217 | except SyntaxError: 218 | pass 219 | return v 220 | 221 | 222 | def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key): 223 | """Checks that `value_a`, which is intended to replace `value_b` is of the 224 | right type. The type is correct if it matches exactly or is one of a few 225 | cases in which the type can be easily coerced. 226 | """ 227 | # The types must match (with some exceptions) 228 | type_b = type(value_b) 229 | type_a = type(value_a) 230 | if type_a is type_b: 231 | return value_a 232 | 233 | # Exceptions: numpy arrays, strings, tuple<->list 234 | if isinstance(value_b, np.ndarray): 235 | value_a = np.array(value_a, dtype=value_b.dtype) 236 | elif isinstance(value_b, six.string_types): 237 | value_a = str(value_a) 238 | elif isinstance(value_a, tuple) and isinstance(value_b, list): 239 | value_a = list(value_a) 240 | elif isinstance(value_a, list) and isinstance(value_b, tuple): 241 | value_a = tuple(value_a) 242 | else: 243 | raise ValueError( 244 | 'Type mismatch ({} vs. {}) with values ({} vs. {}) for config ' 245 | 'key: {}'.format(type_b, type_a, value_b, value_a, full_key) 246 | ) 247 | return value_a 248 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | from .pascal_voc import VOCSegmentation 3 | 4 | datasets = { 5 | 'pascal_voc': VOCSegmentation 6 | } 7 | 8 | def get_num_classes(args): 9 | return datasets[args.dataset.lower()].NUM_CLASS 10 | 11 | def get_class_names(args): 12 | return datasets[args.dataset.lower()].CLASSES 13 | 14 | def get_dataloader(args, cfg, split, batch_size=None, test_mode=None): 15 | assert split in ('train', 'train_voc', 'val'), "Unknown split '{}'".format(split) 16 | 17 | dataset_name = args.dataset.lower() 18 | dataset_cls = datasets[dataset_name] 19 | dataset = dataset_cls(cfg, split, test_mode) 20 | 21 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 22 | shuffle, drop_last = [True, True] if split == 'train' else [False, False] 23 | 24 | if batch_size is None: 25 | batch_size = cfg.TRAIN.BATCH_SIZE 26 | 27 | return data.DataLoader(dataset, batch_size=batch_size, 28 | drop_last=drop_last, shuffle=shuffle, 29 | **kwargs) 30 | -------------------------------------------------------------------------------- /datasets/pascal_voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | from PIL import Image, ImagePalette 6 | from .utils import colormap 7 | import datasets.transforms as tf 8 | 9 | class PascalVOC(Dataset): 10 | 11 | CLASSES = [ 12 | 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 13 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 14 | 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 15 | 'tv/monitor', 'ambiguous' 16 | ] 17 | 18 | CLASS_IDX = { 19 | 'background': 0, 20 | 'aeroplane': 1, 21 | 'bicycle': 2, 22 | 'bird': 3, 23 | 'boat': 4, 24 | 'bottle': 5, 25 | 'bus': 6, 26 | 'car': 7, 27 | 'cat': 8, 28 | 'chair': 9, 29 | 'cow': 10, 30 | 'diningtable': 11, 31 | 'dog': 12, 32 | 'horse': 13, 33 | 'motorbike': 14, 34 | 'person': 15, 35 | 'potted-plant': 16, 36 | 'sheep': 17, 37 | 'sofa': 18, 38 | 'train': 19, 39 | 'tv/monitor': 20, 40 | 'ambiguous': 255 41 | } 42 | 43 | CLASS_IDX_INV = { 44 | 0: 'background', 45 | 1: 'aeroplane', 46 | 2: 'bicycle', 47 | 3: 'bird', 48 | 4: 'boat', 49 | 5: 'bottle', 50 | 6: 'bus', 51 | 7: 'car', 52 | 8: 'cat', 53 | 9: 'chair', 54 | 10: 'cow', 55 | 11: 'diningtable', 56 | 12: 'dog', 57 | 13: 'horse', 58 | 14: 'motorbike', 59 | 15: 'person', 60 | 16: 'potted-plant', 61 | 17: 'sheep', 62 | 18: 'sofa', 63 | 19: 'train', 64 | 20: 'tv/monitor', 65 | 255: 'ambiguous'} 66 | 67 | NUM_CLASS = 21 68 | 69 | MEAN = (0.485, 0.456, 0.406) 70 | STD = (0.229, 0.224, 0.225) 71 | 72 | def __init__(self): 73 | super().__init__() 74 | self._init_palette() 75 | 76 | def _init_palette(self): 77 | self.cmap = colormap() 78 | self.palette = ImagePalette.ImagePalette() 79 | for rgb in self.cmap: 80 | self.palette.getcolor(rgb) 81 | 82 | def get_palette(self): 83 | return self.palette 84 | 85 | def denorm(self, image): 86 | 87 | if image.dim() == 3: 88 | assert image.dim() == 3, "Expected image [CxHxW]" 89 | assert image.size(0) == 3, "Expected RGB image [3xHxW]" 90 | 91 | for t, m, s in zip(image, self.MEAN, self.STD): 92 | t.mul_(s).add_(m) 93 | elif image.dim() == 4: 94 | # batch mode 95 | assert image.size(1) == 3, "Expected RGB image [3xHxW]" 96 | 97 | for t, m, s in zip((0,1,2), self.MEAN, self.STD): 98 | image[:, t, :, :].mul_(s).add_(m) 99 | 100 | return image 101 | 102 | 103 | class VOCSegmentation(PascalVOC): 104 | 105 | def __init__(self, cfg, split, test_mode, root=os.path.expanduser('./data')): 106 | super(VOCSegmentation, self).__init__() 107 | 108 | self.cfg = cfg 109 | self.root = root 110 | self.split = split 111 | self.test_mode = test_mode 112 | 113 | # train/val/test splits are pre-cut 114 | if self.split == 'train': 115 | _split_f = os.path.join(self.root, 'train_augvoc.txt') 116 | elif self.split == 'train_voc': 117 | _split_f = os.path.join(self.root, 'train_voc.txt') 118 | elif self.split == 'val': 119 | _split_f = os.path.join(self.root, 'val_voc.txt') 120 | elif self.split == 'test': 121 | _split_f = os.path.join(self.root, 'test.txt') 122 | else: 123 | raise RuntimeError('Unknown dataset split.') 124 | 125 | assert os.path.isfile(_split_f), "%s not found" % _split_f 126 | 127 | self.images = [] 128 | self.masks = [] 129 | with open(_split_f, "r") as lines: 130 | for line in lines: 131 | _image, _mask = line.strip("\n").split(' ') 132 | _image = os.path.join(self.root, _image) 133 | assert os.path.isfile(_image), '%s not found' % _image 134 | self.images.append(_image) 135 | 136 | if self.split != 'test': 137 | _mask = os.path.join(self.root, _mask.lstrip('/')) 138 | assert os.path.isfile(_mask), '%s not found' % _mask 139 | self.masks.append(_mask) 140 | 141 | if self.split != 'test': 142 | assert (len(self.images) == len(self.masks)) 143 | if self.split == 'train': 144 | assert len(self.images) == 10582 145 | elif self.split == 'val': 146 | assert len(self.images) == 1449 147 | 148 | self.transform = tf.Compose([tf.MaskRandResizedCrop(self.cfg.DATASET), \ 149 | tf.MaskHFlip(), \ 150 | tf.MaskColourJitter(p = 1.0), \ 151 | tf.MaskNormalise(self.MEAN, self.STD), \ 152 | tf.MaskToTensor()]) 153 | 154 | def __len__(self): 155 | return len(self.images) 156 | 157 | def __getitem__(self, index): 158 | 159 | image = Image.open(self.images[index]).convert('RGB') 160 | mask = Image.open(self.masks[index]) 161 | 162 | unique_labels = np.unique(mask) 163 | 164 | # ambigious 165 | if unique_labels[-1] == self.CLASS_IDX['ambiguous']: 166 | unique_labels = unique_labels[:-1] 167 | 168 | # ignoring BG 169 | labels = torch.zeros(self.NUM_CLASS - 1) 170 | if unique_labels[0] == self.CLASS_IDX['background']: 171 | unique_labels = unique_labels[1:] 172 | unique_labels -= 1 # shifting since no BG class 173 | 174 | assert unique_labels.size > 0, 'No labels found in %s' % self.masks[index] 175 | labels[unique_labels.tolist()] = 1 176 | 177 | # general resize, normalize and toTensor 178 | image, mask = self.transform(image, mask) 179 | 180 | return image, labels, os.path.basename(self.images[index]) 181 | 182 | @property 183 | def pred_offset(self): 184 | return 0 185 | -------------------------------------------------------------------------------- /datasets/pascal_voc_ms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multi-scale dataloader 3 | Credit PSA 4 | """ 5 | 6 | from PIL import Image 7 | from .pascal_voc import PascalVOC 8 | 9 | import math 10 | import numpy as np 11 | import torch 12 | import os.path 13 | import torchvision.transforms.functional as F 14 | 15 | 16 | def load_img_name_list(dataset_path, index = 0): 17 | 18 | img_gt_name_list = open(dataset_path).read().splitlines() 19 | img_name_list = [img_gt_name.split(' ')[index].strip('/') for img_gt_name in img_gt_name_list] 20 | 21 | return img_name_list 22 | 23 | def load_label_name_list(dataset_path): 24 | return load_img_name_list(dataset_path, index = 1) 25 | 26 | class VOC12ImageDataset(PascalVOC): 27 | 28 | def __init__(self, img_name_list_path, voc12_root): 29 | super().__init__() 30 | 31 | self.img_name_list = load_img_name_list(img_name_list_path) 32 | self.voc12_root = voc12_root 33 | self.batch_size = 1 34 | 35 | def __len__(self): 36 | return len(self.img_name_list) 37 | 38 | def __getitem__(self, idx): 39 | fullpath = os.path.join(self.voc12_root, self.img_name_list[idx]) 40 | img = Image.open(fullpath).convert("RGB") 41 | return fullpath, img 42 | 43 | class VOC12ClsDataset(VOC12ImageDataset): 44 | 45 | def __init__(self, img_name_list_path, voc12_root): 46 | super(VOC12ClsDataset, self).__init__(img_name_list_path, voc12_root) 47 | self.label_list = load_label_name_list(img_name_list_path) 48 | 49 | def __len__(self): 50 | return self.batch_size * len(self.img_name_list) 51 | 52 | def _pad(self, image): 53 | w, h = image.size 54 | 55 | pad_mask = Image.new("L", image.size) 56 | pad_height = self.pad_size[0] - h 57 | pad_width = self.pad_size[1] - w 58 | 59 | assert pad_height >= 0 and pad_width >= 0 60 | 61 | pad_l = max(0, pad_width // 2) 62 | pad_r = max(0, pad_width - pad_l) 63 | pad_t = max(0, pad_height // 2) 64 | pad_b = max(0, pad_height - pad_t) 65 | 66 | image = F.pad(image, (pad_l, pad_t, pad_r, pad_b), fill=0, padding_mode="constant") 67 | pad_mask = F.pad(pad_mask, (pad_l, pad_t, pad_r, pad_b), fill=1, padding_mode="constant") 68 | 69 | return image, pad_mask, [pad_t, pad_l] 70 | 71 | def __getitem__(self, idx): 72 | name, img = super(VOC12ClsDataset, self).__getitem__(idx) 73 | 74 | label_fullpath = self.label_list[idx] 75 | assert len(label_fullpath) < 256, "Expected label path less than 256 for padding" 76 | 77 | mask = Image.open(os.path.join(self.voc12_root, label_fullpath)) 78 | mask = np.array(mask) 79 | 80 | labels = torch.zeros(self.NUM_CLASS - 1) 81 | 82 | # it will also be sorted 83 | unique_labels = np.unique(mask) 84 | 85 | # ambigious 86 | if unique_labels[-1] == self.CLASS_IDX['ambiguous']: 87 | unique_labels = unique_labels[:-1] 88 | 89 | # background 90 | if unique_labels[0] == self.CLASS_IDX['background']: 91 | unique_labels = unique_labels[1:] 92 | 93 | assert unique_labels.size > 0, 'No labels found in %s' % self.masks[index] 94 | unique_labels -= 1 # shifting since no BG class 95 | labels[unique_labels.tolist()] = 1 96 | 97 | return name, img, labels, mask.astype(np.int) 98 | 99 | 100 | class MultiscaleLoader(VOC12ClsDataset): 101 | 102 | def __init__(self, img_list, cfg, transform): 103 | super().__init__(img_list, cfg.DATA_ROOT) 104 | 105 | self.scales = cfg.SCALES 106 | self.pad_size = cfg.PAD_SIZE 107 | self.use_flips = cfg.FLIP 108 | self.transform = transform 109 | 110 | self.batch_size = len(self.scales) 111 | if self.use_flips: 112 | self.batch_size *= 2 113 | 114 | print("Inference batch size: ", self.batch_size) 115 | assert self.batch_size == cfg.BATCH_SIZE 116 | 117 | def __getitem__(self, idx): 118 | im_idx = idx // self.batch_size 119 | sub_idx = idx % self.batch_size 120 | 121 | scale = self.scales[sub_idx // (2 if self.use_flips else 1)] 122 | flip = self.use_flips and sub_idx % 2 123 | 124 | name, img, label, mask = super().__getitem__(im_idx) 125 | 126 | target_size = (int(round(img.size[0]*scale)), 127 | int(round(img.size[1]*scale))) 128 | 129 | s_img = img.resize(target_size, resample=Image.CUBIC) 130 | 131 | if flip: 132 | s_img = F.hflip(s_img) 133 | 134 | w, h = s_img.size 135 | im_msc, ignore, pads_tl = self._pad(s_img) 136 | pad_t, pad_l = pads_tl 137 | 138 | im_msc = self.transform(im_msc) 139 | img = F.to_tensor(self.transform(img)) 140 | 141 | pads = torch.Tensor([pad_t, pad_l, h, w]) 142 | 143 | ignore = np.array(ignore).astype(im_msc.dtype)[..., np.newaxis] 144 | im_msc = F.to_tensor(im_msc * (1 - ignore)) 145 | 146 | return name, img, im_msc, pads, label, mask 147 | 148 | 149 | class CropLoader(VOC12ClsDataset): 150 | 151 | def __init__(self, img_list, cfg, transform): 152 | super().__init__(img_list, cfg.DATA_ROOT) 153 | 154 | self.use_flips = cfg.FLIP 155 | self.transform = transform 156 | 157 | self.grid_h, self.grid_w = cfg.CROP_GRID_SIZE 158 | self.crop_h, self.crop_w = cfg.CROP_SIZE 159 | self.pad_size = cfg.PAD_SIZE 160 | 161 | self.stride_h = int(math.ceil(self.pad_size[0] / self.grid_h)) 162 | self.stride_w = int(math.ceil(self.pad_size[1] / self.grid_w)) 163 | 164 | assert self.stride_h <= self.crop_h and \ 165 | self.stride_w <= self.crop_w 166 | 167 | self.batch_size = self.grid_h * self.grid_w 168 | if self.use_flips: 169 | self.batch_size *= 2 170 | 171 | print("Inference batch size: ", self.batch_size) 172 | assert self.batch_size == cfg.BATCH_SIZE 173 | 174 | 175 | def __getitem__(self, index): 176 | image_index = index // self.batch_size 177 | batch_index = index % self.batch_size 178 | grid_index = batch_index // (2 if self.use_flips else 1) 179 | 180 | index_h = grid_index // self.grid_w 181 | index_w = grid_index % self.grid_w 182 | flip = self.use_flips and batch_index % 2 == 0 183 | 184 | name, image, label, mask = super().__getitem__(image_index) 185 | 186 | image_pad, pad_mask, pads = self._pad(image) 187 | assert image_pad.size[0] == self.pad_size[1] and \ 188 | image_pad.size[1] == self.pad_size[0] 189 | 190 | s_h = index_h * self.stride_h 191 | e_h = min(s_h + self.crop_h, self.pad_size[0]) 192 | s_h = e_h - self.crop_h 193 | 194 | s_w = index_w * self.stride_w 195 | e_w = min(s_w + self.crop_w, self.pad_size[1]) 196 | s_w = e_w - self.crop_w 197 | 198 | image_pad = self.transform(image_pad) 199 | pad_mask = np.array(pad_mask).astype(image_pad.dtype)[..., np.newaxis] 200 | image_pad *= 1 - pad_mask 201 | 202 | image_pad = F.to_tensor(image_pad) 203 | image_crop = image_pad[:, s_h:e_h, s_w:e_w].clone() 204 | 205 | pads = torch.LongTensor([s_h, e_h, s_w, e_w] + pads) 206 | 207 | if flip: 208 | image_crop = image_crop.flip(-1) 209 | 210 | image = F.to_tensor(self.transform(image)) 211 | 212 | return name, image, image_crop, pads, label, mask 213 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | 5 | from PIL import Image 6 | import torchvision.transforms as tf 7 | import torchvision.transforms.functional as F 8 | 9 | from functools import partial 10 | 11 | class Compose: 12 | def __init__(self, segtransform): 13 | self.segtransform = segtransform 14 | 15 | def __call__(self, image, label): 16 | # allow for intermediate representations 17 | result = (image, label) 18 | for t in self.segtransform: 19 | result = t(*result) 20 | 21 | # ensure we have just the image 22 | # and the label in the end 23 | image, label = result 24 | return image, label 25 | 26 | class MaskRandResizedCrop: 27 | 28 | def __init__(self, cfg): 29 | self.rnd_crop = tf.RandomResizedCrop(cfg.CROP_SIZE, \ 30 | scale=(cfg.SCALE_FROM, \ 31 | cfg.SCALE_TO)) 32 | 33 | def get_params(self, image): 34 | return self.rnd_crop.get_params(image, \ 35 | self.rnd_crop.scale, \ 36 | self.rnd_crop.ratio) 37 | 38 | def __call__(self, image, labels): 39 | 40 | i, j, h, w = self.get_params(image) 41 | 42 | image = F.resized_crop(image, i, j, h, w, self.rnd_crop.size, Image.CUBIC) 43 | labels = F.resized_crop(labels, i, j, h, w, self.rnd_crop.size, Image.NEAREST) 44 | 45 | return image, labels 46 | 47 | class MaskHFlip: 48 | 49 | def __init__(self, p=0.5): 50 | self.p = p 51 | 52 | def __call__(self, image, mask): 53 | 54 | if random.random() < self.p: 55 | image = F.hflip(image) 56 | mask = F.hflip(mask) 57 | 58 | return image, mask 59 | 60 | class MaskNormalise: 61 | 62 | def __init__(self, mean, std): 63 | self.norm = tf.Normalize(mean, std) 64 | 65 | def __toByteTensor(self, pic): 66 | return torch.from_numpy(np.array(pic, np.int32, copy=False)) 67 | 68 | def __call__(self, image, labels): 69 | 70 | image = F.to_tensor(image) 71 | image = self.norm(image) 72 | labels = self.__toByteTensor(labels) 73 | 74 | return image, labels 75 | 76 | class MaskToTensor: 77 | 78 | def __call__(self, image, mask): 79 | gt_labels = torch.arange(0, 21) 80 | gt_labels = gt_labels.unsqueeze(-1).unsqueeze(-1) 81 | mask = mask.unsqueeze(0).type_as(gt_labels) 82 | mask = torch.eq(mask, gt_labels).float() 83 | return image, mask 84 | 85 | class MaskColourJitter: 86 | 87 | def __init__(self, p=0.5, brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1): 88 | self.p = p 89 | self.jitter = tf.ColorJitter(brightness=0.3, \ 90 | contrast=0.3, \ 91 | saturation=0.3, \ 92 | hue=0.1) 93 | 94 | def __call__(self, image, mask): 95 | 96 | if random.random() < self.p: 97 | image = self.jitter(image) 98 | 99 | return image, mask 100 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def colormap(N=256): 5 | def bitget(byteval, idx): 6 | return ((byteval & (1 << idx)) != 0) 7 | 8 | dtype = 'uint8' 9 | cmap = [] 10 | for i in range(N): 11 | r = g = b = 0 12 | c = i 13 | for j in range(8): 14 | r = r | (bitget(c, 0) << 7-j) 15 | g = g | (bitget(c, 1) << 7-j) 16 | b = b | (bitget(c, 2) << 7-j) 17 | c = c >> 3 18 | 19 | cmap.append((r, g, b)) 20 | 21 | return cmap 22 | 23 | """ 24 | Python implementation of the color map function for the PASCAL VOC data set. 25 | Official Matlab version can be found in the PASCAL VOC devkit 26 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 27 | """ 28 | 29 | def uint82bin(n, count=8): 30 | """returns the binary of integer n, count refers to amount of bits""" 31 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 32 | 33 | def labelcolormap(N): 34 | cmap = np.zeros((N, 3), dtype=np.uint8) 35 | for i in range(N): 36 | r = 0 37 | g = 0 38 | b = 0 39 | id = i 40 | for j in range(7): 41 | str_id = uint82bin(id) 42 | r = r ^ (np.uint8(str_id[-1]) << (7-j)) 43 | g = g ^ (np.uint8(str_id[-2]) << (7-j)) 44 | b = b ^ (np.uint8(str_id[-3]) << (7-j)) 45 | id = id >> 3 46 | cmap[i, 0] = r 47 | cmap[i, 1] = g 48 | cmap[i, 2] = b 49 | return cmap 50 | 51 | class Colorize(object): 52 | 53 | def __init__(self, n=22): 54 | self.cmap = labelcolormap(22) 55 | self.cmap = torch.from_numpy(self.cmap[:n]) 56 | 57 | def __call__(self, gray_image): 58 | size = gray_image.size() 59 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 60 | 61 | for label in range(0, len(self.cmap)): 62 | mask = (label == gray_image[0]).cpu() 63 | color_image[0][mask] = self.cmap[label][0] 64 | color_image[1][mask] = self.cmap[label][1] 65 | color_image[2][mask] = self.cmap[label][2] 66 | 67 | return color_image 68 | 69 | 70 | -------------------------------------------------------------------------------- /eval_seg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluating the masks 3 | 4 | TODO: 5 | Parallelise with 6 | 7 | from multiprocessing import Pool 8 | 9 | ConfM = ConfusionMatrix(class_num) 10 | f = ConfM.generateM 11 | pool = Pool() 12 | m_list = pool.map(f, data_list) 13 | pool.close() 14 | pool.join() 15 | 16 | """ 17 | 18 | import sys 19 | import os 20 | import numpy as np 21 | import argparse 22 | import scipy 23 | 24 | from tqdm import tqdm 25 | from datasets.pascal_voc import PascalVOC 26 | from PIL import Image 27 | from utils.metrics import Metric 28 | 29 | # Defining 30 | parser = argparse.ArgumentParser(description="Mask Evaluation") 31 | 32 | parser.add_argument("--data", type=str, default='./data/annotation', 33 | help="The prefix for data directory") 34 | parser.add_argument("--filelist", type=str, default='./data/val.txt', 35 | help="A text file containing the paths to masks") 36 | parser.add_argument("--masks", type=str, default='./masks', 37 | help="A path to generated masks") 38 | parser.add_argument("--oracle-from", type=str, default="", 39 | help="Use GT mask but down- then upscale them") 40 | parser.add_argument("--log-scores", type=str, default='./scores.log', 41 | help="Logging scores for invididual images") 42 | 43 | def check_args(args): 44 | """Check the files/directories exist""" 45 | 46 | assert os.path.isdir(args.data), \ 47 | "Directory {} does not exist".format(args.data) 48 | assert os.path.isfile(args.filelist), \ 49 | "File {} does not exist".format(args.filelist) 50 | if len(args.oracle_from) > 0: 51 | vals = args.oracle_from.split('x') 52 | assert len(vals) == 2, "HxW expected" 53 | h, w = vals 54 | assert int(h) > 2, "Meaningless resolution" 55 | assert int(w) > 2, "Meaningless resolution" 56 | else: 57 | assert os.path.isdir(args.masks), \ 58 | "Directory {} does not exist".format(args.masks) 59 | 60 | def format_num(x): 61 | return round(x*100., 1) 62 | 63 | 64 | def get_stats(M, i): 65 | 66 | TP = M[i, i] 67 | FN = np.sum(M[i, :]) - TP # false negatives 68 | FP = np.sum(M[:, i]) - TP # false positives 69 | 70 | return TP, FN, FP 71 | 72 | def summarise_one(class_stats, M, name, labels): 73 | 74 | for i in labels: 75 | 76 | # skipping the ambiguous 77 | if i == 255: 78 | continue 79 | 80 | # category name 81 | TP, FN, FP = get_stats(M, i) 82 | score = TP - FN - FP 83 | 84 | class_stats[i].append((name, score)) 85 | 86 | def summarise_per_class(class_stats, filename): 87 | 88 | data = "" 89 | for cat in PascalVOC.CLASSES: 90 | 91 | if cat == "ambiguous": 92 | continue 93 | 94 | i = PascalVOC.CLASS_IDX[cat] 95 | sorted_by_score = sorted(class_stats[i], key=lambda x: -x[1]) 96 | data += cat + "\n" 97 | for name, score in sorted_by_score: 98 | data += "{:05d} | {}\n".format(int(score), name) 99 | 100 | with open(filename, 'w') as f: 101 | f.write(data) 102 | 103 | def summarise_stats(M): 104 | 105 | eps = 1e-20 106 | 107 | mean = Metric() 108 | mean.add_metric(Metric.IoU) 109 | mean.add_metric(Metric.Precision) 110 | mean.add_metric(Metric.Recall) 111 | 112 | mean_bkg = Metric() 113 | mean_bkg.add_metric(Metric.IoU) 114 | mean_bkg.add_metric(Metric.Precision) 115 | mean_bkg.add_metric(Metric.Recall) 116 | 117 | head_fmt = "{:>12} | {:>5}" + " | {:>5}"*3 118 | row_fmt = "{:>12} | {:>5}" + " | {:>5.1f}"*3 119 | split = "-"*44 120 | 121 | def print_row(fmt, row): 122 | print(fmt.format(*row)) 123 | 124 | print_row(head_fmt, ("Class", "#", "IoU", "Pr", "Re")) 125 | print(split) 126 | 127 | for cat in PascalVOC.CLASSES: 128 | 129 | if cat == "ambiguous": 130 | continue 131 | 132 | i = PascalVOC.CLASS_IDX[cat] 133 | 134 | TP, FN, FP = get_stats(M, i) 135 | 136 | iou = 100. * TP / (eps + FN + FP + TP) 137 | pr = 100. * TP / (eps + TP + FP) 138 | re = 100. * TP / (eps + TP + FN) 139 | 140 | mean_bkg.update_value(Metric.IoU, iou) 141 | mean_bkg.update_value(Metric.Precision, pr) 142 | mean_bkg.update_value(Metric.Recall, re) 143 | 144 | if cat != "background": 145 | mean.update_value(Metric.IoU, iou) 146 | mean.update_value(Metric.Precision, pr) 147 | mean.update_value(Metric.Recall, re) 148 | 149 | count = int(np.sum(M[i, :])) 150 | print_row(row_fmt, (cat, count, iou, pr, re)) 151 | 152 | 153 | print(split) 154 | sys.stdout.write("mIoU: {:.2f}\t".format(mean.summarize(Metric.IoU))) 155 | sys.stdout.write(" Pr: {:.2f}\t".format(mean.summarize(Metric.Precision))) 156 | sys.stdout.write(" Re: {:.2f}\n".format(mean.summarize(Metric.Recall))) 157 | 158 | print(split) 159 | print("With background: ") 160 | sys.stdout.write("mIoU: {:.2f}\t".format(mean_bkg.summarize(Metric.IoU))) 161 | sys.stdout.write(" Pr: {:.2f}\t".format(mean_bkg.summarize(Metric.Precision))) 162 | sys.stdout.write(" Re: {:.2f}\n".format(mean_bkg.summarize(Metric.Recall))) 163 | 164 | 165 | def evaluate_one(conf_mat, mask_gt, mask): 166 | 167 | gt = mask_gt.reshape(-1) 168 | pred = mask.reshape(-1) 169 | conf_mat_one = np.zeros_like(conf_mat) 170 | 171 | assert(len(gt) == len(pred)) 172 | 173 | for i in range(len(gt)): 174 | if gt[i] < conf_mat.shape[0]: 175 | conf_mat[gt[i], pred[i]] += 1.0 176 | conf_mat_one[gt[i], pred[i]] += 1.0 177 | 178 | return conf_mat_one 179 | 180 | def read_mask_file(filepath): 181 | return np.array(Image.open(filepath)) 182 | 183 | def oracle_lower(mask, h, w, alpha): 184 | 185 | mask_dict = {} 186 | labels = np.unique(mask) 187 | new_mask = np.zeros_like(mask) 188 | H, W = mask.shape 189 | 190 | # skipping background 191 | for l in labels: 192 | if l in (0, 255): 193 | continue 194 | 195 | mask_l = (mask == l).astype(np.float) 196 | mask_down = scipy.misc.imresize(mask_l, (h, w), interp='bilinear') 197 | mask_up = scipy.misc.imresize(mask_down, (H, W), interp='bilinear') 198 | new_mask[mask_up > alpha] = l 199 | 200 | return new_mask 201 | 202 | def get_image_name(name): 203 | base = os.path.basename(name) 204 | base = base.replace(".jpg", "") 205 | return base 206 | 207 | def evaluate_all(args): 208 | 209 | with_oracle = False 210 | if len(args.oracle_from) > 0: 211 | oh, ow = [int(x) for x in args.oracle_from.split("x")] 212 | with_oracle = (oh > 1 and ow > 1) 213 | 214 | if with_oracle: 215 | print(">>> Using oracle {}x{}".format(oh, ow)) 216 | 217 | # initialising the confusion matrix 218 | conf_mat = np.zeros((21, 21)) 219 | class_stats = {} 220 | for class_idx in range(21): 221 | class_stats[class_idx] = [] 222 | 223 | # count of the images 224 | num_im = 0 225 | 226 | # opening the filelist 227 | with open(args.filelist) as fd: 228 | 229 | for line in tqdm(fd.readlines()): 230 | 231 | files = [x.strip('/ \n') for x in line.split(' ')] 232 | 233 | if len(files) < 2: 234 | print("No path to GT mask found in line\n") 235 | print("\t{}".format(line)) 236 | continue 237 | 238 | filepath_gt = os.path.join(args.data, files[1]) 239 | if not os.path.isfile(filepath_gt): 240 | print("File not found (GT): {}".format(filepath_gt)) 241 | continue 242 | 243 | mask_gt = read_mask_file(filepath_gt) 244 | 245 | if with_oracle: 246 | mask = oracle_lower(mask_gt, oh, ow, alpha=0.5) 247 | else: 248 | basename = os.path.basename(files[1]) 249 | filepath = os.path.join(args.masks, basename) 250 | if not os.path.isfile(filepath): 251 | print("File not found: {}".format(filepath)) 252 | continue 253 | 254 | mask = read_mask_file(filepath) 255 | 256 | if mask.shape != mask_gt.shape: 257 | print("Mask shape mismatch in {}: ".format(basename), \ 258 | mask.shape, " vs ", mask_gt.shape) 259 | continue 260 | 261 | conf_mat_one = evaluate_one(conf_mat, mask_gt, mask) 262 | 263 | image_name = get_image_name(files[0]) 264 | image_labels = np.unique(mask_gt) 265 | summarise_one(class_stats, conf_mat_one, image_name, image_labels) 266 | 267 | num_im += 1 268 | 269 | 270 | print("# of images: {}".format(num_im)) 271 | summarise_per_class(class_stats, args.log_scores) 272 | 273 | return conf_mat 274 | 275 | if __name__ == "__main__": 276 | 277 | args = parser.parse_args(sys.argv[1:]) 278 | check_args(args) 279 | stats = evaluate_all(args) 280 | summarise_stats(stats) 281 | -------------------------------------------------------------------------------- /figures/results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/figures/results.gif -------------------------------------------------------------------------------- /figures/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/figures/results.png -------------------------------------------------------------------------------- /fonts/UbuntuMono-R.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/fonts/UbuntuMono-R.ttf -------------------------------------------------------------------------------- /infer_val.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluating class activation maps from a given snapshot 3 | Supports multi-scale fusion of the masks 4 | Based on PSA 5 | """ 6 | 7 | import os 8 | import sys 9 | import numpy as np 10 | import scipy 11 | import torch.multiprocessing as mp 12 | from tqdm import tqdm 13 | from PIL import Image 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torchvision 18 | import torchvision.transforms as tf 19 | from torch.utils.data import DataLoader 20 | from torch.backends import cudnn 21 | cudnn.enabled = True 22 | cudnn.benchmark = False 23 | cudnn.deterministic = True 24 | 25 | from opts import get_arguments 26 | from core.config import cfg, cfg_from_file, cfg_from_list 27 | from models import get_model 28 | 29 | from utils.checkpoints import Checkpoint 30 | from utils.timer import Timer 31 | from utils.dcrf import crf_inference 32 | from utils.inference_tools import get_inference_io 33 | 34 | def check_dir(base_path, name): 35 | 36 | # create the directory 37 | fullpath = os.path.join(base_path, name) 38 | if not os.path.exists(fullpath): 39 | os.makedirs(fullpath) 40 | 41 | return fullpath 42 | 43 | def HWC_to_CHW(img): 44 | return np.transpose(img, (2, 0, 1)) 45 | 46 | if __name__ == '__main__': 47 | 48 | # loading the model 49 | args = get_arguments(sys.argv[1:]) 50 | 51 | # reading the config 52 | cfg_from_file(args.cfg_file) 53 | if args.set_cfgs is not None: 54 | cfg_from_list(args.set_cfgs) 55 | 56 | # initialising the dirs 57 | check_dir(args.mask_output_dir, "vis") 58 | check_dir(args.mask_output_dir, "crf") 59 | 60 | # Loading the model 61 | model = get_model(cfg.NET, num_classes=cfg.TEST.NUM_CLASSES) 62 | checkpoint = Checkpoint(args.snapshot_dir, max_n = 5) 63 | checkpoint.add_model('enc', model) 64 | checkpoint.load(args.resume) 65 | 66 | for p in model.parameters(): 67 | p.requires_grad = False 68 | 69 | # setting the evaluation mode 70 | model.eval() 71 | 72 | assert hasattr(model, 'normalize') 73 | transform = tf.Compose([np.asarray, model.normalize]) 74 | 75 | WriterClass, DatasetClass = get_inference_io(cfg.TEST.METHOD) 76 | 77 | dataset = DatasetClass(args.infer_list, cfg.TEST, transform=transform) 78 | 79 | dataloader = DataLoader(dataset, shuffle=False, num_workers=args.workers, \ 80 | pin_memory=True, batch_size=cfg.TEST.BATCH_SIZE) 81 | 82 | model = nn.DataParallel(model).cuda() 83 | 84 | timer = Timer() 85 | N = len(dataloader) 86 | 87 | palette = dataset.get_palette() 88 | pool = mp.Pool(processes=args.workers) 89 | writer = WriterClass(cfg.TEST, palette, args.mask_output_dir) 90 | 91 | for iter, (img_name, img_orig, images_in, pads, labels, gt_mask) in enumerate(tqdm(dataloader)): 92 | 93 | # cutting the masks 94 | masks = [] 95 | 96 | with torch.no_grad(): 97 | cls_raw, masks_pred = model(images_in) 98 | 99 | if not cfg.TEST.USE_GT_LABELS: 100 | cls_sigmoid = torch.sigmoid(cls_raw) 101 | cls_sigmoid, _ = cls_sigmoid.max(0) 102 | #cls_sigmoid = cls_sigmoid.mean(0) 103 | # threshold class scores 104 | labels = (cls_sigmoid > cfg.TEST.FP_CUT_SCORE) 105 | else: 106 | labels = labels[0] 107 | 108 | # saving the raw npy 109 | image = dataset.denorm(img_orig[0]).numpy() 110 | masks_pred = masks_pred.cpu() 111 | labels = labels.type_as(masks_pred) 112 | 113 | #writer.save(img_name[0], image, masks_pred, pads, labels, gt_mask[0]) 114 | pool.apply_async(writer.save, args=(img_name[0], image, masks_pred, pads, labels, gt_mask[0])) 115 | 116 | timer.update_progress(float(iter + 1) / N) 117 | if iter % 100 == 0: 118 | msg = "Finish time: {}".format(timer.str_est_finish()) 119 | tqdm.write(msg) 120 | sys.stdout.flush() 121 | 122 | pool.close() 123 | pool.join() 124 | -------------------------------------------------------------------------------- /launch/eval_seg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=pascal_voc 4 | FILELIST=data/val_voc.txt # validation 5 | 6 | ## You values here: 7 | # 8 | OUTPUT_DIR= 9 | EXP= 10 | RUN_ID= 11 | # 12 | ## 13 | 14 | 15 | LISTNAME=`basename $FILELIST .txt` 16 | 17 | # without CRF 18 | SAVE_DIR=$OUTPUT_DIR/$DATASET/$EXP/$RUN_ID/$LISTNAME 19 | nohup python eval_seg.py --data ./data --filelist $FILELIST --masks $SAVE_DIR > $SAVE_DIR.eval 2>&1 & 20 | 21 | # with CRF 22 | SAVE_DIR=$OUTPUT_DIR/$DATASET/$EXP/$RUN_ID/$LISTNAME/crf 23 | nohup python eval_seg.py --data ./data --filelist $FILELIST --masks $SAVE_DIR > $SAVE_DIR.eval 2>&1 & 24 | 25 | sleep 1 26 | 27 | echo "Log: ${SAVE_DIR}.eval" 28 | tail -f $SAVE_DIR.eval 29 | -------------------------------------------------------------------------------- /launch/infer_val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Set your argument here 5 | # 6 | CONFIG=configs/voc_resnet38.yaml 7 | DATASET=pascal_voc 8 | FILELIST=data/val_voc.txt 9 | 10 | ## You values here (see below how they're used) 11 | # 12 | OUTPUT_DIR= 13 | EXP= 14 | RUN_ID= 15 | SNAPSHOT= 16 | EXTRA_ARGS= 17 | SAVE_ID= 18 | # 19 | ## 20 | 21 | # limiting threads 22 | NUM_THREADS=6 23 | 24 | set OMP_NUM_THREADS=$NUM_THREADS 25 | export OMP_NUM_THREADS=$NUM_THREADS 26 | 27 | # 28 | # Code goes here 29 | # 30 | LISTNAME=`basename $FILELIST .txt` 31 | SAVE_DIR=$OUTPUT_DIR/$DATASET/$EXP/$SAVE_ID/$LISTNAME 32 | LOG_FILE=$OUTPUT_DIR/$DATASET/$EXP/$SAVE_ID/$LISTNAME.log 33 | 34 | CMD="python infer_val.py --dataset $DATASET \ 35 | --cfg $CONFIG \ 36 | --exp $EXP \ 37 | --run $RUN_ID \ 38 | --resume $SNAPSHOT \ 39 | --infer-list $FILELIST \ 40 | --workers $NUM_THREADS \ 41 | --mask-output-dir $SAVE_DIR \ 42 | $EXTRA_ARGS" 43 | 44 | if [ ! -d $SAVE_DIR ]; then 45 | echo "Creating directory: $SAVE_DIR" 46 | mkdir -p $SAVE_DIR 47 | else 48 | echo "Saving to: $SAVE_DIR" 49 | fi 50 | 51 | git rev-parse HEAD > ${SAVE_DIR}.head 52 | git diff > ${SAVE_DIR}.diff 53 | echo $CMD > ${SAVE_DIR}.cmd 54 | 55 | echo $CMD 56 | nohup $CMD > $LOG_FILE 2>&1 & 57 | 58 | sleep 1 59 | tail -f $LOG_FILE 60 | -------------------------------------------------------------------------------- /launch/run_bsl_resnet101.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Your values here: 4 | # 5 | DS=pascal_voc 6 | EXP= 7 | RUN_ID= 8 | # 9 | ## 10 | 11 | # 12 | # Script 13 | # 14 | 15 | LOG_DIR=logs/${DS}/${EXP} 16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet101.yaml --exp $EXP --run $RUN_ID --set NET.MODEL bsl TRAIN.NUM_EPOCHS 6" 17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log 18 | 19 | if [ ! -d "$LOG_DIR" ]; then 20 | echo "Creating directory $LOG_DIR" 21 | mkdir -p $LOG_DIR 22 | fi 23 | 24 | echo $CMD 25 | echo "LOG: $LOG_FILE" 26 | 27 | nohup $CMD > $LOG_FILE 2>&1 & 28 | sleep 1 29 | tail -f $LOG_FILE 30 | -------------------------------------------------------------------------------- /launch/run_bsl_resnet38.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Your values here: 4 | # 5 | DS=pascal_voc 6 | EXP= 7 | RUN_ID= 8 | # 9 | ## 10 | 11 | # 12 | # Script 13 | # 14 | 15 | LOG_DIR=logs/${DS}/${EXP} 16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet38.yaml --exp $EXP --run $RUN_ID --set NET.MODEL bsl TRAIN.NUM_EPOCHS 6" 17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log 18 | 19 | if [ ! -d "$LOG_DIR" ]; then 20 | echo "Creating directory $LOG_DIR" 21 | mkdir -p $LOG_DIR 22 | fi 23 | 24 | echo $CMD 25 | echo "LOG: $LOG_FILE" 26 | 27 | nohup $CMD > $LOG_FILE 2>&1 & 28 | sleep 1 29 | tail -f $LOG_FILE 30 | -------------------------------------------------------------------------------- /launch/run_bsl_resnet50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Your values here: 4 | # 5 | DS=pascal_voc 6 | EXP= 7 | RUN_ID= 8 | # 9 | ## 10 | 11 | # 12 | # Script 13 | # 14 | 15 | LOG_DIR=logs/${DS}/${EXP} 16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet50.yaml --exp $EXP --run $RUN_ID --set NET.MODEL bsl TRAIN.NUM_EPOCHS 6" 17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log 18 | 19 | if [ ! -d "$LOG_DIR" ]; then 20 | echo "Creating directory $LOG_DIR" 21 | mkdir -p $LOG_DIR 22 | fi 23 | 24 | echo $CMD 25 | echo "LOG: $LOG_FILE" 26 | 27 | nohup $CMD > $LOG_FILE 2>&1 & 28 | sleep 1 29 | tail -f $LOG_FILE 30 | -------------------------------------------------------------------------------- /launch/run_bsl_vgg16.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Your values here: 4 | # 5 | DS=pascal_voc 6 | EXP= 7 | RUN_ID= 8 | # 9 | ## 10 | 11 | # 12 | # Script 13 | # 14 | 15 | LOG_DIR=logs/${DS}/${EXP} 16 | CMD="python train.py --dataset $DS --cfg configs/voc_vgg16.yaml --exp $EXP --run $RUN_ID --set NET.MODEL bsl TRAIN.NUM_EPOCHS 6" 17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log 18 | 19 | if [ ! -d "$LOG_DIR" ]; then 20 | echo "Creating directory $LOG_DIR" 21 | mkdir -p $LOG_DIR 22 | fi 23 | 24 | echo $CMD 25 | echo "LOG: $LOG_FILE" 26 | 27 | nohup $CMD > $LOG_FILE 2>&1 & 28 | sleep 1 29 | tail -f $LOG_FILE 30 | -------------------------------------------------------------------------------- /launch/run_voc_resnet101.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Your values here: 4 | # 5 | DS=pascal_voc 6 | EXP= 7 | RUN_ID= 8 | # 9 | ## 10 | 11 | # 12 | # Script 13 | # 14 | 15 | LOG_DIR=logs/${DS}/${EXP} 16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet101.yaml --exp $EXP --run $RUN_ID" 17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log 18 | 19 | if [ ! -d "$LOG_DIR" ]; then 20 | echo "Creating directory $LOG_DIR" 21 | mkdir -p $LOG_DIR 22 | fi 23 | 24 | echo $CMD 25 | echo "LOG: $LOG_FILE" 26 | 27 | nohup $CMD > $LOG_FILE 2>&1 & 28 | sleep 1 29 | tail -f $LOG_FILE 30 | -------------------------------------------------------------------------------- /launch/run_voc_resnet38.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Your values here: 4 | # 5 | DS=pascal_voc 6 | EXP= 7 | RUN_ID= 8 | # 9 | ## 10 | 11 | # 12 | # Script 13 | # 14 | 15 | LOG_DIR=logs/${DS}/${EXP} 16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet38.yaml --exp $EXP --run $RUN_ID" 17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log 18 | 19 | if [ ! -d "$LOG_DIR" ]; then 20 | echo "Creating directory $LOG_DIR" 21 | mkdir -p $LOG_DIR 22 | fi 23 | 24 | echo $CMD 25 | echo "LOG: $LOG_FILE" 26 | 27 | nohup $CMD > $LOG_FILE 2>&1 & 28 | sleep 1 29 | tail -f $LOG_FILE 30 | -------------------------------------------------------------------------------- /launch/run_voc_resnet50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Your values here: 4 | # 5 | DS=pascal_voc 6 | EXP= 7 | RUN_ID= 8 | # 9 | ## 10 | 11 | # 12 | # Script 13 | # 14 | 15 | LOG_DIR=logs/${DS}/${EXP} 16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet50.yaml --exp $EXP --run $RUN_ID" 17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log 18 | 19 | if [ ! -d "$LOG_DIR" ]; then 20 | echo "Creating directory $LOG_DIR" 21 | mkdir -p $LOG_DIR 22 | fi 23 | 24 | echo $CMD 25 | echo "LOG: $LOG_FILE" 26 | 27 | nohup $CMD > $LOG_FILE 2>&1 & 28 | sleep 1 29 | tail -f $LOG_FILE 30 | -------------------------------------------------------------------------------- /launch/run_voc_vgg16.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Your values here: 4 | # 5 | DS=pascal_voc 6 | EXP= 7 | RUN_ID= 8 | # 9 | ## 10 | 11 | # 12 | # Script 13 | # 14 | 15 | LOG_DIR=logs/${DS}/${EXP} 16 | CMD="python train.py --dataset $DS --cfg configs/voc_vgg16.yaml --exp $EXP --run $RUN_ID" 17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log 18 | 19 | if [ ! -d "$LOG_DIR" ]; then 20 | echo "Creating directory $LOG_DIR" 21 | mkdir -p $LOG_DIR 22 | fi 23 | 24 | echo $CMD 25 | echo "LOG: $LOG_FILE" 26 | 27 | nohup $CMD > $LOG_FILE 2>&1 & 28 | sleep 1 29 | tail -f $LOG_FILE 30 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | class MLHingeLoss(nn.Module): 7 | 8 | def forward(self, x, y, reduction='mean'): 9 | """ 10 | y: labels have standard {0,1} form and will be converted to indices 11 | """ 12 | b, c = x.size() 13 | idx = (torch.arange(c) + 1).type_as(x) 14 | y_idx, _ = (idx * y).sort(-1, descending=True) 15 | y_idx = (y_idx - 1).long() 16 | 17 | return F.multilabel_margin_loss(x, y_idx, reduction=reduction) 18 | 19 | def get_criterion(loss_name, **kwargs): 20 | 21 | losses = { 22 | "SoftMargin": nn.MultiLabelSoftMarginLoss, 23 | "Hinge": MLHingeLoss 24 | } 25 | 26 | return losses[loss_name](**kwargs) 27 | 28 | 29 | # 30 | # Mask self-supervision 31 | # 32 | def mask_loss_ce(mask, pseudo_gt, ignore_index=255): 33 | mask = F.interpolate(mask, size=pseudo_gt.size()[-2:], mode="bilinear", align_corners=True) 34 | 35 | # indices of the max classes 36 | mask_gt = torch.argmax(pseudo_gt, 1) 37 | 38 | # for each pixel there should be at least one 1 39 | # otherwise, ignore 40 | weight = pseudo_gt.sum(1).type_as(mask_gt) 41 | mask_gt += (1 - weight) * ignore_index 42 | 43 | # BCE loss 44 | loss = F.cross_entropy(mask, mask_gt, ignore_index=ignore_index) 45 | return loss.mean() 46 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from .stage_net import network_factory 3 | 4 | def get_model(cfg, *args, **kwargs): 5 | net = partial(network_factory(cfg), config=cfg, pre_weights=cfg.PRE_WEIGHTS_PATH) 6 | return net(*args, **kwargs) 7 | -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/models/backbones/__init__.py -------------------------------------------------------------------------------- /models/backbones/base_net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Normalize(): 7 | def __init__(self, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)): 8 | 9 | self.mean = mean 10 | self.std = std 11 | 12 | def undo(self, imgarr): 13 | proc_img = imgarr.copy() 14 | 15 | proc_img[..., 0] = (self.std[0] * imgarr[..., 0] + self.mean[0]) * 255. 16 | proc_img[..., 1] = (self.std[1] * imgarr[..., 1] + self.mean[1]) * 255. 17 | proc_img[..., 2] = (self.std[2] * imgarr[..., 2] + self.mean[2]) * 255. 18 | 19 | return proc_img 20 | 21 | def __call__(self, img): 22 | imgarr = np.asarray(img) 23 | proc_img = np.empty_like(imgarr, np.float32) 24 | 25 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0] 26 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1] 27 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2] 28 | 29 | return proc_img 30 | 31 | class BaseNet(nn.Module): 32 | 33 | def __init__(self): 34 | super().__init__() 35 | self.normalize = Normalize() 36 | self.NormLayer = nn.BatchNorm2d 37 | 38 | self.not_training = [] # freezing parameters 39 | self.bn_frozen = [] # freezing running stats 40 | self.from_scratch_layers = [] # new layers -> higher LR 41 | 42 | def _init_weights(self, path_to_weights): 43 | print("Loading weights from: ", path_to_weights) 44 | weights_dict = torch.load(path_to_weights) 45 | self.load_state_dict(weights_dict, strict=False) 46 | 47 | def fan_out(self): 48 | raise NotImplementedError 49 | 50 | def fixed_layers(self): 51 | return self.not_training 52 | 53 | def _fix_running_stats(self, layer, fix_params=False): 54 | 55 | if isinstance(layer, self.NormLayer): 56 | self.bn_frozen.append(layer) 57 | if fix_params and not layer in self.not_training: 58 | self.not_training.append(layer) 59 | elif isinstance(layer, list): 60 | for m in layer: 61 | self._fix_running_stats(m, fix_params) 62 | else: 63 | for m in layer.children(): 64 | self._fix_running_stats(m, fix_params) 65 | 66 | def _fix_params(self, layer): 67 | 68 | if isinstance(layer, nn.Conv2d) or \ 69 | isinstance(layer, self.NormLayer) or \ 70 | isinstance(layer, nn.Linear): 71 | self.not_training.append(layer) 72 | if isinstance(layer, self.NormLayer): 73 | self.bn_frozen.append(layer) 74 | elif isinstance(layer, list): 75 | for m in layer: 76 | self._fix_params(m) 77 | elif isinstance(layer, nn.Module): 78 | if hasattr(layer, "weight") or hasattr(layer, "bias"): 79 | print("Ignoring fixed weight/bias layer: ", layer) 80 | 81 | for m in layer.children(): 82 | self._fix_params(m) 83 | 84 | def _freeze_bn(self, layer): 85 | 86 | if isinstance(layer, self.NormLayer): 87 | # freezing the layer 88 | layer.eval() 89 | elif isinstance(layer, nn.Module): 90 | for m in layer.children(): 91 | self._freeze_bn(m) 92 | 93 | def train(self, mode=True): 94 | 95 | super().train(mode) 96 | 97 | for layer in self.not_training: 98 | 99 | if hasattr(layer, "weight") and not layer.weight is None: 100 | layer.weight.requires_grad = False 101 | 102 | if hasattr(layer, "bias") and not layer.bias is None: 103 | layer.bias.requires_grad = False 104 | 105 | elif isinstance(layer, torch.nn.Module): 106 | print("Unkown layer to fix: ", layer) 107 | 108 | for bn_layer in self.bn_frozen: 109 | self._freeze_bn(bn_layer) 110 | 111 | def _lr_mult(self): 112 | return 1., 2., 10., 20 113 | 114 | def parameter_groups(self, base_lr, wd): 115 | 116 | w_old, b_old, w_new, b_new = self._lr_mult() 117 | 118 | groups = ({"params": [], "weight_decay": wd, "lr": w_old*base_lr}, # weight learning 119 | {"params": [], "weight_decay": 0.0, "lr": b_old*base_lr}, # bias finetuning 120 | {"params": [], "weight_decay": wd, "lr": w_new*base_lr}, # weight finetuning 121 | {"params": [], "weight_decay": 0.0, "lr": b_new*base_lr}) # bias learning 122 | 123 | fixed_layers = self.fixed_layers() 124 | 125 | for m in self.modules(): 126 | 127 | if m in fixed_layers: 128 | # skipping fixed layers 129 | continue 130 | 131 | if isinstance(m, nn.Conv2d) or \ 132 | isinstance(m, nn.Linear) or \ 133 | isinstance(m, self.NormLayer): 134 | 135 | if not m.weight is None: 136 | if m in self.from_scratch_layers: 137 | groups[2]["params"].append(m.weight) 138 | else: 139 | groups[0]["params"].append(m.weight) 140 | 141 | if not m.bias is None: 142 | if m in self.from_scratch_layers: 143 | groups[3]["params"].append(m.bias) 144 | else: 145 | groups[1]["params"].append(m.bias) 146 | 147 | elif hasattr(m, "weight"): 148 | print("! Skipping learnable: ", m) 149 | 150 | for i, g in enumerate(groups): 151 | print("Group {}: #{}, LR={:4.3e}".format(i, len(g["params"]), g["lr"])) 152 | 153 | return groups 154 | -------------------------------------------------------------------------------- /models/backbones/resnet38d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | import torch.nn.functional as F 6 | 7 | from models.backbones.base_net import BaseNet 8 | 9 | class ResBlock(nn.Module): 10 | def __init__(self, in_channels, mid_channels, out_channels, stride=1, first_dilation=None, dilation=1): 11 | super(ResBlock, self).__init__() 12 | 13 | self.same_shape = (in_channels == out_channels and stride == 1) 14 | 15 | if first_dilation == None: first_dilation = dilation 16 | 17 | self.bn_branch2a = nn.BatchNorm2d(in_channels) 18 | 19 | self.conv_branch2a = nn.Conv2d(in_channels, mid_channels, 3, stride, 20 | padding=first_dilation, dilation=first_dilation, bias=False) 21 | 22 | self.bn_branch2b1 = nn.BatchNorm2d(mid_channels) 23 | 24 | self.conv_branch2b1 = nn.Conv2d(mid_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False) 25 | 26 | if not self.same_shape: 27 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) 28 | 29 | def forward(self, x, get_x_bn_relu=False): 30 | 31 | branch2 = self.bn_branch2a(x) 32 | branch2 = F.relu(branch2) 33 | 34 | x_bn_relu = branch2 35 | 36 | if not self.same_shape: 37 | branch1 = self.conv_branch1(branch2) 38 | else: 39 | branch1 = x 40 | 41 | branch2 = self.conv_branch2a(branch2) 42 | branch2 = self.bn_branch2b1(branch2) 43 | branch2 = F.relu(branch2) 44 | branch2 = self.conv_branch2b1(branch2) 45 | 46 | x = branch1 + branch2 47 | 48 | if get_x_bn_relu: 49 | return x, x_bn_relu 50 | 51 | return x 52 | 53 | def __call__(self, x, get_x_bn_relu=False): 54 | return self.forward(x, get_x_bn_relu=get_x_bn_relu) 55 | 56 | class ResBlock_bot(nn.Module): 57 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, dropout=0.): 58 | super(ResBlock_bot, self).__init__() 59 | 60 | self.same_shape = (in_channels == out_channels and stride == 1) 61 | 62 | self.bn_branch2a = nn.BatchNorm2d(in_channels) 63 | self.conv_branch2a = nn.Conv2d(in_channels, out_channels//4, 1, stride, bias=False) 64 | 65 | self.bn_branch2b1 = nn.BatchNorm2d(out_channels//4) 66 | self.dropout_2b1 = torch.nn.Dropout2d(dropout) 67 | self.conv_branch2b1 = nn.Conv2d(out_channels//4, out_channels//2, 3, padding=dilation, dilation=dilation, bias=False) 68 | 69 | self.bn_branch2b2 = nn.BatchNorm2d(out_channels//2) 70 | self.dropout_2b2 = torch.nn.Dropout2d(dropout) 71 | self.conv_branch2b2 = nn.Conv2d(out_channels//2, out_channels, 1, bias=False) 72 | 73 | if not self.same_shape: 74 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) 75 | 76 | def forward(self, x, get_x_bn_relu=False): 77 | 78 | branch2 = self.bn_branch2a(x) 79 | branch2 = F.relu(branch2) 80 | x_bn_relu = branch2 81 | 82 | branch1 = self.conv_branch1(branch2) 83 | 84 | branch2 = self.conv_branch2a(branch2) 85 | 86 | branch2 = self.bn_branch2b1(branch2) 87 | branch2 = F.relu(branch2) 88 | branch2 = self.dropout_2b1(branch2) 89 | branch2 = self.conv_branch2b1(branch2) 90 | 91 | branch2 = self.bn_branch2b2(branch2) 92 | branch2 = F.relu(branch2) 93 | branch2 = self.dropout_2b2(branch2) 94 | branch2 = self.conv_branch2b2(branch2) 95 | 96 | x = branch1 + branch2 97 | 98 | if get_x_bn_relu: 99 | return x, x_bn_relu 100 | 101 | return x 102 | 103 | def __call__(self, x, get_x_bn_relu=False): 104 | return self.forward(x, get_x_bn_relu=get_x_bn_relu) 105 | 106 | class ResNet38(BaseNet): 107 | 108 | def __init__(self): 109 | super(ResNet38, self).__init__() 110 | 111 | self.conv1a = nn.Conv2d(3, 64, 3, padding=1, bias=False) 112 | 113 | self.b2 = ResBlock(64, 128, 128, stride=2) 114 | self.b2_1 = ResBlock(128, 128, 128) 115 | self.b2_2 = ResBlock(128, 128, 128) 116 | 117 | self.b3 = ResBlock(128, 256, 256, stride=2) 118 | self.b3_1 = ResBlock(256, 256, 256) 119 | self.b3_2 = ResBlock(256, 256, 256) 120 | 121 | self.b4 = ResBlock(256, 512, 512, stride=2) 122 | self.b4_1 = ResBlock(512, 512, 512) 123 | self.b4_2 = ResBlock(512, 512, 512) 124 | self.b4_3 = ResBlock(512, 512, 512) 125 | self.b4_4 = ResBlock(512, 512, 512) 126 | self.b4_5 = ResBlock(512, 512, 512) 127 | 128 | self.b5 = ResBlock(512, 512, 1024, stride=1, first_dilation=1, dilation=2) 129 | self.b5_1 = ResBlock(1024, 512, 1024, dilation=2) 130 | self.b5_2 = ResBlock(1024, 512, 1024, dilation=2) 131 | 132 | self.b6 = ResBlock_bot(1024, 2048, stride=1, dilation=4, dropout=0.3) 133 | 134 | self.b7 = ResBlock_bot(2048, 4096, dilation=4, dropout=0.5) 135 | 136 | self.bn7 = nn.BatchNorm2d(4096) 137 | 138 | # fixing the parameters 139 | self._fix_params([self.conv1a, self.b2, self.b2_1, self.b2_2]) 140 | 141 | def fan_out(self): 142 | return 4096 143 | 144 | def forward(self, x): 145 | return self.forward_as_dict(x)['conv6'] 146 | 147 | def forward_as_dict(self, x): 148 | 149 | x = self.conv1a(x) 150 | 151 | x = self.b2(x) 152 | x = self.b2_1(x) 153 | x = self.b2_2(x) 154 | 155 | x = self.b3(x) 156 | x = self.b3_1(x) 157 | x = self.b3_2(x) 158 | conv3 = x 159 | 160 | x = self.b4(x) 161 | x = self.b4_1(x) 162 | x = self.b4_2(x) 163 | x = self.b4_3(x) 164 | x = self.b4_4(x) 165 | x = self.b4_5(x) 166 | 167 | x, conv4 = self.b5(x, get_x_bn_relu=True) 168 | x = self.b5_1(x) 169 | x = self.b5_2(x) 170 | 171 | x, conv5 = self.b6(x, get_x_bn_relu=True) 172 | 173 | x = self.b7(x) 174 | conv6 = F.relu(self.bn7(x)) 175 | 176 | return dict({'conv3': conv3, 'conv6': conv6}) 177 | -------------------------------------------------------------------------------- /models/backbones/resnets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.backbones.base_net import BaseNet 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=dilation, groups=groups, bias=False, dilation=dilation) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | """1x1 convolution""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | class Bottleneck(nn.Module): 18 | expansion = 4 19 | __constants__ = ['downsample'] 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 22 | base_width=64, dilation=1, norm_layer=None): 23 | super(Bottleneck, self).__init__() 24 | if norm_layer is None: 25 | norm_layer = nn.BatchNorm2d 26 | width = int(planes * (base_width / 64.)) * groups 27 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 28 | self.conv1 = conv1x1(inplanes, width) 29 | self.bn1 = norm_layer(width) 30 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 31 | self.bn2 = norm_layer(width) 32 | self.conv3 = conv1x1(width, planes * self.expansion) 33 | self.bn3 = norm_layer(planes * self.expansion) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | identity = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv3(out) 50 | out = self.bn3(out) 51 | 52 | if self.downsample is not None: 53 | identity = self.downsample(x) 54 | 55 | out += identity 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | class ResNet(BaseNet): 61 | 62 | def __init__(self, block, layers, zero_init_residual=False, 63 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 64 | norm_layer=None, deep_base=False): 65 | super(ResNet, self).__init__() 66 | if norm_layer is None: 67 | norm_layer = nn.BatchNorm2d 68 | self._norm_layer = norm_layer 69 | 70 | self.dilation = 1 71 | if replace_stride_with_dilation is None: 72 | # each element in the tuple indicates if we should replace 73 | # the 2x2 stride with a dilated convolution instead 74 | replace_stride_with_dilation = [False, False, False] 75 | if len(replace_stride_with_dilation) != 3: 76 | raise ValueError("replace_stride_with_dilation should be None " 77 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 78 | self.groups = groups 79 | self.base_width = width_per_group 80 | self.deep_base = deep_base 81 | if not self.deep_base: # see PSPNet implementation 82 | self.inplanes = 64 83 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 84 | bias=False) 85 | self.bn1 = norm_layer(self.inplanes) 86 | else: 87 | self.inplanes = 128 88 | self.conv1 = conv3x3(3, 64, stride=2) 89 | self.bn1 = norm_layer(64) 90 | self.conv2 = conv3x3(64, 64) 91 | self.bn2 = norm_layer(64) 92 | self.conv3 = conv3x3(64, 128) 93 | self.bn3 = norm_layer(128) 94 | 95 | self.relu = nn.ReLU(inplace=True) 96 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 97 | self.layer1 = self._make_layer(block, 64, layers[0]) 98 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 99 | dilate=replace_stride_with_dilation[0]) 100 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 101 | dilate=replace_stride_with_dilation[1]) 102 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 103 | dilate=replace_stride_with_dilation[2]) 104 | 105 | # note no global pooling of fully-connected layers 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 111 | nn.init.constant_(m.weight, 1) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | # Zero-initialize the last BN in each residual branch, 115 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 116 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 117 | if zero_init_residual: 118 | for m in self.modules(): 119 | if isinstance(m, Bottleneck): 120 | nn.init.constant_(m.bn3.weight, 0) 121 | elif isinstance(m, BasicBlock): 122 | nn.init.constant_(m.bn2.weight, 0) 123 | 124 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 125 | norm_layer = self._norm_layer 126 | downsample = None 127 | previous_dilation = self.dilation 128 | if dilate: 129 | self.dilation *= stride 130 | stride = 1 131 | if stride != 1 or self.inplanes != planes * block.expansion: 132 | downsample = nn.Sequential( 133 | conv1x1(self.inplanes, planes * block.expansion, stride), 134 | norm_layer(planes * block.expansion), 135 | ) 136 | 137 | layers = [] 138 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 139 | self.base_width, previous_dilation, norm_layer)) 140 | self.inplanes = planes * block.expansion 141 | for _ in range(1, blocks): 142 | layers.append(block(self.inplanes, planes, groups=self.groups, 143 | base_width=self.base_width, dilation=self.dilation, 144 | norm_layer=norm_layer)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def _forward_impl(self, x): 149 | x = self.conv1(x) 150 | x = self.bn1(x) 151 | x = self.relu(x) 152 | 153 | if self.deep_base: 154 | x = self.relu(self.bn2(self.conv2(x))) 155 | x = self.relu(self.bn3(self.conv3(x))) 156 | 157 | x = self.maxpool(x) 158 | 159 | x = self.layer1(x) 160 | x = self.layer2(x) 161 | x = self.layer3(x) 162 | x = self.layer4(x) 163 | 164 | return x 165 | 166 | def forward_as_dict(self, x): 167 | # See note [TorchScript super()] 168 | x = self.conv1(x) 169 | x = self.bn1(x) 170 | x = self.relu(x) 171 | 172 | if self.deep_base: 173 | x = self.relu(self.bn2(self.conv2(x))) 174 | x = self.relu(self.bn3(self.conv3(x))) 175 | 176 | x = self.maxpool(x) 177 | 178 | x = self.layer1(x) 179 | conv3 = x 180 | 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | x = self.layer4(x) 184 | 185 | return {"conv6": x, "conv3": conv3} 186 | 187 | def forward(self, x): 188 | return self.forward_as_dict(x)["conv6"] 189 | 190 | def _lr_mult(self): 191 | return 1., 1., 10., 10. 192 | 193 | class ResNet50(ResNet): 194 | 195 | def __init__(self): 196 | super(ResNet50, self).__init__(Bottleneck, [3, 4, 6, 3], \ 197 | replace_stride_with_dilation=[False, False, False]) 198 | 199 | # fixing the parameters 200 | self._fix_params([self.conv1, self.bn1]) 201 | 202 | assert not self.deep_base 203 | 204 | def fan_out(self): 205 | return 2048 206 | 207 | class ResNet101(ResNet): 208 | 209 | def __init__(self): 210 | super(ResNet101, self).__init__(Bottleneck, [3, 4, 23, 3], \ 211 | replace_stride_with_dilation=[False, False, False]) 212 | 213 | # fixing the parameters 214 | self._fix_params([self.conv1, self.bn1]) 215 | 216 | assert not self.deep_base 217 | 218 | def fan_out(self): 219 | return 2048 220 | -------------------------------------------------------------------------------- /models/backbones/vgg16d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from models.backbones.base_net import BaseNet 8 | 9 | class VGG16(BaseNet): 10 | def __init__(self, fc6_dilation = 1): 11 | super(VGG16, self).__init__() 12 | 13 | self.conv1_1 = nn.Conv2d(3,64,3,padding = 1) 14 | self.conv1_2 = nn.Conv2d(64,64,3,padding = 1) 15 | self.pool1 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding=1) 16 | self.conv2_1 = nn.Conv2d(64,128,3,padding = 1) 17 | self.conv2_2 = nn.Conv2d(128,128,3,padding = 1) 18 | self.pool2 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding=1) 19 | self.conv3_1 = nn.Conv2d(128,256,3,padding = 1) 20 | self.conv3_2 = nn.Conv2d(256,256,3,padding = 1) 21 | self.conv3_3 = nn.Conv2d(256,256,3,padding = 1) 22 | self.pool3 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding=1) 23 | self.conv4_1 = nn.Conv2d(256,512,3,padding = 1) 24 | self.conv4_2 = nn.Conv2d(512,512,3,padding = 1) 25 | self.conv4_3 = nn.Conv2d(512,512,3,padding = 1) 26 | self.pool4 = nn.MaxPool2d(kernel_size = 3, stride = 1, padding=1) 27 | self.conv5_1 = nn.Conv2d(512,512,3,padding = 2, dilation = 2) 28 | self.conv5_2 = nn.Conv2d(512,512,3,padding = 2, dilation = 2) 29 | self.conv5_3 = nn.Conv2d(512,512,3,padding = 2, dilation = 2) 30 | 31 | self.fc6 = nn.Conv2d(512,1024, 3, padding = fc6_dilation, dilation = fc6_dilation) 32 | 33 | self.drop6 = nn.Dropout2d(p=0.5) 34 | self.fc7 = nn.Conv2d(1024, 1024, 1) 35 | 36 | # fixing the parameters 37 | self._fix_params([self.conv1_1, self.conv1_2]) 38 | 39 | def fan_out(self): 40 | return 1024 41 | 42 | def forward(self, x): 43 | return self.forward_as_dict(x)['conv6'] 44 | 45 | def forward_as_dict(self, x): 46 | 47 | x = F.relu(self.conv1_1(x), inplace=True) 48 | x = F.relu(self.conv1_2(x), inplace=True) 49 | x = self.pool1(x) 50 | 51 | x = F.relu(self.conv2_1(x), inplace=True) 52 | x = F.relu(self.conv2_2(x), inplace=True) 53 | x = self.pool2(x) 54 | 55 | x = F.relu(self.conv3_1(x), inplace=True) 56 | x = F.relu(self.conv3_2(x), inplace=True) 57 | x = F.relu(self.conv3_3(x), inplace=True) 58 | conv3 = x 59 | 60 | x = self.pool3(x) 61 | 62 | x = F.relu(self.conv4_1(x), inplace=True) 63 | x = F.relu(self.conv4_2(x), inplace=True) 64 | x = F.relu(self.conv4_3(x), inplace=True) 65 | 66 | x = self.pool4(x) 67 | 68 | x = F.relu(self.conv5_1(x), inplace=True) 69 | x = F.relu(self.conv5_2(x), inplace=True) 70 | x = F.relu(self.conv5_3(x), inplace=True) 71 | 72 | x = F.relu(self.fc6(x), inplace=True) 73 | x = self.drop6(x) 74 | x = F.relu(self.fc7(x), inplace=True) 75 | 76 | conv6 = x 77 | 78 | return dict({'conv3': conv3, 'conv6': conv6}) 79 | -------------------------------------------------------------------------------- /models/mods/__init__.py: -------------------------------------------------------------------------------- 1 | from .sg import StochasticGate 2 | from .pamr import PAMR 3 | from .aspp import ASPP 4 | from .gci import GCI 5 | -------------------------------------------------------------------------------- /models/mods/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class _ASPPModule(nn.Module): 7 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 8 | super(_ASPPModule, self).__init__() 9 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 10 | stride=1, padding=padding, dilation=dilation, bias=False) 11 | self.bn = BatchNorm(planes) 12 | self.relu = nn.ReLU() 13 | 14 | def forward(self, x): 15 | x = self.atrous_conv(x) 16 | x = self.bn(x) 17 | 18 | return self.relu(x) 19 | 20 | class ASPP(nn.Module): 21 | def __init__(self, inplanes, output_stride, BatchNorm): 22 | super(ASPP, self).__init__() 23 | 24 | if output_stride == 16: 25 | dilations = [1, 6, 12, 18] 26 | elif output_stride == 8: 27 | dilations = [1, 12, 24, 36] 28 | else: 29 | raise NotImplementedError 30 | 31 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 32 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 33 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 34 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 35 | 36 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 37 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 38 | BatchNorm(256), 39 | nn.ReLU()) 40 | 41 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 42 | self.bn1 = BatchNorm(256) 43 | self.relu = nn.ReLU() 44 | self.dropout = nn.Dropout(0.5) 45 | self._init_weight() 46 | 47 | def forward(self, x): 48 | x1 = self.aspp1(x) 49 | x2 = self.aspp2(x) 50 | x3 = self.aspp3(x) 51 | x4 = self.aspp4(x) 52 | x5 = self.global_avg_pool(x) 53 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 54 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 55 | 56 | x = self.conv1(x) 57 | x = self.bn1(x) 58 | x = self.relu(x) 59 | 60 | return self.dropout(x) 61 | 62 | def _init_weight(self): 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | torch.nn.init.kaiming_normal_(m.weight) 66 | elif isinstance(m, nn.BatchNorm2d): 67 | if not m.weight is None: 68 | m.weight.data.fill_(1) 69 | else: 70 | print("ASPP has not weight: ", m) 71 | 72 | if not m.bias is None: 73 | m.bias.data.zero_() 74 | else: 75 | print("ASPP has not bias: ", m) 76 | 77 | 78 | def build_aspp(backbone, output_stride, BatchNorm): 79 | return ASPP(backbone, output_stride, BatchNorm) 80 | -------------------------------------------------------------------------------- /models/mods/gci.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GCI(nn.Module): 7 | """Global Cue Injection 8 | Takes shallow features with low receptive 9 | field and augments it with global info via 10 | adaptive instance normalisation""" 11 | 12 | def __init__(self, NormLayer=nn.BatchNorm2d): 13 | super(GCI, self).__init__() 14 | 15 | self.NormLayer = NormLayer 16 | self.from_scratch_layers = [] 17 | 18 | self._init_params() 19 | 20 | def _conv2d(self, *args, **kwargs): 21 | conv = nn.Conv2d(*args, **kwargs) 22 | self.from_scratch_layers.append(conv) 23 | torch.nn.init.kaiming_normal_(conv.weight) 24 | return conv 25 | 26 | def _bnorm(self, *args, **kwargs): 27 | bn = self.NormLayer(*args, **kwargs) 28 | #self.bn_learn.append(bn) 29 | self.from_scratch_layers.append(bn) 30 | if not bn.weight is None: 31 | bn.weight.data.fill_(1) 32 | bn.bias.data.zero_() 33 | return bn 34 | 35 | def _init_params(self): 36 | 37 | self.fc_deep = nn.Sequential(self._conv2d(256, 512, 1, bias=False), \ 38 | self._bnorm(512), nn.ReLU()) 39 | 40 | self.fc_skip = nn.Sequential(self._conv2d(256, 256, 1, bias=False), \ 41 | self._bnorm(256, affine=False)) 42 | 43 | self.fc_cls = nn.Sequential(self._conv2d(256, 256, 1, bias=False), \ 44 | self._bnorm(256), nn.ReLU()) 45 | 46 | def forward(self, x, y): 47 | """Forward pass 48 | 49 | Args: 50 | x: shalow features 51 | y: deep features 52 | """ 53 | 54 | # extract global attributes 55 | y = self.fc_deep(y) 56 | attrs, _ = y.view(y.size(0), y.size(1), -1).max(-1) 57 | 58 | # pre-process shallow features 59 | x = self.fc_skip(x) 60 | x = F.relu(self._adin_conv(x, attrs)) 61 | 62 | return self.fc_cls(x) 63 | 64 | def _adin_conv(self, x, y): 65 | 66 | bs, num_c, _, _ = x.size() 67 | assert 2*num_c == y.size(1), "AdIN: dimension mismatch" 68 | 69 | y = y.view(bs, 2, num_c) 70 | gamma, beta = y[:, 0], y[:, 1] 71 | 72 | gamma = gamma.unsqueeze(-1).unsqueeze(-1) 73 | beta = beta.unsqueeze(-1).unsqueeze(-1) 74 | 75 | return x * (gamma + 1) + beta 76 | -------------------------------------------------------------------------------- /models/mods/pamr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | from functools import partial 6 | 7 | # 8 | # Helper modules 9 | # 10 | class LocalAffinity(nn.Module): 11 | 12 | def __init__(self, dilations=[1]): 13 | super(LocalAffinity, self).__init__() 14 | self.dilations = dilations 15 | weight = self._init_aff() 16 | self.register_buffer('kernel', weight) 17 | 18 | def _init_aff(self): 19 | # initialising the shift kernel 20 | weight = torch.zeros(8, 1, 3, 3) 21 | 22 | for i in range(weight.size(0)): 23 | weight[i, 0, 1, 1] = 1 24 | 25 | weight[0, 0, 0, 0] = -1 26 | weight[1, 0, 0, 1] = -1 27 | weight[2, 0, 0, 2] = -1 28 | 29 | weight[3, 0, 1, 0] = -1 30 | weight[4, 0, 1, 2] = -1 31 | 32 | weight[5, 0, 2, 0] = -1 33 | weight[6, 0, 2, 1] = -1 34 | weight[7, 0, 2, 2] = -1 35 | 36 | self.weight_check = weight.clone() 37 | 38 | return weight 39 | 40 | def forward(self, x): 41 | 42 | self.weight_check = self.weight_check.type_as(x) 43 | assert torch.all(self.weight_check.eq(self.kernel)) 44 | 45 | B,K,H,W = x.size() 46 | x = x.view(B*K,1,H,W) 47 | 48 | x_affs = [] 49 | for d in self.dilations: 50 | x_pad = F.pad(x, [d]*4, mode='replicate') 51 | x_aff = F.conv2d(x_pad, self.kernel, dilation=d) 52 | x_affs.append(x_aff) 53 | 54 | x_aff = torch.cat(x_affs, 1) 55 | return x_aff.view(B,K,-1,H,W) 56 | 57 | class LocalAffinityCopy(LocalAffinity): 58 | 59 | def _init_aff(self): 60 | # initialising the shift kernel 61 | weight = torch.zeros(8, 1, 3, 3) 62 | 63 | weight[0, 0, 0, 0] = 1 64 | weight[1, 0, 0, 1] = 1 65 | weight[2, 0, 0, 2] = 1 66 | 67 | weight[3, 0, 1, 0] = 1 68 | weight[4, 0, 1, 2] = 1 69 | 70 | weight[5, 0, 2, 0] = 1 71 | weight[6, 0, 2, 1] = 1 72 | weight[7, 0, 2, 2] = 1 73 | 74 | self.weight_check = weight.clone() 75 | return weight 76 | 77 | class LocalStDev(LocalAffinity): 78 | 79 | def _init_aff(self): 80 | weight = torch.zeros(9, 1, 3, 3) 81 | weight.zero_() 82 | 83 | weight[0, 0, 0, 0] = 1 84 | weight[1, 0, 0, 1] = 1 85 | weight[2, 0, 0, 2] = 1 86 | 87 | weight[3, 0, 1, 0] = 1 88 | weight[4, 0, 1, 1] = 1 89 | weight[5, 0, 1, 2] = 1 90 | 91 | weight[6, 0, 2, 0] = 1 92 | weight[7, 0, 2, 1] = 1 93 | weight[8, 0, 2, 2] = 1 94 | 95 | self.weight_check = weight.clone() 96 | return weight 97 | 98 | def forward(self, x): 99 | # returns (B,K,P,H,W), where P is the number 100 | # of locations 101 | x = super(LocalStDev, self).forward(x) 102 | 103 | return x.std(2, keepdim=True) 104 | 105 | class LocalAffinityAbs(LocalAffinity): 106 | 107 | def forward(self, x): 108 | x = super(LocalAffinityAbs, self).forward(x) 109 | return torch.abs(x) 110 | 111 | # 112 | # PAMR module 113 | # 114 | class PAMR(nn.Module): 115 | 116 | def __init__(self, num_iter=1, dilations=[1]): 117 | super(PAMR, self).__init__() 118 | 119 | self.num_iter = num_iter 120 | self.aff_x = LocalAffinityAbs(dilations) 121 | self.aff_m = LocalAffinityCopy(dilations) 122 | self.aff_std = LocalStDev(dilations) 123 | 124 | def forward(self, x, mask): 125 | mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True) 126 | 127 | # x: [BxKxHxW] 128 | # mask: [BxCxHxW] 129 | B,K,H,W = x.size() 130 | _,C,_,_ = mask.size() 131 | 132 | x_std = self.aff_std(x) 133 | 134 | x = -self.aff_x(x) / (1e-8 + 0.1 * x_std) 135 | x = x.mean(1, keepdim=True) 136 | x = F.softmax(x, 2) 137 | 138 | for _ in range(self.num_iter): 139 | m = self.aff_m(mask) # [BxCxPxHxW] 140 | mask = (m * x).sum(2) 141 | 142 | # xvals: [BxCxHxW] 143 | return mask 144 | -------------------------------------------------------------------------------- /models/mods/sg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class StochasticGate(nn.Module): 7 | """Stochastically merges features from two levels 8 | with varying size of the receptive field 9 | """ 10 | 11 | def __init__(self): 12 | super(StochasticGate, self).__init__() 13 | self._mask_drop = None 14 | 15 | def forward(self, x1, x2, alpha_rate=0.3): 16 | """Stochastic Gate (SG) 17 | 18 | SG stochastically mixes deep and shallow features 19 | at training time and deterministically combines 20 | them at test time with a hyperparam. alpha 21 | """ 22 | 23 | if self.training: # training time 24 | # dropout: selecting either x1 or x2 25 | if self._mask_drop is None: 26 | bs, c, h, w = x1.size() 27 | assert c == x2.size(1), "Number of features is different" 28 | self._mask_drop = torch.ones_like(x1) 29 | 30 | # a mask of {0,1} 31 | mask_drop = (1 - alpha_rate) * F.dropout(self._mask_drop, alpha_rate) 32 | 33 | # shift and scale deep features 34 | # at train time: E[x] = x1 35 | x1 = (x1 - alpha_rate * x2) / max(1e-8, 1 - alpha_rate) 36 | 37 | # combine the features 38 | x = mask_drop * x1 + (1 - mask_drop) * x2 39 | else: 40 | # inference time: deterministic 41 | x = (1 - alpha_rate) * x1 + alpha_rate * x2 42 | 43 | return x 44 | 45 | -------------------------------------------------------------------------------- /models/stage_net.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # backbone nets 7 | from models.backbones.resnet38d import ResNet38 8 | from models.backbones.vgg16d import VGG16 9 | from models.backbones.resnets import ResNet101, ResNet50 10 | 11 | # modules 12 | from models.mods import ASPP 13 | from models.mods import PAMR 14 | from models.mods import StochasticGate 15 | from models.mods import GCI 16 | 17 | # 18 | # Helper classes 19 | # 20 | def rescale_as(x, y, mode="bilinear", align_corners=True): 21 | h, w = y.size()[2:] 22 | x = F.interpolate(x, size=[h, w], mode=mode, align_corners=align_corners) 23 | return x 24 | 25 | def focal_loss(x, p = 1, c = 0.1): 26 | return torch.pow(1 - x, p) * torch.log(c + x) 27 | 28 | def pseudo_gtmask(mask, cutoff_top=0.6, cutoff_low=0.2, eps=1e-8): 29 | """Convert continuous mask into binary mask""" 30 | bs,c,h,w = mask.size() 31 | mask = mask.view(bs,c,-1) 32 | 33 | # for each class extract the max confidence 34 | mask_max, _ = mask.max(-1, keepdim=True) 35 | mask_max[:, :1] *= 0.7 36 | mask_max[:, 1:] *= cutoff_top 37 | #mask_max *= cutoff_top 38 | 39 | # if the top score is too low, ignore it 40 | lowest = torch.Tensor([cutoff_low]).type_as(mask_max) 41 | mask_max = mask_max.max(lowest) 42 | 43 | pseudo_gt = (mask > mask_max).type_as(mask) 44 | 45 | # remove ambiguous pixels 46 | ambiguous = (pseudo_gt.sum(1, keepdim=True) > 1).type_as(mask) 47 | pseudo_gt = (1 - ambiguous) * pseudo_gt 48 | 49 | return pseudo_gt.view(bs,c,h,w) 50 | 51 | def balanced_mask_loss_ce(mask, pseudo_gt, gt_labels, ignore_index=255): 52 | """Class-balanced CE loss 53 | - cancel loss if only one class in pseudo_gt 54 | - weight loss equally between classes 55 | """ 56 | 57 | mask = F.interpolate(mask, size=pseudo_gt.size()[-2:], mode="bilinear", align_corners=True) 58 | 59 | # indices of the max classes 60 | mask_gt = torch.argmax(pseudo_gt, 1) 61 | 62 | # for each pixel there should be at least one 1 63 | # otherwise, ignore 64 | ignore_mask = pseudo_gt.sum(1) < 1. 65 | mask_gt[ignore_mask] = ignore_index 66 | 67 | # class weight balances the loss w.r.t. number of pixels 68 | # because we are equally interested in all classes 69 | bs,c,h,w = pseudo_gt.size() 70 | num_pixels_per_class = pseudo_gt.view(bs,c,-1).sum(-1) 71 | num_pixels_total = num_pixels_per_class.sum(-1, keepdim=True) 72 | class_weight = (num_pixels_total - num_pixels_per_class) / (1 + num_pixels_total) 73 | class_weight = (pseudo_gt * class_weight[:,:,None,None]).sum(1).view(bs, -1) 74 | 75 | # BCE loss 76 | loss = F.cross_entropy(mask, mask_gt, ignore_index=ignore_index, reduction="none") 77 | loss = loss.view(bs, -1) 78 | 79 | # we will have the loss only for batch indices 80 | # which have all classes in pseudo mask 81 | gt_num_labels = gt_labels.sum(-1).type_as(loss) + 1 # + BG 82 | ps_num_labels = (num_pixels_per_class > 0).type_as(loss).sum(-1) 83 | batch_weight = (gt_num_labels == ps_num_labels).type_as(loss) 84 | 85 | loss = batch_weight * (class_weight * loss).mean(-1) 86 | return loss 87 | 88 | class Flatten(nn.Module): 89 | def forward(self, input): 90 | return input.view(input.size(0), -1) 91 | 92 | # 93 | # Dynamic change of the base class 94 | # 95 | def network_factory(cfg): 96 | 97 | if cfg.BACKBONE == "resnet38": 98 | print("Backbone: ResNet38") 99 | backbone = ResNet38 100 | elif cfg.BACKBONE == "vgg16": 101 | print("Backbone: VGG16") 102 | backbone = VGG16 103 | elif cfg.BACKBONE == "resnet50": 104 | print("Backbone: ResNet50") 105 | backbone = ResNet50 106 | elif cfg.BACKBONE == "resnet101": 107 | print("Backbone: ResNet101") 108 | backbone = ResNet101 109 | else: 110 | raise NotImplementedError("No backbone found for '{}'".format(cfg.BACKBONE)) 111 | 112 | # 113 | # Class definitions 114 | # 115 | class BaselineCAM(backbone): 116 | 117 | def __init__(self, config, pre_weights=None, num_classes=21, dropout=True): 118 | super().__init__() 119 | 120 | self.cfg = config 121 | 122 | self.fc8 = nn.Conv2d(self.fan_out(), num_classes - 1, 1, bias=False) 123 | nn.init.xavier_uniform_(self.fc8.weight) 124 | 125 | cls_modules = [nn.AdaptiveAvgPool2d((1, 1)), self.fc8, Flatten()] 126 | if dropout: 127 | cls_modules.insert(0, nn.Dropout2d(0.5)) 128 | 129 | self.cls_branch = nn.Sequential(*cls_modules) 130 | self.mask_branch = nn.Sequential(self.fc8, nn.ReLU()) 131 | 132 | self.from_scratch_layers = [self.fc8] 133 | self._init_weights(pre_weights) 134 | self._mask_logits = None 135 | 136 | self._fix_running_stats(self, fix_params=True) # freeze backbone BNs 137 | 138 | def forward_backbone(self, x): 139 | self._mask_logits = super().forward(x) 140 | return self._mask_logits 141 | 142 | def forward_cls(self, x): 143 | return self.cls_branch(x) 144 | 145 | def forward_mask(self, x, size): 146 | logits = self.fc8(x) 147 | masks = F.interpolate(logits, size=size, mode='bilinear', align_corners=True) 148 | masks = F.relu(masks) 149 | 150 | # CAMs are unbounded 151 | # so let's normalised it first 152 | # (see jiwoon-ahn/psa) 153 | b,c,h,w = masks.size() 154 | masks_ = masks.view(b,c,-1) 155 | z, _ = masks_.max(-1, keepdim=True) 156 | masks_ /= (1e-5 + z) 157 | masks = masks.view(b,c,h,w) 158 | 159 | bg = torch.ones_like(masks[:, :1]) 160 | masks = torch.cat([self.cfg.BG_SCORE * bg, masks], 1) 161 | 162 | # note, that the masks contain the background as the first channel 163 | return logits, masks 164 | 165 | def forward(self, y, _, labels=None): 166 | test_mode = labels is None 167 | 168 | x = self.forward_backbone(y) 169 | 170 | cls = self.forward_cls(x) 171 | logits, masks = self.forward_mask(x, y.size()[-2:]) 172 | 173 | if test_mode: 174 | return cls, masks 175 | 176 | # foreground stats 177 | b,c,h,w = masks.size() 178 | masks_ = masks.view(b,c,-1) 179 | masks_ = masks_[:, 1:] 180 | cls_fg = (masks_.mean(-1) * labels).sum(-1) / labels.sum(-1) 181 | 182 | # upscale the masks & clean 183 | masks = self._rescale_and_clean(masks, y, labels) 184 | 185 | return cls, cls_fg, {"cam": masks}, logits, None, None 186 | 187 | def _rescale_and_clean(self, masks, image, labels): 188 | masks = F.interpolate(masks, size=image.size()[-2:], mode='bilinear', align_corners=True) 189 | masks[:, 1:] *= labels[:, :, None, None].type_as(masks) 190 | return masks 191 | 192 | # 193 | # Softmax unit 194 | # 195 | class SoftMaxAE(backbone): 196 | 197 | def __init__(self, config, pre_weights=None, num_classes=21, dropout=True): 198 | super().__init__() 199 | 200 | self.cfg = config 201 | self.num_classes = num_classes 202 | 203 | self._init_weights(pre_weights) # initialise backbone weights 204 | self._fix_running_stats(self, fix_params=True) # freeze backbone BNs 205 | 206 | # Decoder 207 | self._init_aspp() 208 | self._init_decoder(num_classes) 209 | 210 | self._backbone = None 211 | self._mask_logits = None 212 | 213 | def _init_aspp(self): 214 | self.aspp = ASPP(self.fan_out(), 8, self.NormLayer) 215 | 216 | for m in self.aspp.modules(): 217 | if isinstance(m, nn.Conv2d) or isinstance(m, self.NormLayer): 218 | self.from_scratch_layers.append(m) 219 | 220 | self._fix_running_stats(self.aspp) # freeze BN 221 | 222 | def _init_decoder(self, num_classes): 223 | 224 | self._aff = PAMR(self.cfg.PAMR_ITER, self.cfg.PAMR_KERNEL) 225 | 226 | def conv2d(*args, **kwargs): 227 | conv = nn.Conv2d(*args, **kwargs) 228 | self.from_scratch_layers.append(conv) 229 | torch.nn.init.kaiming_normal_(conv.weight) 230 | return conv 231 | 232 | def bnorm(*args, **kwargs): 233 | bn = self.NormLayer(*args, **kwargs) 234 | self.from_scratch_layers.append(bn) 235 | if not bn.weight is None: 236 | bn.weight.data.fill_(1) 237 | bn.bias.data.zero_() 238 | return bn 239 | 240 | # pre-processing for shallow features 241 | self.shallow_mask = GCI(self.NormLayer) 242 | self.from_scratch_layers += self.shallow_mask.from_scratch_layers 243 | 244 | # Stochastic Gate 245 | self.sg = StochasticGate() 246 | self.fc8_skip = nn.Sequential(conv2d(256, 48, 1, bias=False), bnorm(48), nn.ReLU()) 247 | self.fc8_x = nn.Sequential(conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 248 | bnorm(256), nn.ReLU()) 249 | 250 | # decoder 251 | self.last_conv = nn.Sequential(conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 252 | bnorm(256), nn.ReLU(), 253 | nn.Dropout(0.5), 254 | conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 255 | bnorm(256), nn.ReLU(), 256 | nn.Dropout(0.1), 257 | conv2d(256, num_classes - 1, kernel_size=1, stride=1)) 258 | 259 | def run_pamr(self, im, mask): 260 | im = F.interpolate(im, mask.size()[-2:], mode="bilinear", align_corners=True) 261 | masks_dec = self._aff(im, mask) 262 | return masks_dec 263 | 264 | def forward_backbone(self, x): 265 | self._backbone = super().forward_as_dict(x) 266 | return self._backbone['conv6'] 267 | 268 | def forward(self, y, y_raw=None, labels=None): 269 | test_mode = y_raw is None and labels is None 270 | 271 | # 1. backbone pass 272 | x = self.forward_backbone(y) 273 | 274 | # 2. ASPP modules 275 | x = self.aspp(x) 276 | 277 | # 278 | # 3. merging deep and shallow features 279 | # 280 | 281 | # 3.1 skip connection for deep features 282 | x2_x = self.fc8_skip(self._backbone['conv3']) 283 | x_up = rescale_as(x, x2_x) 284 | x = self.fc8_x(torch.cat([x_up, x2_x], 1)) 285 | 286 | # 3.2 deep feature context for shallow features 287 | x2 = self.shallow_mask(self._backbone['conv3'], x) 288 | 289 | # 3.3 stochastically merging the masks 290 | x = self.sg(x, x2, alpha_rate=self.cfg.SG_PSI) 291 | 292 | # 4. final convs to get the masks 293 | x = self.last_conv(x) 294 | 295 | # 296 | # 5. Finalising the masks and scores 297 | # 298 | 299 | # constant BG scores 300 | bg = torch.ones_like(x[:, :1]) 301 | x = torch.cat([bg, x], 1) 302 | 303 | bs, c, h, w = x.size() 304 | 305 | masks = F.softmax(x, dim=1) 306 | 307 | # reshaping 308 | features = x.view(bs, c, -1) 309 | masks_ = masks.view(bs, c, -1) 310 | 311 | # classification loss 312 | cls_1 = (features * masks_).sum(-1) / (1.0 + masks_.sum(-1)) 313 | 314 | # focal penalty loss 315 | cls_2 = focal_loss(masks_.mean(-1), \ 316 | p=self.cfg.FOCAL_P, \ 317 | c=self.cfg.FOCAL_LAMBDA) 318 | 319 | # adding the losses together 320 | cls = cls_1[:, 1:] + cls_2[:, 1:] 321 | 322 | if test_mode: 323 | # if in test mode, not mask 324 | # cleaning is performed 325 | return cls, rescale_as(masks, y) 326 | 327 | self._mask_logits = x 328 | 329 | # foreground stats 330 | masks_ = masks_[:, 1:] 331 | cls_fg = (masks_.mean(-1) * labels).sum(-1) / labels.sum(-1) 332 | 333 | # mask refinement with PAMR 334 | masks_dec = self.run_pamr(y_raw, masks.detach()) 335 | 336 | # upscale the masks & clean 337 | masks = self._rescale_and_clean(masks, y, labels) 338 | masks_dec = self._rescale_and_clean(masks_dec, y, labels) 339 | 340 | # create pseudo GT 341 | pseudo_gt = pseudo_gtmask(masks_dec).detach() 342 | loss_mask = balanced_mask_loss_ce(self._mask_logits, pseudo_gt, labels) 343 | 344 | return cls, cls_fg, {"cam": masks, "dec": masks_dec}, self._mask_logits, pseudo_gt, loss_mask 345 | 346 | def _rescale_and_clean(self, masks, image, labels): 347 | """Rescale to fit the image size and remove any masks 348 | of labels that are not present""" 349 | masks = F.interpolate(masks, size=image.size()[-2:], mode='bilinear', align_corners=True) 350 | masks[:, 1:] *= labels[:, :, None, None].type_as(masks) 351 | return masks 352 | 353 | 354 | if cfg.MODEL == 'ae': 355 | print("Model: AE") 356 | return SoftMaxAE 357 | elif cfg.MODEL == 'bsl': 358 | print("Model: Baseline") 359 | return BaselineCAM 360 | else: 361 | raise NotImplementedError("Unknown model '{}'".format(cfg.MODEL)) 362 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import torch 5 | import argparse 6 | from core.config import cfg 7 | 8 | def add_global_arguments(parser): 9 | 10 | parser.add_argument("--dataset", type=str, 11 | help="Determines dataloader to use (only Pascal VOC supported)") 12 | parser.add_argument("--exp", type=str, default="main", 13 | help="ID of the experiment (multiple runs)") 14 | parser.add_argument("--resume", type=str, default=None, 15 | help="Snapshot \"ID,iter\" to load") 16 | parser.add_argument("--run", type=str, help="ID of the run") 17 | parser.add_argument('--workers', type=int, default=8, 18 | metavar='N', help='dataloader threads') 19 | parser.add_argument("--snapshot-dir", type=str, default='./snapshots', 20 | help="Where to save snapshots of the model.") 21 | parser.add_argument("--logdir", type=str, default='./logs', 22 | help="Where to save log files of the model.") 23 | 24 | # used at inference only 25 | parser.add_argument("--infer-list", type=str, default='data/val_augvoc.txt', 26 | help="Path to a file list") 27 | parser.add_argument("--mask-output-dir", type=str, default='results/', 28 | help="Path where to save masks") 29 | 30 | # 31 | # Configuration 32 | # 33 | parser.add_argument( 34 | '--cfg', dest='cfg_file', required=True, 35 | help='Config file for training (and optionally testing)') 36 | parser.add_argument( 37 | '--set', dest='set_cfgs', 38 | help='Set config keys. Key value sequence seperate by whitespace.' 39 | 'e.g. [key] [value] [key] [value]', 40 | default=[], nargs='+') 41 | 42 | parser.add_argument("--random-seed", type=int, default=64, help="Random seed") 43 | 44 | 45 | def maybe_create_dir(path): 46 | if not os.path.exists(path): 47 | os.makedirs(path) 48 | 49 | def check_global_arguments(args): 50 | 51 | torch.set_num_threads(args.workers) 52 | if args.workers != torch.get_num_threads(): 53 | print("Warning: # of threads is only ", torch.get_num_threads()) 54 | 55 | setattr(args, "fixed_batch_path", os.path.join(args.logdir, args.dataset, args.exp, "fixed_batch.pt")) 56 | args.logdir = os.path.join(args.logdir, args.dataset, args.exp, args.run) 57 | maybe_create_dir(args.logdir) 58 | #print("Saving events in: {}".format(args.logdir)) 59 | 60 | # 61 | # Model directories 62 | # 63 | args.snapshot_dir = os.path.join(args.snapshot_dir, args.dataset, args.exp, args.run) 64 | maybe_create_dir(args.snapshot_dir) 65 | #print("Saving snapshots in: {}".format(args.snapshot_dir)) 66 | 67 | def get_arguments(args_in): 68 | """Parse all the arguments provided from the CLI. 69 | 70 | Returns: 71 | A list of parsed arguments. 72 | """ 73 | parser = argparse.ArgumentParser(description="Model Evaluation") 74 | 75 | add_global_arguments(parser) 76 | args = parser.parse_args(args_in) 77 | check_global_arguments(args) 78 | 79 | return args 80 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | python=3.6.3 5 | numpy=1.15.4 6 | pillow=5.4.1 7 | pydensecrf=1.0rc3 8 | pytorch=1.0.1 9 | pyyaml=3.13 10 | torchvision=0.2.1 11 | tqdm==4.31.1 12 | tensorboardX=1.6 13 | packaging=19.0 14 | scikit-learn=0.20.2 15 | scikit-image=0.16.1 16 | six=1.12.0 17 | scipy=1.2.1 18 | -------------------------------------------------------------------------------- /tools/convert_sbd.py: -------------------------------------------------------------------------------- 1 | """Convert .mat segmentation mask of SBD to .png 2 | See: https://github.com/visinf/1-stage-wseg/issues/9 3 | """ 4 | 5 | import os 6 | import sys 7 | import glob 8 | import argparse 9 | from PIL import Image 10 | from scipy.io import loadmat 11 | 12 | # load tqdm optionally 13 | try: 14 | from tqdm import tqdm 15 | except ImportError: 16 | tqdm = lambda x: x 17 | 18 | 19 | def args(): 20 | parser = argparse.ArgumentParser(description="Convert SBD .mat to .png") 21 | parser.add_argument("--inp", type=str, default='./dataset/cls/', 22 | help="Directory with .mat files") 23 | parser.add_argument("--out", type=str, default='./dataset/cls_png/', 24 | help="Directory where to save .png files") 25 | return parser.parse_args(sys.argv[1:]) 26 | 27 | 28 | def convert(opts): 29 | 30 | # searching for files .mat 31 | opts.inp = opts.inp + ("" if opts.inp[-1] == "/" else "/") 32 | filelist = glob.glob(opts.inp + "*.mat") 33 | print("Found {:d} files".format(len(filelist))) 34 | 35 | if len(filelist) == 0: 36 | return 37 | 38 | # check output directory 39 | if not os.path.isdir(opts.out): 40 | print("Creating {}".format(opts.out)) 41 | os.makedirs(opts.out) 42 | 43 | 44 | for filepath in tqdm(filelist): 45 | 46 | x = loadmat(filepath) 47 | y = x['GTcls']['Segmentation'][0][0] 48 | 49 | # converting to PIL image 50 | png = Image.fromarray(y) 51 | 52 | name = os.path.basename(filepath).replace(".mat", ".png") 53 | png.save(os.path.join(opts.out, name)) 54 | 55 | 56 | if __name__ == "__main__": 57 | opts = args() 58 | convert(opts) 59 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import math 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | from sklearn.metrics import average_precision_score 11 | 12 | from datasets import get_dataloader, get_num_classes, get_class_names 13 | from models import get_model 14 | 15 | from base_trainer import BaseTrainer 16 | from functools import partial 17 | 18 | from opts import get_arguments 19 | from core.config import cfg, cfg_from_file, cfg_from_list 20 | from datasets.utils import Colorize 21 | from losses import get_criterion, mask_loss_ce 22 | 23 | from utils.timer import Timer 24 | from utils.stat_manager import StatManager 25 | from utils.metrics import compute_jaccard 26 | 27 | # specific to pytorch-v1 cuda-9.0 28 | # see: https://github.com/pytorch/pytorch/issues/15054#issuecomment-450191923 29 | # and: https://github.com/pytorch/pytorch/issues/14456 30 | torch.backends.cudnn.benchmark = True 31 | #torch.backends.cudnn.deterministic = True 32 | DEBUG = False 33 | 34 | def rescale_as(x, y, mode="bilinear", align_corners=True): 35 | h, w = y.size()[2:] 36 | x = F.interpolate(x, size=[h, w], mode=mode, align_corners=align_corners) 37 | return x 38 | 39 | class DecTrainer(BaseTrainer): 40 | 41 | def __init__(self, args, **kwargs): 42 | super(DecTrainer, self).__init__(args, **kwargs) 43 | 44 | # dataloader 45 | self.trainloader = get_dataloader(args, cfg, 'train') 46 | self.trainloader_val = get_dataloader(args, cfg, 'train_voc') 47 | self.valloader = get_dataloader(args, cfg, 'val') 48 | self.denorm = self.trainloader.dataset.denorm 49 | 50 | self.nclass = get_num_classes(args) 51 | self.classNames = get_class_names(args)[:-1] 52 | assert self.nclass == len(self.classNames) 53 | 54 | self.classIndex = {} 55 | for i, cname in enumerate(self.classNames): 56 | self.classIndex[cname] = i 57 | 58 | # model 59 | self.enc = get_model(cfg.NET, num_classes=self.nclass) 60 | self.criterion_cls = get_criterion(cfg.NET.LOSS) 61 | print(self.enc) 62 | 63 | # optimizer using different LR 64 | enc_params = self.enc.parameter_groups(cfg.NET.LR, cfg.NET.WEIGHT_DECAY) 65 | self.optim_enc = self.get_optim(enc_params, cfg.NET) 66 | 67 | # checkpoint management 68 | self._define_checkpoint('enc', self.enc, self.optim_enc) 69 | self._load_checkpoint(args.resume) 70 | 71 | self.fixed_batch = None 72 | self.fixed_batch_path = args.fixed_batch_path 73 | if os.path.isfile(self.fixed_batch_path): 74 | print("Loading fixed batch from {}".format(self.fixed_batch_path)) 75 | self.fixed_batch = torch.load(self.fixed_batch_path) 76 | 77 | # using cuda 78 | self.enc = nn.DataParallel(self.enc).cuda() 79 | self.criterion_cls = nn.DataParallel(self.criterion_cls).cuda() 80 | 81 | def step(self, epoch, image, gt_labels, train=False, visualise=False): 82 | 83 | PRETRAIN = epoch < (11 if DEBUG else cfg.TRAIN.PRETRAIN) 84 | 85 | # denorm image 86 | image_raw = self.denorm(image.clone()) 87 | 88 | # classification 89 | cls_out, cls_fg, masks, mask_logits, pseudo_gt, loss_mask = self.enc(image, image_raw, gt_labels) 90 | 91 | # classification loss 92 | loss_cls = self.criterion_cls(cls_out, gt_labels).mean() 93 | 94 | # keep track of all losses for logging 95 | losses = {"loss_cls": loss_cls.item(), 96 | "loss_fg": cls_fg.mean().item()} 97 | 98 | loss = loss_cls.clone() 99 | if "dec" in masks: 100 | loss_mask = loss_mask.mean() 101 | 102 | if not PRETRAIN: 103 | loss += cfg.NET.MASK_LOSS_BCE * loss_mask 104 | 105 | assert not "pseudo" in masks 106 | masks["pseudo"] = pseudo_gt 107 | losses["loss_mask"] = loss_mask.item() 108 | 109 | losses["loss"] = loss.item() 110 | 111 | if train: 112 | self.optim_enc.zero_grad() 113 | loss.backward() 114 | self.optim_enc.step() 115 | 116 | for mask_key, mask_val in masks.items(): 117 | masks[mask_key] = masks[mask_key].detach() 118 | 119 | mask_logits = mask_logits.detach() 120 | 121 | if visualise: 122 | self._visualise(epoch, image, masks, mask_logits, cls_out, gt_labels) 123 | 124 | # make sure to cut the return values from graph 125 | return losses, cls_out.detach(), masks, mask_logits 126 | 127 | def train_epoch(self, epoch): 128 | self.enc.train() 129 | 130 | stat = StatManager() 131 | stat.add_val("loss") 132 | stat.add_val("loss_cls") 133 | stat.add_val("loss_fg") 134 | stat.add_val("loss_bce") 135 | 136 | # adding stats for classes 137 | timer = Timer("New Epoch: ") 138 | train_step = partial(self.step, train=True, visualise=False) 139 | 140 | for i, (image, gt_labels, _) in enumerate(self.trainloader): 141 | 142 | # masks 143 | losses, _, _, _ = train_step(epoch, image, gt_labels) 144 | 145 | if self.fixed_batch is None: 146 | self.fixed_batch = {} 147 | self.fixed_batch["image"] = image.clone() 148 | self.fixed_batch["labels"] = gt_labels.clone() 149 | torch.save(self.fixed_batch, self.fixed_batch_path) 150 | 151 | for loss_key, loss_val in losses.items(): 152 | stat.update_stats(loss_key, loss_val) 153 | 154 | # intermediate logging 155 | if i % 10 == 0: 156 | msg = "Loss [{:04d}]: ".format(i) 157 | for loss_key, loss_val in losses.items(): 158 | msg += "{}: {:.4f} | ".format(loss_key, loss_val) 159 | 160 | msg += " | Im/Sec: {:.1f}".format(i * cfg.TRAIN.BATCH_SIZE / timer.get_stage_elapsed()) 161 | print(msg) 162 | sys.stdout.flush() 163 | 164 | del image, gt_labels 165 | 166 | if DEBUG and i > 100: 167 | break 168 | 169 | def publish_loss(stats, name, t, prefix='data/'): 170 | print("{}: {:4.3f}".format(name, stats.summarize_key(name))) 171 | #self.writer.add_scalar(prefix + name, stats.summarize_key(name), t) 172 | 173 | for stat_key in stat.vals.keys(): 174 | publish_loss(stat, stat_key, epoch) 175 | 176 | # plotting learning rate 177 | for ii, l in enumerate(self.optim_enc.param_groups): 178 | print("Learning rate [{}]: {:4.3e}".format(ii, l['lr'])) 179 | self.writer.add_scalar('lr/enc_group_%02d' % ii, l['lr'], epoch) 180 | 181 | #self.writer.add_scalar('lr/bg_baseline', self.enc.module.mean.item(), epoch) 182 | 183 | # visualising 184 | self.enc.eval() 185 | with torch.no_grad(): 186 | self.step(epoch, self.fixed_batch["image"], \ 187 | self.fixed_batch["labels"], \ 188 | train=False, visualise=True) 189 | 190 | def _mask_rgb(self, masks, image_norm): 191 | # visualising masks 192 | masks_conf, masks_idx = torch.max(masks, 1) 193 | masks_conf = masks_conf - F.relu(masks_conf - 1, 0) 194 | 195 | masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), masks_conf.cpu()) 196 | return 0.3 * image_norm + 0.7 * masks_idx_rgb 197 | 198 | def _init_norm(self): 199 | self.trainloader.dataset.set_norm(self.enc.normalize) 200 | self.valloader.dataset.set_norm(self.enc.normalize) 201 | self.trainloader_val.dataset.set_norm(self.enc.normalize) 202 | 203 | def _apply_cmap(self, mask_idx, mask_conf): 204 | palette = self.trainloader.dataset.get_palette() 205 | 206 | masks = [] 207 | col = Colorize() 208 | mask_conf = mask_conf.float() / 255.0 209 | for mask, conf in zip(mask_idx.split(1), mask_conf.split(1)): 210 | m = col(mask).float() 211 | m = m * conf 212 | masks.append(m[None, ...]) 213 | 214 | return torch.cat(masks, 0) 215 | 216 | def validation(self, epoch, writer, loader, checkpoint=False): 217 | 218 | stat = StatManager() 219 | 220 | # Fast test during the training 221 | def eval_batch(image, gt_labels): 222 | 223 | losses, cls, masks, mask_logits = \ 224 | self.step(epoch, image, gt_labels, train=False, visualise=False) 225 | 226 | for loss_key, loss_val in losses.items(): 227 | stat.update_stats(loss_key, loss_val) 228 | 229 | return cls.cpu(), masks, mask_logits.cpu() 230 | 231 | self.enc.eval() 232 | 233 | # class ground truth 234 | targets_all = [] 235 | 236 | # class predictions 237 | preds_all = [] 238 | 239 | def add_stats(means, stds, x): 240 | means.append(x.mean()) 241 | stds.append(x.std()) 242 | 243 | for n, (image, gt_labels, _) in enumerate(loader): 244 | 245 | with torch.no_grad(): 246 | cls_raw, masks_all, mask_logits = eval_batch(image, gt_labels) 247 | 248 | cls_sigmoid = torch.sigmoid(cls_raw).numpy() 249 | 250 | preds_all.append(cls_sigmoid) 251 | targets_all.append(gt_labels.cpu().numpy()) 252 | 253 | # 254 | # classification 255 | # 256 | targets_stacked = np.vstack(targets_all) 257 | preds_stacked = np.vstack(preds_all) 258 | aps = average_precision_score(targets_stacked, preds_stacked, average=None) 259 | 260 | # skip BG AP 261 | offset = self.nclass - aps.size 262 | assert offset == 1, 'Class number mismatch' 263 | 264 | classNames = self.classNames[offset:] 265 | for ni, className in enumerate(classNames): 266 | writer.add_scalar('%02d_%s/AP' % (ni + offset, className), aps[ni], epoch) 267 | print("AP_{}: {:4.3f}".format(className, aps[ni])) 268 | 269 | meanAP = np.mean(aps) 270 | writer.add_scalar('all_wo_BG/mAP', meanAP, epoch) 271 | print('mAP: {:4.3f}'.format(meanAP)) 272 | 273 | # total classification loss 274 | for stat_key in stat.vals.keys(): 275 | writer.add_scalar('all/{}'.format(stat_key), stat.summarize_key(stat_key), epoch) 276 | 277 | if checkpoint and epoch >= cfg.TRAIN.PRETRAIN: 278 | # we will use mAP - mask_loss as our proxy score 279 | # to save the best checkpoint so far 280 | proxy_score = 1 - stat.summarize_key("loss") 281 | writer.add_scalar('all/checkpoint_score', proxy_score, epoch) 282 | self.checkpoint_best(proxy_score, epoch) 283 | 284 | def _visualise(self, epoch, image, masks, mask_logits, cls_out, gt_labels): 285 | image_norm = self.denorm(image.clone()).cpu() 286 | visual = [image_norm] 287 | 288 | if "cam" in masks: 289 | visual.append(self._mask_rgb(masks["cam"], image_norm)) 290 | 291 | if "dec" in masks: 292 | visual.append(self._mask_rgb(masks["dec"], image_norm)) 293 | 294 | if "pseudo" in masks: 295 | pseudo_gt_rgb = self._mask_rgb(masks["pseudo"], image_norm) 296 | 297 | # cancel ambiguous 298 | ambiguous = 1 - masks["pseudo"].sum(1, keepdim=True).cpu() 299 | pseudo_gt_rgb = ambiguous * image_norm + (1 - ambiguous) * pseudo_gt_rgb 300 | visual.append(pseudo_gt_rgb) 301 | 302 | # ready to assemble 303 | visual_logits = torch.cat(visual, -1) 304 | self._visualise_grid(visual_logits, gt_labels, epoch, scores=cls_out) 305 | 306 | if __name__ == "__main__": 307 | args = get_arguments(sys.argv[1:]) 308 | 309 | # Reading the config 310 | cfg_from_file(args.cfg_file) 311 | if args.set_cfgs is not None: 312 | cfg_from_list(args.set_cfgs) 313 | 314 | print("Config: \n", cfg) 315 | 316 | trainer = DecTrainer(args) 317 | torch.manual_seed(0) 318 | 319 | timer = Timer() 320 | def time_call(func, msg, *args, **kwargs): 321 | timer.reset_stage() 322 | func(*args, **kwargs) 323 | print(msg + (" {:3.2}m".format(timer.get_stage_elapsed() / 60.))) 324 | 325 | for epoch in range(trainer.start_epoch, cfg.TRAIN.NUM_EPOCHS + 1): 326 | print("Epoch >>> ", epoch) 327 | 328 | log_int = 5 if DEBUG else 2 329 | if epoch % log_int == 0: 330 | with torch.no_grad(): 331 | if not DEBUG: 332 | time_call(trainer.validation, "Validation / Train: ", epoch, trainer.writer, trainer.trainloader_val) 333 | time_call(trainer.validation, "Validation / Val: ", epoch, trainer.writer_val, trainer.valloader, checkpoint=True) 334 | 335 | time_call(trainer.train_epoch, "Train epoch: ", epoch) 336 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/utils/__init__.py -------------------------------------------------------------------------------- /utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class Checkpoint(object): 6 | 7 | def __init__(self, path, max_n=3): 8 | self.path = path 9 | self.max_n = max_n 10 | self.models = {} 11 | self.checkpoints = [] 12 | 13 | def add_model(self, name, model, opt=None): 14 | assert not name in self.models, "Model {} already added".format(name) 15 | 16 | self.models[name] = {} 17 | self.models[name]['model'] = model 18 | self.models[name]['opt'] = opt 19 | 20 | def limit(self): 21 | return self.max_n 22 | 23 | def add_checkpoints(self, name=None): 24 | # searching for names 25 | fns = os.listdir(self.path) 26 | fns = filter(lambda x: x[-4:] == '.pth', fns) 27 | 28 | names = {} 29 | for fn in fns: 30 | sfx = fn.split("_")[-1].rstrip('.pth') 31 | path = self._get_full_path(fn) 32 | if not sfx in names: 33 | names[sfx] = os.path.getmtime(path) 34 | else: 35 | names[sfx] = max(names[sfx], os.path.getmtime(path)) 36 | 37 | # assembling 38 | names_and_time = [] 39 | for sfx, time in names.items(): 40 | exists, paths = self.find(sfx) 41 | if exists: 42 | names_and_time.append((sfx, time)) 43 | 44 | # if there are more checkpoints 45 | # than we can handle, remove the older ones 46 | # but do not remove them (for safety) 47 | if len(names_and_time) > self.max_n: 48 | names_and_time = sorted(names_and_time, \ 49 | key=lambda x: x[1], \ 50 | reverse=False) 51 | new_checkpoints = [] 52 | for key in names_and_time[-self.max_n:]: 53 | new_checkpoints.append(key[0]) 54 | 55 | self.checkpoints = new_checkpoints 56 | 57 | def __len__(self): 58 | return len(self.checkpoints) 59 | 60 | def _get_full_path(self, filename): 61 | return os.path.join(self.path, filename) 62 | 63 | def clean(self, n_remove): 64 | 65 | n_remove = min(n_remove, len(self.checkpoints)) 66 | 67 | for i in range(n_remove): 68 | sfx = self.checkpoints[i] 69 | 70 | for name, data in self.models.items(): 71 | for d in ('model', 'opt'): 72 | fn = self._filename(d, name, sfx) 73 | self._rm(fn) 74 | 75 | removed = self.checkpoints[:n_remove] 76 | self.checkpoints = self.checkpoints[n_remove:] 77 | return removed 78 | 79 | def _rm(self, fn): 80 | path = self._get_full_path(fn) 81 | if os.path.isfile(path): 82 | os.remove(path) 83 | 84 | def _filename(self, d, name, suffix): 85 | return "{}_{}_{}.pth".format(d, name, suffix) 86 | 87 | def load(self, suffix): 88 | if suffix is None: 89 | return False 90 | 91 | found, paths = self.find(suffix) 92 | if not found: 93 | return False 94 | 95 | # loading 96 | for name, data in self.models.items(): 97 | for d in ('model', 'opt'): 98 | if data[d] is not None: 99 | data[d].load_state_dict(torch.load(paths[name][d])) 100 | 101 | return True 102 | 103 | def find(self, suffix, force=False): 104 | paths = {} 105 | found = True 106 | for name, data in self.models.items(): 107 | paths[name] = {} 108 | for d in ('model', 'opt'): 109 | fn = self._filename(d, name, suffix) 110 | path = self._get_full_path(fn) 111 | paths[name][d] = path 112 | if not os.path.isfile(path): 113 | print("File not found: ", path) 114 | if d == 'model': 115 | found = False 116 | 117 | if found and not suffix in self.checkpoints: 118 | if len(self.checkpoints) < self.max_n or force: 119 | self.checkpoints.insert(0, suffix) 120 | if force: 121 | self.max_n = max(self.max_n, len(self.checkpoints)) 122 | 123 | return found, paths 124 | 125 | def checkpoint(self, suffix): 126 | assert not '_' in suffix, "Underscores are not allowed" 127 | 128 | self.checkpoints.append(suffix) 129 | 130 | for name, data in self.models.items(): 131 | for d in ('model', 'opt'): 132 | fn = self._filename(d, name, suffix) 133 | path = self._get_full_path(fn) 134 | if not os.path.isfile(path) and data[d] is not None: 135 | torch.save(data[d].state_dict(), path) 136 | 137 | # removing 138 | n_remove = max(0, len(self.checkpoints) - self.max_n) 139 | removed = self.clean(n_remove) 140 | 141 | return removed 142 | -------------------------------------------------------------------------------- /utils/collections.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | """A simple attribute dictionary used for representing configuration options.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | 24 | class AttrDict(dict): 25 | 26 | IMMUTABLE = '__immutable__' 27 | 28 | def __init__(self, *args, **kwargs): 29 | super(AttrDict, self).__init__(*args, **kwargs) 30 | self.__dict__[AttrDict.IMMUTABLE] = False 31 | 32 | def __getattr__(self, name): 33 | if name in self.__dict__: 34 | return self.__dict__[name] 35 | elif name in self: 36 | return self[name] 37 | else: 38 | raise AttributeError(name) 39 | 40 | def __setattr__(self, name, value): 41 | if not self.__dict__[AttrDict.IMMUTABLE]: 42 | if name in self.__dict__: 43 | self.__dict__[name] = value 44 | else: 45 | self[name] = value 46 | else: 47 | raise AttributeError( 48 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'. 49 | format(name, value) 50 | ) 51 | 52 | def immutable(self, is_immutable): 53 | """Set immutability to is_immutable and recursively apply the setting 54 | to all nested AttrDicts. 55 | """ 56 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable 57 | # Recursively set immutable state 58 | for v in self.__dict__.values(): 59 | if isinstance(v, AttrDict): 60 | v.immutable(is_immutable) 61 | for v in self.values(): 62 | if isinstance(v, AttrDict): 63 | v.immutable(is_immutable) 64 | 65 | def is_immutable(self): 66 | return self.__dict__[AttrDict.IMMUTABLE] 67 | -------------------------------------------------------------------------------- /utils/dcrf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydensecrf.densecrf as dcrf 3 | from pydensecrf.utils import unary_from_softmax 4 | 5 | 6 | def crf_inference(img, probs, t=10, scale_factor=1, labels=21): 7 | 8 | h, w = img.shape[:2] 9 | n_labels = labels 10 | 11 | d = dcrf.DenseCRF2D(w, h, n_labels) 12 | 13 | unary = unary_from_softmax(probs) 14 | unary = np.ascontiguousarray(unary) 15 | 16 | d.setUnaryEnergy(unary) 17 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 18 | d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 19 | Q = d.inference(t) 20 | 21 | return np.array(Q).reshape((n_labels, h, w)) 22 | -------------------------------------------------------------------------------- /utils/inference_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import scipy.misc 5 | 6 | import torch.nn.functional as F 7 | 8 | from PIL import Image 9 | from utils.dcrf import crf_inference 10 | 11 | from datasets.pascal_voc_ms import MultiscaleLoader, CropLoader 12 | 13 | class ResultWriter: 14 | 15 | def __init__(self, cfg, palette, out_path, verbose=True): 16 | self.cfg = cfg 17 | self.palette = palette 18 | self.root = out_path 19 | self.verbose = verbose 20 | 21 | def _mask_overlay(self, mask, image, alpha=0.3): 22 | """Creates an overlayed mask visualisation""" 23 | mask_rgb = self.__mask2rgb(mask) 24 | return alpha * image + (1 - alpha) * mask_rgb 25 | 26 | def __mask2rgb(self, mask): 27 | im = Image.fromarray(mask).convert("P") 28 | im.putpalette(self.palette) 29 | mask_rgb = np.array(im.convert("RGB"), dtype=np.float) 30 | return mask_rgb / 255. 31 | 32 | def _merge_masks(self, masks, labels, pads): 33 | """Combines masks at multiple scales 34 | 35 | Args: 36 | masks: list of masks obtained at different scales 37 | (already scaled to the original) 38 | Returns: 39 | pred: combined single mask 40 | pred_crf: refined mask with CRF 41 | """ 42 | raise NotImplementedError 43 | 44 | def save(self, img_path, img_orig, all_masks, labels, pads, gt_mask): 45 | 46 | img_name = os.path.basename(img_path).rstrip(".jpg") 47 | 48 | # converting original image to [0, 255] 49 | img_orig255 = np.round(255. * img_orig).astype(np.uint8) 50 | img_orig255 = np.transpose(img_orig255, [1,2,0]) 51 | img_orig255 = np.ascontiguousarray(img_orig255) 52 | 53 | merged_mask = self._merge_masks(all_masks, pads, labels, img_orig255.shape[:2]) 54 | pred = np.argmax(merged_mask, 0) 55 | 56 | # CRF 57 | pred_crf = crf_inference(img_orig255, merged_mask, t=10, scale_factor=1, labels=21) 58 | pred_crf = np.argmax(pred_crf, 0) 59 | 60 | filepath = os.path.join(self.root, img_name + '.png') 61 | scipy.misc.imsave(filepath, pred.astype(np.uint8)) 62 | 63 | filepath = os.path.join(self.root, "crf", img_name + '.png') 64 | scipy.misc.imsave(filepath, pred_crf.astype(np.uint8)) 65 | 66 | if self.verbose: 67 | mask_gt = gt_mask.numpy() 68 | masks_all = np.concatenate([pred, pred_crf, mask_gt], 1).astype(np.uint8) 69 | images = np.concatenate([img_orig]*3, 2) 70 | images = np.transpose(images, [1,2,0]) 71 | 72 | overlay = self._mask_overlay(masks_all, images) 73 | filepath = os.path.join(self.root, "vis", img_name + '.png') 74 | overlay255 = np.round(overlay * 255.).astype(np.uint8) 75 | scipy.misc.imsave(filepath, overlay255) 76 | 77 | class MergeMultiScale(ResultWriter): 78 | 79 | def _cut(self, x_chw, pads): 80 | pad_h, pad_w, h, w = [int(p) for p in pads] 81 | return x_chw[:, pad_h:(pad_h + h), pad_w:(pad_w + w)] 82 | 83 | def _merge_masks(self, masks, labels, pads, imsize_hw): 84 | 85 | mask_list = [] 86 | for i, mask in enumerate(masks.split(1, dim=0)): 87 | 88 | # removing the padding 89 | mask_cut = self._cut(mask[0], pads[i]).unsqueeze(0) 90 | 91 | # normalising the scale 92 | mask_cut = F.interpolate(mask_cut, imsize_hw, mode='bilinear', align_corners=False)[0] 93 | 94 | # flipping if necessary 95 | if self.cfg.FLIP and i % 2 == 1: 96 | mask_cut = torch.flip(mask_cut, (-1, )) 97 | 98 | # getting the max response 99 | mask_cut[1:, ::] *= labels[:, None, None] 100 | mask_list.append(mask_cut) 101 | 102 | mean_mask = sum(mask_list).numpy() / len(mask_list) 103 | 104 | # discounting BG 105 | #mean_mask[0, ::] *= 0.5 106 | mean_mask[0, ::] = np.power(mean_mask[0, ::], self.cfg.BG_POW) 107 | 108 | return mean_mask 109 | 110 | class MergeCrops(ResultWriter): 111 | 112 | def _cut(self, x_chw, pads): 113 | pad_h, pad_w, h, w = [int(p) for p in pads] 114 | return x_chw[:, pad_h:(pad_h + h), pad_w:(pad_w + w)] 115 | 116 | def _merge_masks(self, masks, labels, coords, imsize_hw): 117 | num_classes = masks.size(1) 118 | 119 | masks_sum = torch.zeros([num_classes, *imsize_hw]).type_as(masks) 120 | counts = torch.zeros(imsize_hw).type_as(masks) 121 | 122 | for ii, (mask, pads) in enumerate(zip(masks.split(1), coords.split(1))): 123 | 124 | mask = mask[0] 125 | s_h, e_h, s_w, e_w = pads[0][:4] 126 | pad_t, pad_l = pads[0][4:] 127 | 128 | if self.cfg.FLIP and ii % 2 == 0: 129 | mask = mask.flip(-1) 130 | 131 | # crop mask, if needed 132 | m_h = 0 if s_h > 0 else pad_t 133 | m_w = 0 if s_w > 0 else pad_l 134 | 135 | # due to padding 136 | # end point is shifted 137 | s_h = max(0, s_h - pad_t) 138 | s_w = max(0, s_w - pad_l) 139 | e_h = min(e_h - pad_t, imsize_hw[0]) 140 | e_w = min(e_w - pad_l, imsize_hw[1]) 141 | 142 | m_he = m_h + e_h - s_h 143 | m_we = m_w + e_w - s_w 144 | 145 | masks_sum[:, s_h:e_h, s_w:e_w] += mask[:, m_h:m_he, m_w:m_we] 146 | counts[s_h:e_h, s_w:e_w] += 1 147 | 148 | assert torch.all(counts > 0) 149 | 150 | # removing false pasitives 151 | masks_sum[1:, ::] *= labels[:, None, None] 152 | 153 | # removing the padding 154 | return (masks_sum / counts).numpy() 155 | 156 | class PAMRWriter(ResultWriter): 157 | 158 | def save_batch(self, img_paths, imgs, all_masks, all_gt_masks): 159 | 160 | for b, img_path in enumerate(img_paths): 161 | 162 | img_name = os.path.basename(img_path).rstrip(".jpg") 163 | img_orig = imgs[b] 164 | gt_mask = all_gt_masks[b] 165 | 166 | # converting original image to [0, 255] 167 | img_orig255 = np.round(255. * img_orig).astype(np.uint8) 168 | img_orig255 = np.transpose(img_orig255, [1,2,0]) 169 | img_orig255 = np.ascontiguousarray(img_orig255) 170 | 171 | mask_gt = torch.argmax(gt_mask, 0) 172 | 173 | # cancel ambiguous 174 | ambiguous = gt_mask.sum(0) == 0 175 | mask_gt[ambiguous] = 255 176 | mask_gt = mask_gt.numpy() 177 | 178 | # saving GT 179 | image_hwc = np.transpose(img_orig, [1,2,0]) 180 | overlay_gt = self._mask_overlay(mask_gt.astype(np.uint8), image_hwc, alpha=0.5) 181 | 182 | filepath = os.path.join(self.root, img_name + '_gt.png') 183 | overlay255 = np.round(overlay_gt * 255.).astype(np.uint8) 184 | scipy.misc.imsave(filepath, overlay255) 185 | 186 | for it, mask_batch in enumerate(all_masks): 187 | 188 | mask = mask_batch[b] 189 | mask_idx = torch.argmax(mask, 0) 190 | 191 | # cancel ambiguous 192 | ambiguous = mask.sum(0) == 0 193 | mask_idx[ambiguous] = 255 194 | 195 | overlay = self._mask_overlay(mask_idx.numpy().astype(np.uint8), image_hwc, alpha=0.5) 196 | 197 | filepath = os.path.join(self.root, img_name + '_{:02d}.png'.format(it)) 198 | overlay255 = np.round(overlay * 255.).astype(np.uint8) 199 | scipy.misc.imsave(filepath, overlay255) 200 | 201 | 202 | def get_inference_io(method_name): 203 | 204 | if method_name == "multiscale": 205 | return MergeMultiScale, MultiscaleLoader 206 | elif method_name == "multicrop": 207 | return MergeCrops, CropLoader 208 | else: 209 | raise NotImplementedError("Method {} is unknown".format(method_name)) 210 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import threading 12 | import numpy as np 13 | import torch 14 | 15 | from sklearn.metrics import average_precision_score 16 | 17 | class Metric(object): 18 | 19 | # synonyms 20 | IoU = "IoU" 21 | MaskIoU = "IoU" 22 | 23 | Precision = "Precision" 24 | Recall = "Recall" 25 | ClassAP = "ClassAP" 26 | 27 | def __init__(self): 28 | self.data = {} 29 | self.count = {} 30 | self.fn = {} 31 | 32 | # initising the functions 33 | self.fn[Metric.MaskIoU] = self.mask_iou_ 34 | self.fn[Metric.Precision] = self.precision_ 35 | self.fn[Metric.Recall] = self.recall_ 36 | self.fn[Metric.ClassAP] = self.class_ap_ 37 | 38 | def add_metric(self, m): 39 | assert m in self.fn, "Unknown metric with key {}".format(m) 40 | 41 | self.data[m] = 0. 42 | self.count[m] = 0. 43 | 44 | def metrics(self): 45 | return self.data.keys() 46 | 47 | def print_summary(self): 48 | 49 | keys_sorted = sorted(self.data.keys()) 50 | 51 | for m in keys_sorted: 52 | print("{}: {:5.4f}".format(m, self.summarize(m))) 53 | 54 | def reset_stat(self, m=None): 55 | 56 | if m is None: 57 | # resetting everything 58 | for m in self.data: 59 | self.data[m] = 0. 60 | self.count[m] = 0. 61 | else: 62 | assert m in self.fn, "Unknown metric with key {}".format(m) 63 | 64 | self.data[m] = 0. 65 | self.count[m] = 0. 66 | 67 | def update_value(self, m, value, count=1.): 68 | 69 | self.data[m] += value 70 | self.count[m] += count 71 | 72 | def update(self, gt, pred): 73 | 74 | for m in self.data: 75 | self.data[m] += self.fn[m](gt, pred) 76 | self.count[m] += 1. 77 | 78 | def merge(self, metric): 79 | 80 | for m in metric.data: 81 | if not m in self.data: 82 | self.reset_stat(m) 83 | 84 | self.update_value(m, metric.data[m], metric.count[m]) 85 | 86 | def merge_summary(self, metric): 87 | 88 | for m in metric.data: 89 | if not m in self.data: 90 | self.reset_stat(m) 91 | 92 | mean_value = metric.summarize(m) 93 | self.update_value(m, mean_value, 1.) 94 | 95 | def summarize(self, m): 96 | if not m in self.count or self.count[m] == 0.: 97 | return 0. 98 | 99 | return self.data[m] / self.count[m] 100 | 101 | @staticmethod 102 | def mask_iou_(a, b): 103 | # computing the mask IoU 104 | isc = (a * b).sum() 105 | unn = (a + b).sum() 106 | z = unn - isc 107 | 108 | if z == 0.: 109 | return 0. 110 | 111 | return isc / z 112 | 113 | @staticmethod 114 | def precision_(gt, p): 115 | # computing the mask IoU 116 | acc = (gt * p).sum() 117 | sss = p.sum() 118 | 119 | if sss == 0.: 120 | return 0. 121 | 122 | return acc / sss 123 | 124 | @staticmethod 125 | def recall_(gt, p): 126 | # computing the mask IoU 127 | acc = (gt * p).sum() 128 | sss = gt.sum() 129 | 130 | if sss == 0.: 131 | return 0. 132 | 133 | return acc / sss 134 | 135 | @staticmethod 136 | def class_ap_(gt, p): 137 | 138 | # this return AP for each class 139 | ap = average_precision_score(gt, p, average=None) 140 | 141 | # return the average 142 | return np.mean(ap[gt.sum(0) > 0]) 143 | 144 | 145 | def compute_jaccard(preds_masks_all, targets_masks_all, num_classes=21): 146 | 147 | tps = np.zeros((num_classes, )) 148 | fps = np.zeros((num_classes, )) 149 | fns = np.zeros((num_classes, )) 150 | counts = np.zeros((num_classes, )) 151 | 152 | for mask_pred, mask_gt in zip(preds_masks_all, targets_masks_all): 153 | 154 | bs, h, w = mask_pred.size() 155 | assert bs == mask_gt.size(0), "Batch size mismatch" 156 | assert h == mask_gt.size(1), "Width mismatch" 157 | assert w == mask_gt.size(2), "Height mismatch" 158 | 159 | mask_pred = mask_pred.view(bs, 1, -1) 160 | mask_gt = mask_gt.view(bs, 1, -1) 161 | 162 | # ignore ambiguous 163 | mask_pred[mask_gt == 255] = 255 164 | 165 | for label in range(num_classes): 166 | mask_pred_ = (mask_pred == label).float() 167 | mask_gt_ = (mask_gt == label).float() 168 | 169 | tps[label] += (mask_pred_ * mask_gt_).sum().item() 170 | diff = mask_pred_ - mask_gt_ 171 | fps[label] += np.maximum(0., diff).float().sum().item() 172 | fns[label] += np.maximum(0., -diff).float().sum().item() 173 | 174 | jaccards = [None]*num_classes 175 | precision = [None]*num_classes 176 | recall = [None]*num_classes 177 | for i in range(num_classes): 178 | tp = tps[i] 179 | fn = fns[i] 180 | fp = fps[i] 181 | jaccards[i] = tp / max(1e-3, fn + fp + tp) 182 | precision[i] = tp / max(1e-3, tp + fp) 183 | recall[i] = tp / max(1e-3, tp + fn) 184 | 185 | return jaccards, precision, recall 186 | 187 | -------------------------------------------------------------------------------- /utils/pallete.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | from PIL import Image 12 | 13 | def get_mask_pallete(npimg, dataset='detail'): 14 | """Get image color pallete for visualizing masks""" 15 | # recovery boundary 16 | if dataset == 'pascal_voc': 17 | npimg[npimg==21] = 255 18 | # put colormap 19 | out_img = Image.fromarray(npimg.squeeze().astype('uint8')) 20 | if dataset == 'ade20k': 21 | out_img.putpalette(adepallete) 22 | elif dataset == 'cityscapes': 23 | out_img.putpalette(citypallete) 24 | elif dataset in ('detail', 'pascal_voc', 'pascal_aug'): 25 | out_img.putpalette(vocpallete) 26 | return out_img 27 | 28 | def _get_voc_pallete(num_cls): 29 | n = num_cls 30 | pallete = [0]*(n*3) 31 | for j in range(0,n): 32 | lab = j 33 | pallete[j*3+0] = 0 34 | pallete[j*3+1] = 0 35 | pallete[j*3+2] = 0 36 | i = 0 37 | while (lab > 0): 38 | pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i)) 39 | pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i)) 40 | pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i)) 41 | i = i + 1 42 | lab >>= 3 43 | return pallete 44 | 45 | vocpallete = _get_voc_pallete(256) 46 | 47 | adepallete = [0,0,0,120,120,120,180,120,120,6,230,230,80,50,50,4,200,3,120,120,80,140,140,140,204,5,255,230,230,230,4,250,7,224,5,255,235,255,7,150,5,61,120,120,70,8,255,51,255,6,82,143,255,140,204,255,4,255,51,7,204,70,3,0,102,200,61,230,250,255,6,51,11,102,255,255,7,71,255,9,224,9,7,230,220,220,220,255,9,92,112,9,255,8,255,214,7,255,224,255,184,6,10,255,71,255,41,10,7,255,255,224,255,8,102,8,255,255,61,6,255,194,7,255,122,8,0,255,20,255,8,41,255,5,153,6,51,255,235,12,255,160,150,20,0,163,255,140,140,140,250,10,15,20,255,0,31,255,0,255,31,0,255,224,0,153,255,0,0,0,255,255,71,0,0,235,255,0,173,255,31,0,255,11,200,200,255,82,0,0,255,245,0,61,255,0,255,112,0,255,133,255,0,0,255,163,0,255,102,0,194,255,0,0,143,255,51,255,0,0,82,255,0,255,41,0,255,173,10,0,255,173,255,0,0,255,153,255,92,0,255,0,255,255,0,245,255,0,102,255,173,0,255,0,20,255,184,184,0,31,255,0,255,61,0,71,255,255,0,204,0,255,194,0,255,82,0,10,255,0,112,255,51,0,255,0,194,255,0,122,255,0,255,163,255,153,0,0,255,10,255,112,0,143,255,0,82,0,255,163,255,0,255,235,0,8,184,170,133,0,255,0,255,92,184,0,255,255,0,31,0,184,255,0,214,255,255,0,112,92,255,0,0,224,255,112,224,255,70,184,160,163,0,255,153,0,255,71,255,0,255,0,163,255,204,0,255,0,143,0,255,235,133,255,0,255,0,235,245,0,255,255,0,122,255,245,0,10,190,212,214,255,0,0,204,255,20,0,255,255,255,0,0,153,255,0,41,255,0,255,204,41,0,255,41,255,0,173,0,255,0,245,255,71,0,255,122,0,255,0,255,184,0,92,255,184,255,0,0,133,255,255,214,0,25,194,194,102,255,0,92,0,255] 48 | 49 | citypallete = [ 50 | 128,64,128,244,35,232,70,70,70,102,102,156,190,153,153,153,153,153,250,170,30,220,220,0,107,142,35,152,251,152,70,130,180,220,20,60,255,0,0,0,0,142,0,0,70,0,60,100,0,80,100,0,0,230,119,11,32,128,192,0,0,64,128,128,64,128,0,192,128,128,192,128,64,64,0,192,64,0,64,192,0,192,192,0,64,64,128,192,64,128,64,192,128,192,192,128,0,0,64,128,0,64,0,128,64,128,128,64,0,0,192,128,0,192,0,128,192,128,128,192,64,0,64,192,0,64,64,128,64,192,128,64,64,0,192,192,0,192,64,128,192,192,128,192,0,64,64,128,64,64,0,192,64,128,192,64,0,64,192,128,64,192,0,192,192,128,192,192,64,64,64,192,64,64,64,192,64,192,192,64,64,64,192,192,64,192,64,192,192,192,192,192,32,0,0,160,0,0,32,128,0,160,128,0,32,0,128,160,0,128,32,128,128,160,128,128,96,0,0,224,0,0,96,128,0,224,128,0,96,0,128,224,0,128,96,128,128,224,128,128,32,64,0,160,64,0,32,192,0,160,192,0,32,64,128,160,64,128,32,192,128,160,192,128,96,64,0,224,64,0,96,192,0,224,192,0,96,64,128,224,64,128,96,192,128,224,192,128,32,0,64,160,0,64,32,128,64,160,128,64,32,0,192,160,0,192,32,128,192,160,128,192,96,0,64,224,0,64,96,128,64,224,128,64,96,0,192,224,0,192,96,128,192,224,128,192,32,64,64,160,64,64,32,192,64,160,192,64,32,64,192,160,64,192,32,192,192,160,192,192,96,64,64,224,64,64,96,192,64,224,192,64,96,64,192,224,64,192,96,192,192,224,192,192,0,32,0,128,32,0,0,160,0,128,160,0,0,32,128,128,32,128,0,160,128,128,160,128,64,32,0,192,32,0,64,160,0,192,160,0,64,32,128,192,32,128,64,160,128,192,160,128,0,96,0,128,96,0,0,224,0,128,224,0,0,96,128,128,96,128,0,224,128,128,224,128,64,96,0,192,96,0,64,224,0,192,224,0,64,96,128,192,96,128,64,224,128,192,224,128,0,32,64,128,32,64,0,160,64,128,160,64,0,32,192,128,32,192,0,160,192,128,160,192,64,32,64,192,32,64,64,160,64,192,160,64,64,32,192,192,32,192,64,160,192,192,160,192,0,96,64,128,96,64,0,224,64,128,224,64,0,96,192,128,96,192,0,224,192,128,224,192,64,96,64,192,96,64,64,224,64,192,224,64,64,96,192,192,96,192,64,224,192,192,224,192,32,32,0,160,32,0,32,160,0,160,160,0,32,32,128,160,32,128,32,160,128,160,160,128,96,32,0,224,32,0,96,160,0,224,160,0,96,32,128,224,32,128,96,160,128,224,160,128,32,96,0,160,96,0,32,224,0,160,224,0,32,96,128,160,96,128,32,224,128,160,224,128,96,96,0,224,96,0,96,224,0,224,224,0,96,96,128,224,96,128,96,224,128,224,224,128,32,32,64,160,32,64,32,160,64,160,160,64,32,32,192,160,32,192,32,160,192,160,160,192,96,32,64,224,32,64,96,160,64,224,160,64,96,32,192,224,32,192,96,160,192,224,160,192,32,96,64,160,96,64,32,224,64,160,224,64,32,96,192,160,96,192,32,224,192,160,224,192,96,96,64,224,96,64,96,224,64,224,224,64,96,96,192,224,96,192,96,224,192,0,0,0] 51 | 52 | -------------------------------------------------------------------------------- /utils/stat_manager.py: -------------------------------------------------------------------------------- 1 | 2 | class StatManager(object): 3 | 4 | def __init__(self): 5 | self.func_keys = {} 6 | self.vals = {} 7 | self.vals_count = {} 8 | self.formats = {} 9 | 10 | def reset(self): 11 | for k in self.vals: 12 | self.vals[k] = 0.0 13 | self.vals_count[k] = 0.0 14 | 15 | def add_val(self, key, form="{:4.3f}"): 16 | self.vals[key] = 0.0 17 | self.vals_count[key] = 0.0 18 | self.formats[key] = form 19 | 20 | def get_val(self, key): 21 | return self.vals[key], self.vals_count[key] 22 | 23 | def add_compute(self, key, func, form="{:4.3f}"): 24 | self.func_keys[key] = func 25 | self.add_val(key) 26 | self.formats[key] = form 27 | 28 | def update_stats(self, key, val, count = 1): 29 | if not key in self.vals: 30 | self.add_val(key) 31 | 32 | self.vals[key] += val 33 | self.vals_count[key] += count 34 | 35 | def compute_stats(self, a, b, size = 1): 36 | 37 | for k, func in self.func_keys.iteritems(): 38 | self.vals[k] += func(a, b) 39 | self.vals_count[k] += size 40 | 41 | def has_vals(self, k): 42 | if not k in self.vals_count: 43 | return False 44 | return self.vals_count[k] > 0 45 | 46 | def summarize_key(self, k): 47 | if self.has_vals(k): 48 | return self.vals[k] / self.vals_count[k] 49 | else: 50 | return 0 51 | 52 | def summarize(self, epoch = 0, verbose = True): 53 | 54 | if verbose: 55 | out = "\tEpoch[{:03d}]".format(epoch) 56 | for k in self.vals: 57 | if self.has_vals(k): 58 | out += (" / {} " + self.formats[k]).format(k, self.summarize_key(k)) 59 | print(out) 60 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer: 4 | def __init__(self, starting_msg = None): 5 | self.start = time.time() 6 | self.stage_start = self.start 7 | 8 | if starting_msg is not None: 9 | print(starting_msg, time.ctime(time.time())) 10 | 11 | 12 | def update_progress(self, progress): 13 | self.elapsed = time.time() - self.start 14 | self.est_total = self.elapsed / progress 15 | self.est_remaining = self.est_total - self.elapsed 16 | self.est_finish = int(self.start + self.est_total) 17 | 18 | 19 | def str_est_finish(self): 20 | return str(time.ctime(self.est_finish)) 21 | 22 | def get_stage_elapsed(self): 23 | return time.time() - self.stage_start 24 | 25 | def reset_stage(self): 26 | self.stage_start = time.time() 27 | --------------------------------------------------------------------------------