├── .gitignore
├── LICENSE
├── README.md
├── base_trainer.py
├── configs
├── voc_resnet101.yaml
├── voc_resnet38.yaml
├── voc_resnet50.yaml
└── voc_vgg16.yaml
├── core
├── __init__.py
└── config.py
├── data
├── test.txt
├── train_augvoc.txt
├── train_voc.txt
└── val_voc.txt
├── datasets
├── __init__.py
├── pascal_voc.py
├── pascal_voc_ms.py
├── transforms.py
└── utils.py
├── eval_seg.py
├── figures
├── results.gif
└── results.png
├── fonts
└── UbuntuMono-R.ttf
├── infer_val.py
├── launch
├── eval_seg.sh
├── infer_val.sh
├── run_bsl_resnet101.sh
├── run_bsl_resnet38.sh
├── run_bsl_resnet50.sh
├── run_bsl_vgg16.sh
├── run_voc_resnet101.sh
├── run_voc_resnet38.sh
├── run_voc_resnet50.sh
└── run_voc_vgg16.sh
├── losses
└── __init__.py
├── models
├── __init__.py
├── backbones
│ ├── __init__.py
│ ├── base_net.py
│ ├── resnet38d.py
│ ├── resnets.py
│ └── vgg16d.py
├── mods
│ ├── __init__.py
│ ├── aspp.py
│ ├── gci.py
│ ├── pamr.py
│ └── sg.py
└── stage_net.py
├── opts.py
├── requirements.txt
├── tools
└── convert_sbd.py
├── train.py
└── utils
├── __init__.py
├── checkpoints.py
├── collections.py
├── dcrf.py
├── inference_tools.py
├── metrics.py
├── pallete.py
├── stat_manager.py
└── timer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.pyc
3 | logs/
4 | data/
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction, and
10 | distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by the
13 | copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all other
16 | entities that control, are controlled by, or are under common control with
17 | that entity. For the purposes of this definition, "control" means (i) the
18 | power, direct or indirect, to cause the direction or management of such
19 | entity, whether by contract or otherwise, or (ii) ownership of
20 | fifty percent (50%) or more of the outstanding shares, or (iii) beneficial
21 | ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity exercising
24 | permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation source,
28 | and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical transformation
31 | or translation of a Source form, including but not limited to compiled
32 | object code, generated documentation, and conversions to
33 | other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or Object
36 | form, made available under the License, as indicated by a copyright notice
37 | that is included in or attached to the work (an example is provided in the
38 | Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object form,
41 | that is based on (or derived from) the Work and for which the editorial
42 | revisions, annotations, elaborations, or other modifications represent,
43 | as a whole, an original work of authorship. For the purposes of this
44 | License, Derivative Works shall not include works that remain separable
45 | from, or merely link (or bind by name) to the interfaces of, the Work and
46 | Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including the original
49 | version of the Work and any modifications or additions to that Work or
50 | Derivative Works thereof, that is intentionally submitted to Licensor for
51 | inclusion in the Work by the copyright owner or by an individual or
52 | Legal Entity authorized to submit on behalf of the copyright owner.
53 | For the purposes of this definition, "submitted" means any form of
54 | electronic, verbal, or written communication sent to the Licensor or its
55 | representatives, including but not limited to communication on electronic
56 | mailing lists, source code control systems, and issue tracking systems
57 | that are managed by, or on behalf of, the Licensor for the purpose of
58 | discussing and improving the Work, but excluding communication that is
59 | conspicuously marked or otherwise designated in writing by the copyright
60 | owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity on
63 | behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License.
67 |
68 | Subject to the terms and conditions of this License, each Contributor
69 | hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
70 | royalty-free, irrevocable copyright license to reproduce, prepare
71 | Derivative Works of, publicly display, publicly perform, sublicense,
72 | and distribute the Work and such Derivative Works in
73 | Source or Object form.
74 |
75 | 3. Grant of Patent License.
76 |
77 | Subject to the terms and conditions of this License, each Contributor
78 | hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
79 | royalty-free, irrevocable (except as stated in this section) patent
80 | license to make, have made, use, offer to sell, sell, import, and
81 | otherwise transfer the Work, where such license applies only to those
82 | patent claims licensable by such Contributor that are necessarily
83 | infringed by their Contribution(s) alone or by combination of their
84 | Contribution(s) with the Work to which such Contribution(s) was submitted.
85 | If You institute patent litigation against any entity (including a
86 | cross-claim or counterclaim in a lawsuit) alleging that the Work or a
87 | Contribution incorporated within the Work constitutes direct or
88 | contributory patent infringement, then any patent licenses granted to
89 | You under this License for that Work shall terminate as of the date such
90 | litigation is filed.
91 |
92 | 4. Redistribution.
93 |
94 | You may reproduce and distribute copies of the Work or Derivative Works
95 | thereof in any medium, with or without modifications, and in Source or
96 | Object form, provided that You meet the following conditions:
97 |
98 | 1. You must give any other recipients of the Work or Derivative Works a
99 | copy of this License; and
100 |
101 | 2. You must cause any modified files to carry prominent notices stating
102 | that You changed the files; and
103 |
104 | 3. You must retain, in the Source form of any Derivative Works that You
105 | distribute, all copyright, patent, trademark, and attribution notices from
106 | the Source form of the Work, excluding those notices that do not pertain
107 | to any part of the Derivative Works; and
108 |
109 | 4. If the Work includes a "NOTICE" text file as part of its distribution,
110 | then any Derivative Works that You distribute must include a readable copy
111 | of the attribution notices contained within such NOTICE file, excluding
112 | those notices that do not pertain to any part of the Derivative Works,
113 | in at least one of the following places: within a NOTICE text file
114 | distributed as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or, within a
116 | display generated by the Derivative Works, if and wherever such
117 | third-party notices normally appear. The contents of the NOTICE file are
118 | for informational purposes only and do not modify the License.
119 | You may add Your own attribution notices within Derivative Works that You
120 | distribute, alongside or as an addendum to the NOTICE text from the Work,
121 | provided that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and may
125 | provide additional or different license terms and conditions for use,
126 | reproduction, or distribution of Your modifications, or for any such
127 | Derivative Works as a whole, provided Your use, reproduction, and
128 | distribution of the Work otherwise complies with the conditions
129 | stated in this License.
130 |
131 | 5. Submission of Contributions.
132 |
133 | Unless You explicitly state otherwise, any Contribution intentionally
134 | submitted for inclusion in the Work by You to the Licensor shall be under
135 | the terms and conditions of this License, without any additional
136 | terms or conditions. Notwithstanding the above, nothing herein shall
137 | supersede or modify the terms of any separate license agreement you may
138 | have executed with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks.
141 |
142 | This License does not grant permission to use the trade names, trademarks,
143 | service marks, or product names of the Licensor, except as required for
144 | reasonable and customary use in describing the origin of the Work and
145 | reproducing the content of the NOTICE file.
146 |
147 | 7. Disclaimer of Warranty.
148 |
149 | Unless required by applicable law or agreed to in writing, Licensor
150 | provides the Work (and each Contributor provides its Contributions)
151 | on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
152 | either express or implied, including, without limitation, any warranties
153 | or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS
154 | FOR A PARTICULAR PURPOSE. You are solely responsible for determining the
155 | appropriateness of using or redistributing the Work and assume any risks
156 | associated with Your exercise of permissions under this License.
157 |
158 | 8. Limitation of Liability.
159 |
160 | In no event and under no legal theory, whether in tort
161 | (including negligence), contract, or otherwise, unless required by
162 | applicable law (such as deliberate and grossly negligent acts) or agreed
163 | to in writing, shall any Contributor be liable to You for damages,
164 | including any direct, indirect, special, incidental, or consequential
165 | damages of any character arising as a result of this License or out of
166 | the use or inability to use the Work (including but not limited to damages
167 | for loss of goodwill, work stoppage, computer failure or malfunction,
168 | or any and all other commercial damages or losses), even if such
169 | Contributor has been advised of the possibility of such damages.
170 |
171 | 9. Accepting Warranty or Additional Liability.
172 |
173 | While redistributing the Work or Derivative Works thereof, You may choose
174 | to offer, and charge a fee for, acceptance of support, warranty,
175 | indemnity, or other liability obligations and/or rights consistent with
176 | this License. However, in accepting such obligations, You may act only
177 | on Your own behalf and on Your sole responsibility, not on behalf of any
178 | other Contributor, and only if You agree to indemnify, defend, and hold
179 | each Contributor harmless for any liability incurred by, or claims
180 | asserted against, such Contributor by reason of your accepting any such
181 | warranty or additional liability.
182 |
183 | END OF TERMS AND CONDITIONS
184 |
185 | Copyright 2020 TU Darmstadt
186 |
187 | Author: Nikita Araslanov
188 |
189 | Licensed under the Apache License, Version 2.0 (the "License");
190 | you may not use this file except in compliance with the License.
191 | You may obtain a copy of the License at
192 |
193 | http://www.apache.org/licenses/LICENSE-2.0
194 |
195 | Unless required by applicable law or agreed to in writing, software
196 | distributed under the License is distributed on an "AS IS" BASIS,
197 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
198 | or implied. See the License for the specific language governing
199 | permissions and limitations under the License.
200 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Single-Stage Semantic Segmentation from Image Labels
2 |
3 | [](https://opensource.org/licenses/Apache-2.0)
4 | [](https://pytorch.org/)
5 |
6 | This repository contains the original implementation of our paper:
7 |
8 |
9 | **Single-stage Semantic Segmentation from Image Labels**
10 | *[Nikita Araslanov](https://arnike.github.io) and [Stefan Roth](https://www.visinf.tu-darmstadt.de/team_members/sroth/sroth.en.jsp)*
11 | CVPR 2020. [[pdf](https://openaccess.thecvf.com/content_CVPR_2020/papers/Araslanov_Single-Stage_Semantic_Segmentation_From_Image_Labels_CVPR_2020_paper.pdf)] [[supp](https://openaccess.thecvf.com/content_CVPR_2020/supplemental/Araslanov_Single-Stage_Semantic_Segmentation_CVPR_2020_supplemental.pdf)]
12 | [[arXiv](https://arxiv.org/abs/2005.08104)]
13 |
14 | Contact: Nikita Araslanov
15 |
16 |
17 | | 
|
18 | |:---|
19 | | We attain competitive results by training a single network model
for segmentation in a self-supervised fashion using only
image-level annotations (one run of 20 epochs on Pascal VOC). |
20 |
21 | ### Setup
22 | 0. **Minimum requirements.** This project was originally developed with Python 3.6, PyTorch 1.0 and CUDA 9.0. The training requires at least two Titan X GPUs (12Gb memory each).
23 | 1. **Setup your Python environment.** Please, clone the repository and install the dependencies. We recommend using Anaconda 3 distribution:
24 | ```
25 | conda create -n --file requirements.txt
26 | ```
27 | 2. **Download and link to the dataset.** We train our model on the original Pascal VOC 2012 augmented with the SBD data (10K images in total). Download the data from:
28 | - VOC: [Training/Validation (2GB .tar file)](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar)
29 | - SBD: [Training (1.4GB .tgz file)](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz)
30 |
31 | Link to the data:
32 | ```
33 | ln -s /data/voc
34 | ln -s /data/sbd
35 | ```
36 | Make sure that the first directory in `data/voc` is `VOCdevkit`; the first directory in `data/sbd` is `benchmark_RELEASE`.
37 | 3. **Download pre-trained models.** Download the initial weights (pre-trained on ImageNet) for the backbones you are planning to use and place them into `/models/weights/`.
38 |
39 | | Backbone | Initial Weights | Comment |
40 | |:---:|:---:|:---:|
41 | | WideResNet38 | [ilsvrc-cls_rna-a1_cls1000_ep-0001.pth (402M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/models/ilsvrc-cls_rna-a1_cls1000_ep-0001.pth) | Converted from [mxnet](https://github.com/itijyou/ademxapp) |
42 | | VGG16 | [vgg16_20M.pth (79M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/models/vgg16_20M.pth) | Converted from [Caffe](http://liangchiehchen.com/projects/Init%20Models.html) |
43 | | ResNet50 | [resnet50-19c8e357.pth](https://download.pytorch.org/models/resnet50-19c8e357.pth) | PyTorch official |
44 | | ResNet101 | [resnet101-5d3b4d8f.pth](https://download.pytorch.org/models/resnet101-5d3b4d8f.pth) | PyTorch official |
45 |
46 |
47 | ### Training, Inference and Evaluation
48 | The directory `launch` contains template bash scripts for training, inference and evaluation.
49 |
50 | **Training.** For each run, you need to specify names of two variables, for example
51 | ```bash
52 | EXP=baselines
53 | RUN_ID=v01
54 | ```
55 | Running `bash ./launch/run_voc_resnet38.sh` will create a directory `./logs/pascal_voc/baselines/v01` with tensorboard events and will save snapshots into `./snapshots/pascal_voc/baselines/v01`.
56 |
57 | **Inference.** To generate final masks, please, use the script `./launch/infer_val.sh`. You will need to specify:
58 | * `EXP` and `RUN_ID` you used for training;
59 | * `OUTPUT_DIR` the path where to save the masks;
60 | * `FILELIST` specifies the file to the data split;
61 | * `SNAPSHOT` specifies the model suffix in the format `e000Xs0.000`. For example, `e020Xs0.928`;
62 | * (optionally) `EXTRA_ARGS` specify additional arguments to the inference script.
63 |
64 | **Evaluation.** To compute IoU of the masks, please, run `./launch/eval_seg.sh`. You will need to specify `SAVE_DIR` that contains the masks and `FILELIST` specifying the split for evaluation.
65 |
66 | ### Pre-trained model
67 | For testing, we provide our pre-trained WideResNet38 model:
68 |
69 | | Backbone | Val | Val (+ CRF) | Link |
70 | |:---:|:---:|:---:|---:|
71 | | WideResNet38 | 59.7 | 62.7 | [model_enc_e020Xs0.928.pth (527M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/models/model_enc_e020Xs0.928.pth) |
72 |
73 | The also release the masks predicted by this model:
74 |
75 | | Split | IoU | IoU (+ CRF) | Link | Comment |
76 | |:---:|:---:|:---:|:---:|:---:|
77 | | train-clean (VOC+SBD) | 64.7 | 66.9 | [train_results_clean.tgz (2.9G)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/results/train_results_clean.tgz) | Reported IoU is for VOC |
78 | | val-clean | 63.4 | 65.3 | [val_results_clean.tgz (423M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/results/val_results_clean.tgz) | |
79 | | val | 59.7 | 62.7 | [val_results.tgz (427M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/results/val_results.tgz) | |
80 | | test | 62.7 | 64.3 | [test_results.tgz (368M)](https://download.visinf.tu-darmstadt.de/data/2020-cvpr-araslanov-1-stage-wseg/results/test_results.tgz) | |
81 |
82 | The suffix `-clean` means we used ground-truth image-level labels to remove masks of the categories not present in the image.
83 | These masks are commonly used as pseudo ground truth to train another segmentation model in fully supervised regime.
84 |
85 | ## Acknowledgements
86 | We thank PyTorch team, and Jiwoon Ahn for releasing his [code](https://github.com/jiwoon-ahn/psa) that helped in the early stages of this project.
87 |
88 | ## Citation
89 | We hope that you find this work useful. If you would like to acknowledge us, please, use the following citation:
90 | ```
91 | @InProceedings{Araslanov:2020:SSS,
92 | author = {Araslanov, Nikita and Roth, Stefan},
93 | title = {Single-Stage Semantic Segmentation From Image Labels},
94 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
95 | month = {June},
96 | pages = {4253--4262}
97 | year = {2020}
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/base_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import torch
4 | import math
5 | import numpy as np
6 |
7 | import torchvision.utils as vutils
8 |
9 | from PIL import Image, ImageDraw, ImageFont
10 | from utils.checkpoints import Checkpoint
11 |
12 | try: # backward compatibility
13 | from tensorboardX import SummaryWriter
14 | except ImportError:
15 | from torch.utils.tensorboard import SummaryWriter
16 |
17 | from core.config import cfg, cfg_from_file, cfg_from_list
18 |
19 | class BaseTrainer(object):
20 |
21 | def __del__(self):
22 | # commented out, because hangs on exit
23 | # (presumably some bug with threading in TensorboardX)
24 | """
25 | if not self.quiet:
26 | self.writer.close()
27 | self.writer_val.close()
28 | """
29 | pass
30 |
31 | def __init__(self, args, quiet=False):
32 | self.args = args
33 | self.quiet = quiet
34 |
35 | # config
36 | # Reading the config
37 | if type(args.cfg_file) is str \
38 | and os.path.isfile(args.cfg_file):
39 |
40 | cfg_from_file(args.cfg_file)
41 | if args.set_cfgs is not None:
42 | cfg_from_list(args.set_cfgs)
43 |
44 | self.start_epoch = 0
45 | self.best_score = -1e16
46 | self.checkpoint = Checkpoint(args.snapshot_dir, max_n = 5)
47 |
48 | if not quiet:
49 | #self.model_id = "%s" % args.run
50 | logdir = os.path.join(args.logdir, 'train')
51 | logdir_val = os.path.join(args.logdir, 'val')
52 |
53 | self.writer = SummaryWriter(logdir)
54 | self.writer_val = SummaryWriter(logdir_val)
55 |
56 | def _define_checkpoint(self, name, model, optim):
57 | self.checkpoint.add_model(name, model, optim)
58 |
59 | def _load_checkpoint(self, suffix):
60 | if self.checkpoint.load(suffix):
61 | # loading the epoch and the best score
62 | tmpl = re.compile("^e(\d+)Xs([\.\d+\-]+)$")
63 | match = tmpl.match(suffix)
64 | if not match:
65 | print("Warning: epoch and score could not be recovered")
66 | return
67 | else:
68 | epoch, score = match.groups()
69 | self.start_epoch = int(epoch) + 1
70 | self.best_score = float(score)
71 |
72 | def checkpoint_epoch(self, score, epoch):
73 |
74 | if score > self.best_score:
75 | self.best_score = score
76 |
77 | print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format(score, epoch))
78 | suffix = "e{:03d}Xs{:4.3f}".format(epoch, score)
79 | self.checkpoint.checkpoint(suffix)
80 |
81 | return True
82 |
83 | def checkpoint_best(self, score, epoch):
84 |
85 | if score > self.best_score:
86 | print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format(score, epoch))
87 | self.best_score= score
88 |
89 | suffix = "e{:03d}Xs{:4.3f}".format(epoch, score)
90 | self.checkpoint.checkpoint(suffix)
91 |
92 | return True
93 |
94 | return False
95 |
96 | @staticmethod
97 | def get_optim(params, cfg):
98 |
99 | if not hasattr(torch.optim, cfg.OPT):
100 | print("Optimiser {} not supported".format(cfg.OPT))
101 | raise NotImplementedError
102 |
103 | optim = getattr(torch.optim, cfg.OPT)
104 |
105 | if cfg.OPT == 'Adam':
106 | upd = torch.optim.Adam(params, lr=cfg.LR, \
107 | betas=(cfg.BETA1, 0.999), \
108 | weight_decay=cfg.WEIGHT_DECAY)
109 | elif cfg.OPT == 'SGD':
110 | print("Using SGD >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}".format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY))
111 | upd = torch.optim.SGD(params, lr=cfg.LR, \
112 | momentum=cfg.MOMENTUM, \
113 | weight_decay=cfg.WEIGHT_DECAY)
114 |
115 | else:
116 | upd = optim(params, lr=cfg.LR)
117 |
118 | upd.zero_grad()
119 |
120 | return upd
121 |
122 | @staticmethod
123 | def set_lr(optim, lr):
124 | for param_group in optim.param_groups:
125 | param_group['lr'] = lr
126 |
127 |
128 | def _visualise_grid(self, x_all, labels, t, ious=None, tag="visualisation", scores=None):
129 |
130 | # adding the labels to images
131 | bs, ch, h, w = x_all.size()
132 | x_all_new = torch.zeros(bs, ch, h + 16, w)
133 | _, y_labels_idx = torch.max(labels, -1)
134 | classNamesOffset = len(self.classNames) - labels.size(1)
135 | classNames = self.classNames[classNamesOffset:]
136 | for b in range(bs):
137 | label_idx = labels[b]
138 | label_names = [name for i,name in enumerate(classNames) if label_idx[i].item()]
139 | label_name = ", ".join(label_names)
140 |
141 | ndarr = x_all[b].mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
142 | arr = np.zeros((16, w, ch), dtype=ndarr.dtype)
143 | ndarr = np.concatenate((arr, ndarr), 0)
144 | im = Image.fromarray(ndarr)
145 | draw = ImageDraw.Draw(im)
146 |
147 | font = ImageFont.truetype("fonts/UbuntuMono-R.ttf", 12)
148 |
149 | # draw.text((x, y),"Sample Text",(r,g,b))
150 | draw.text((5, 1), label_name, (255,255,255), font=font)
151 | im_np = np.array(im).astype(np.float)
152 | x_all_new[b] = (torch.from_numpy(im_np)/255.0).permute(2,0,1)
153 |
154 | summary_grid = vutils.make_grid(x_all_new, nrow=1, padding=8, pad_value=0.9)
155 | self.writer.add_image(tag, summary_grid, t)
156 |
157 | def _apply_cmap(self, mask_idx, mask_conf):
158 | palette = self.trainloader.dataset.get_palette()
159 |
160 | masks = []
161 | col = Colorize()
162 | mask_conf = mask_conf.float() / 255.0
163 | for mask, conf in zip(mask_idx.split(1), mask_conf.split(1)):
164 | m = col(mask).float()
165 | m = m * conf
166 | masks.append(m[None, ...])
167 |
168 | return torch.cat(masks, 0)
169 |
170 | def _mask_rgb(self, masks, image_norm, alpha=0.3):
171 | # visualising masks
172 | masks_conf, masks_idx = torch.max(masks, 1)
173 | masks_conf = masks_conf - F.relu(masks_conf - 1, 0)
174 |
175 | masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), masks_conf.cpu())
176 | return alpha * image_norm + (1 - alpha) * masks_idx_rgb
177 |
178 | def _init_norm(self):
179 | self.trainloader.dataset.set_norm(self.enc.normalize)
180 | self.valloader.dataset.set_norm(self.enc.normalize)
181 | self.trainloader_val.dataset.set_norm(self.enc.normalize)
182 |
--------------------------------------------------------------------------------
/configs/voc_resnet101.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 1
2 | DATASET:
3 | CROP_SIZE: 448
4 | SCALE_FROM: 0.9
5 | SCALE_TO: 1.0
6 | TRAIN:
7 | BATCH_SIZE: 16
8 | NUM_EPOCHS: 20
9 | NUM_WORKERS: 8
10 | PRETRAIN: 5
11 | NET:
12 | BACKBONE: "resnet101"
13 | MODEL: "ae"
14 | PRE_WEIGHTS_PATH: "./models/weights/resnet101-5d3b4d8f.pth"
15 | LR: 0.001
16 | OPT: "SGD"
17 | LOSS: "SoftMargin"
18 | WEIGHT_DECAY: 0.0005
19 | TEST:
20 | METHOD: "multiscale" # multiscale | crop
21 | DATA_ROOT: "/fastdata/naraslanov"
22 | FLIP: True
23 | BATCH_SIZE: 8 # 4 scales, +1 flip for each
24 | PAD_SIZE: [768, 768]
25 | SCALES: [1, 0.75, 1.25, 1.5]
26 | FP_CUT_SCORE: 0.3
27 | USE_GT_LABELS: True
28 |
--------------------------------------------------------------------------------
/configs/voc_resnet38.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 1
2 | DATASET:
3 | CROP_SIZE: 321
4 | SCALE_FROM: 0.9
5 | SCALE_TO: 1.0
6 | TRAIN:
7 | BATCH_SIZE: 16
8 | NUM_EPOCHS: 24
9 | NUM_WORKERS: 8
10 | PRETRAIN: 5
11 | NET:
12 | BACKBONE: "resnet38"
13 | MODEL: "ae"
14 | PRE_WEIGHTS_PATH: "./models/weights/ilsvrc-cls_rna-a1_cls1000_ep-0001.pth"
15 | LR: 0.001
16 | OPT: "SGD"
17 | LOSS: "SoftMargin"
18 | WEIGHT_DECAY: 0.0005
19 | PAMR_ITER: 10
20 | FOCAL_LAMBDA: 0.01
21 | FOCAL_P: 3
22 | SG_PSI: 0.3
23 | TEST:
24 | METHOD: "multiscale"
25 | DATA_ROOT: "/fastdata/naraslanov"
26 | FLIP: True
27 | BATCH_SIZE: 8 # 4 scales, +1 flip for each
28 | PAD_SIZE: [1024, 1024]
29 | SCALES: [1, 0.5, 1.5, 2.0]
30 | FP_CUT_SCORE: 0.1
31 | BG_POW: 3
32 | USE_GT_LABELS: False
33 |
--------------------------------------------------------------------------------
/configs/voc_resnet50.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 1
2 | DATASET:
3 | CROP_SIZE: 448
4 | SCALE_FROM: 0.9
5 | SCALE_TO: 1.0
6 | TRAIN:
7 | BATCH_SIZE: 16
8 | NUM_EPOCHS: 20
9 | NUM_WORKERS: 8
10 | PRETRAIN: 5
11 | NET:
12 | BACKBONE: "resnet50"
13 | MODEL: "ae"
14 | PRE_WEIGHTS_PATH: "./models/weights/resnet50-19c8e357.pth"
15 | LR: 0.0005
16 | OPT: "SGD"
17 | LOSS: "SoftMargin"
18 | WEIGHT_DECAY: 0.0005
19 | TEST:
20 | METHOD: "multiscale" # multiscale | crop
21 | DATA_ROOT: "/fastdata/naraslanov"
22 | FLIP: True
23 | BATCH_SIZE: 8 # 4 scales, +1 flip for each
24 | PAD_SIZE: [768, 768]
25 | SCALES: [1, 0.75, 1.25, 1.5]
26 | FP_CUT_SCORE: 0.3
27 | USE_GT_LABELS: True
28 |
--------------------------------------------------------------------------------
/configs/voc_vgg16.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 1
2 | DATASET:
3 | CROP_SIZE: 448
4 | SCALE_FROM: 0.9
5 | SCALE_TO: 1.0
6 | TRAIN:
7 | BATCH_SIZE: 16
8 | NUM_EPOCHS: 32
9 | NUM_WORKERS: 8
10 | PRETRAIN: 5
11 | NET:
12 | BACKBONE: "vgg16"
13 | MODEL: "ae"
14 | PRE_WEIGHTS_PATH: "./models/weights/vgg16_20M.pth"
15 | LR: 0.001
16 | OPT: "SGD"
17 | LOSS: "SoftMargin"
18 | WEIGHT_DECAY: 0.0005
19 | TEST:
20 | METHOD: "multiscale" # multiscale | crop
21 | DATA_ROOT: "/fastdata/naraslanov"
22 | FLIP: True
23 | BATCH_SIZE: 8 # 4 scales, +1 flip for each
24 | PAD_SIZE: [768, 768]
25 | SCALES: [1, 0.75, 1.25, 1.5]
26 | FP_CUT_SCORE: 0.3
27 | USE_GT_LABELS: True
28 | BG_POW: 1
29 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/core/__init__.py
--------------------------------------------------------------------------------
/core/config.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from __future__ import unicode_literals
5 |
6 |
7 | import yaml
8 | import six
9 | import os
10 | import os.path as osp
11 | import copy
12 | from ast import literal_eval
13 |
14 | import numpy as np
15 | from packaging import version
16 |
17 | from utils.collections import AttrDict
18 |
19 | __C = AttrDict()
20 | # Consumers can get config by:
21 | # from fast_rcnn_config import cfg
22 | cfg = __C
23 |
24 | __C.NUM_GPUS = 1
25 | # Random note: avoid using '.ON' as a config key since yaml converts it to True;
26 | # prefer 'ENABLED' instead
27 |
28 | # ---------------------------------------------------------------------------- #
29 | # Training options
30 | # ---------------------------------------------------------------------------- #
31 | __C.TRAIN = AttrDict()
32 | __C.TRAIN.BATCH_SIZE = 20
33 | __C.TRAIN.NUM_EPOCHS = 15
34 | __C.TRAIN.NUM_WORKERS = 4
35 | __C.TRAIN.MASK_LOSS = 0.0
36 | __C.TRAIN.PRETRAIN = 5
37 |
38 | # ---------------------------------------------------------------------------- #
39 | # Inference options
40 | # ---------------------------------------------------------------------------- #
41 | __C.TEST = AttrDict()
42 | __C.TEST.METHOD = "multiscale" # multiscale | crop
43 | __C.TEST.DATA_ROOT = "/data/your_directory"
44 | __C.TEST.SCALES = [1, 0.5, 1.5, 2.0]
45 | __C.TEST.FLIP = True
46 | __C.TEST.PAD_SIZE = [1024, 1024]
47 | __C.TEST.CROP_SIZE = [448, 448]
48 | __C.TEST.CROP_GRID_SIZE = [2, 2]
49 | __C.TEST.BATCH_SIZE = 8
50 | __C.TEST.BG_POW = 3
51 | __C.TEST.NUM_CLASSES = 21
52 |
53 | # use ground-truth labels to remove
54 | # false positive masks
55 | __C.TEST.USE_GT_LABELS = False
56 |
57 | # if class confidence does not exceed this threshold
58 | # the mask is removed (count as false positive)
59 | # used only if MASKS.USE_GT_LABELS is False
60 | __C.TEST.FP_CUT_SCORE = 0.1
61 |
62 | # ---------------------------------------------------------------------------- #
63 | # Dataset options
64 | # ---------------------------------------------------------------------------- #
65 | __C.DATASET = AttrDict()
66 |
67 | __C.DATASET.CROP_SIZE = 321
68 | __C.DATASET.SCALE_FROM = 0.9
69 | __C.DATASET.SCALE_TO = 1.0
70 | __C.DATASET.PATH = "data/images"
71 |
72 | # ---------------------------------------------------------------------------- #
73 | # Network options
74 | # ---------------------------------------------------------------------------- #
75 | __C.NET = AttrDict()
76 | __C.NET.MODEL = 'vgg16'
77 | __C.NET.BACKBONE = 'resnet50'
78 | __C.NET.PRE_WEIGHTS_PATH = ""
79 | __C.NET.OPT = 'SGD'
80 | __C.NET.LR = 0.001
81 | __C.NET.BETA1 = 0.5
82 | __C.NET.MOMENTUM = 0.9
83 | __C.NET.WEIGHT_DECAY = 1e-5
84 | __C.NET.LOSS = 'SoftMargin'
85 | __C.NET.MASK_LOSS_BCE = 1.0
86 | __C.NET.BG_SCORE = 0.1 # background score (only for CAM)
87 | __C.NET.FOCAL_P = 3
88 | __C.NET.FOCAL_LAMBDA = 0.01
89 | __C.NET.PAMR_KERNEL = [1, 2, 4, 8, 12, 24]
90 | __C.NET.PAMR_ITER = 10
91 | __C.NET.SG_PSI = 0.3
92 |
93 | # Mask Inference
94 | __C.MASKS = AttrDict()
95 |
96 | # CRF options
97 | __C.MASKS.CRF = AttrDict()
98 | __C.MASKS.CRF.ALPHA_LOW = 4
99 | __C.MASKS.CRF.ALPHA_HIGH = 32
100 |
101 | # [Infered value]
102 | __C.CUDA = False
103 |
104 | __C.DEBUG = False
105 |
106 | # [Infered value]
107 | __C.PYTORCH_VERSION_LESS_THAN_040 = False
108 |
109 | def assert_and_infer_cfg(make_immutable=True):
110 | """Call this function in your script after you have finished setting all cfg
111 | values that are necessary (e.g., merging a config from a file, merging
112 | command line config options, etc.). By default, this function will also
113 | mark the global cfg as immutable to prevent changing the global cfg settings
114 | during script execution (which can lead to hard to debug errors or code
115 | that's harder to understand than is necessary).
116 | """
117 | if make_immutable:
118 | cfg.immutable(True)
119 |
120 |
121 | def merge_cfg_from_file(cfg_filename):
122 | """Load a yaml config file and merge it into the global config."""
123 | with open(cfg_filename, 'r') as f:
124 | if hasattr(yaml, "FullLoader"):
125 | yaml_cfg = AttrDict(yaml.load(f, Loader=yaml.FullLoader))
126 | else:
127 | yaml_cfg = AttrDict(yaml.load(f))
128 |
129 | _merge_a_into_b(yaml_cfg, __C)
130 |
131 | cfg_from_file = merge_cfg_from_file
132 |
133 |
134 | def merge_cfg_from_cfg(cfg_other):
135 | """Merge `cfg_other` into the global config."""
136 | _merge_a_into_b(cfg_other, __C)
137 |
138 |
139 | def merge_cfg_from_list(cfg_list):
140 | """Merge config keys, values in a list (e.g., from command line) into the
141 | global config. For example, `cfg_list = ['TEST.NMS', 0.5]`.
142 | """
143 | assert len(cfg_list) % 2 == 0
144 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
145 | key_list = full_key.split('.')
146 | d = __C
147 | for subkey in key_list[:-1]:
148 | assert subkey in d, 'Non-existent key: {}'.format(full_key)
149 | d = d[subkey]
150 | subkey = key_list[-1]
151 | assert subkey in d, 'Non-existent key: {}'.format(full_key)
152 | value = _decode_cfg_value(v)
153 | value = _check_and_coerce_cfg_value_type(
154 | value, d[subkey], subkey, full_key
155 | )
156 | d[subkey] = value
157 |
158 | cfg_from_list = merge_cfg_from_list
159 |
160 |
161 | def _merge_a_into_b(a, b, stack=None):
162 | """Merge config dictionary a into config dictionary b, clobbering the
163 | options in b whenever they are also specified in a.
164 | """
165 | assert isinstance(a, AttrDict), 'Argument `a` must be an AttrDict'
166 | assert isinstance(b, AttrDict), 'Argument `b` must be an AttrDict'
167 |
168 | for k, v_ in a.items():
169 | full_key = '.'.join(stack) + '.' + k if stack is not None else k
170 | # a must specify keys that are in b
171 | if k not in b:
172 | raise KeyError('Non-existent config key: {}'.format(full_key))
173 |
174 | v = copy.deepcopy(v_)
175 | v = _decode_cfg_value(v)
176 | v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
177 |
178 | # Recursively merge dicts
179 | if isinstance(v, AttrDict):
180 | try:
181 | stack_push = [k] if stack is None else stack + [k]
182 | _merge_a_into_b(v, b[k], stack=stack_push)
183 | except BaseException:
184 | raise
185 | else:
186 | b[k] = v
187 |
188 |
189 | def _decode_cfg_value(v):
190 | """Decodes a raw config value (e.g., from a yaml config files or command
191 | line argument) into a Python object.
192 | """
193 | # Configs parsed from raw yaml will contain dictionary keys that need to be
194 | # converted to AttrDict objects
195 | if isinstance(v, dict):
196 | return AttrDict(v)
197 | # All remaining processing is only applied to strings
198 | if not isinstance(v, six.string_types):
199 | return v
200 | # Try to interpret `v` as a:
201 | # string, number, tuple, list, dict, boolean, or None
202 | try:
203 | v = literal_eval(v)
204 | # The following two excepts allow v to pass through when it represents a
205 | # string.
206 | #
207 | # Longer explanation:
208 | # The type of v is always a string (before calling literal_eval), but
209 | # sometimes it *represents* a string and other times a data structure, like
210 | # a list. In the case that v represents a string, what we got back from the
211 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
212 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
213 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
214 | # will raise a SyntaxError.
215 | except ValueError:
216 | pass
217 | except SyntaxError:
218 | pass
219 | return v
220 |
221 |
222 | def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key):
223 | """Checks that `value_a`, which is intended to replace `value_b` is of the
224 | right type. The type is correct if it matches exactly or is one of a few
225 | cases in which the type can be easily coerced.
226 | """
227 | # The types must match (with some exceptions)
228 | type_b = type(value_b)
229 | type_a = type(value_a)
230 | if type_a is type_b:
231 | return value_a
232 |
233 | # Exceptions: numpy arrays, strings, tuple<->list
234 | if isinstance(value_b, np.ndarray):
235 | value_a = np.array(value_a, dtype=value_b.dtype)
236 | elif isinstance(value_b, six.string_types):
237 | value_a = str(value_a)
238 | elif isinstance(value_a, tuple) and isinstance(value_b, list):
239 | value_a = list(value_a)
240 | elif isinstance(value_a, list) and isinstance(value_b, tuple):
241 | value_a = tuple(value_a)
242 | else:
243 | raise ValueError(
244 | 'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
245 | 'key: {}'.format(type_b, type_a, value_b, value_a, full_key)
246 | )
247 | return value_a
248 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data
2 | from .pascal_voc import VOCSegmentation
3 |
4 | datasets = {
5 | 'pascal_voc': VOCSegmentation
6 | }
7 |
8 | def get_num_classes(args):
9 | return datasets[args.dataset.lower()].NUM_CLASS
10 |
11 | def get_class_names(args):
12 | return datasets[args.dataset.lower()].CLASSES
13 |
14 | def get_dataloader(args, cfg, split, batch_size=None, test_mode=None):
15 | assert split in ('train', 'train_voc', 'val'), "Unknown split '{}'".format(split)
16 |
17 | dataset_name = args.dataset.lower()
18 | dataset_cls = datasets[dataset_name]
19 | dataset = dataset_cls(cfg, split, test_mode)
20 |
21 | kwargs = {'num_workers': args.workers, 'pin_memory': True}
22 | shuffle, drop_last = [True, True] if split == 'train' else [False, False]
23 |
24 | if batch_size is None:
25 | batch_size = cfg.TRAIN.BATCH_SIZE
26 |
27 | return data.DataLoader(dataset, batch_size=batch_size,
28 | drop_last=drop_last, shuffle=shuffle,
29 | **kwargs)
30 |
--------------------------------------------------------------------------------
/datasets/pascal_voc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset
5 | from PIL import Image, ImagePalette
6 | from .utils import colormap
7 | import datasets.transforms as tf
8 |
9 | class PascalVOC(Dataset):
10 |
11 | CLASSES = [
12 | 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
13 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
14 | 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
15 | 'tv/monitor', 'ambiguous'
16 | ]
17 |
18 | CLASS_IDX = {
19 | 'background': 0,
20 | 'aeroplane': 1,
21 | 'bicycle': 2,
22 | 'bird': 3,
23 | 'boat': 4,
24 | 'bottle': 5,
25 | 'bus': 6,
26 | 'car': 7,
27 | 'cat': 8,
28 | 'chair': 9,
29 | 'cow': 10,
30 | 'diningtable': 11,
31 | 'dog': 12,
32 | 'horse': 13,
33 | 'motorbike': 14,
34 | 'person': 15,
35 | 'potted-plant': 16,
36 | 'sheep': 17,
37 | 'sofa': 18,
38 | 'train': 19,
39 | 'tv/monitor': 20,
40 | 'ambiguous': 255
41 | }
42 |
43 | CLASS_IDX_INV = {
44 | 0: 'background',
45 | 1: 'aeroplane',
46 | 2: 'bicycle',
47 | 3: 'bird',
48 | 4: 'boat',
49 | 5: 'bottle',
50 | 6: 'bus',
51 | 7: 'car',
52 | 8: 'cat',
53 | 9: 'chair',
54 | 10: 'cow',
55 | 11: 'diningtable',
56 | 12: 'dog',
57 | 13: 'horse',
58 | 14: 'motorbike',
59 | 15: 'person',
60 | 16: 'potted-plant',
61 | 17: 'sheep',
62 | 18: 'sofa',
63 | 19: 'train',
64 | 20: 'tv/monitor',
65 | 255: 'ambiguous'}
66 |
67 | NUM_CLASS = 21
68 |
69 | MEAN = (0.485, 0.456, 0.406)
70 | STD = (0.229, 0.224, 0.225)
71 |
72 | def __init__(self):
73 | super().__init__()
74 | self._init_palette()
75 |
76 | def _init_palette(self):
77 | self.cmap = colormap()
78 | self.palette = ImagePalette.ImagePalette()
79 | for rgb in self.cmap:
80 | self.palette.getcolor(rgb)
81 |
82 | def get_palette(self):
83 | return self.palette
84 |
85 | def denorm(self, image):
86 |
87 | if image.dim() == 3:
88 | assert image.dim() == 3, "Expected image [CxHxW]"
89 | assert image.size(0) == 3, "Expected RGB image [3xHxW]"
90 |
91 | for t, m, s in zip(image, self.MEAN, self.STD):
92 | t.mul_(s).add_(m)
93 | elif image.dim() == 4:
94 | # batch mode
95 | assert image.size(1) == 3, "Expected RGB image [3xHxW]"
96 |
97 | for t, m, s in zip((0,1,2), self.MEAN, self.STD):
98 | image[:, t, :, :].mul_(s).add_(m)
99 |
100 | return image
101 |
102 |
103 | class VOCSegmentation(PascalVOC):
104 |
105 | def __init__(self, cfg, split, test_mode, root=os.path.expanduser('./data')):
106 | super(VOCSegmentation, self).__init__()
107 |
108 | self.cfg = cfg
109 | self.root = root
110 | self.split = split
111 | self.test_mode = test_mode
112 |
113 | # train/val/test splits are pre-cut
114 | if self.split == 'train':
115 | _split_f = os.path.join(self.root, 'train_augvoc.txt')
116 | elif self.split == 'train_voc':
117 | _split_f = os.path.join(self.root, 'train_voc.txt')
118 | elif self.split == 'val':
119 | _split_f = os.path.join(self.root, 'val_voc.txt')
120 | elif self.split == 'test':
121 | _split_f = os.path.join(self.root, 'test.txt')
122 | else:
123 | raise RuntimeError('Unknown dataset split.')
124 |
125 | assert os.path.isfile(_split_f), "%s not found" % _split_f
126 |
127 | self.images = []
128 | self.masks = []
129 | with open(_split_f, "r") as lines:
130 | for line in lines:
131 | _image, _mask = line.strip("\n").split(' ')
132 | _image = os.path.join(self.root, _image)
133 | assert os.path.isfile(_image), '%s not found' % _image
134 | self.images.append(_image)
135 |
136 | if self.split != 'test':
137 | _mask = os.path.join(self.root, _mask.lstrip('/'))
138 | assert os.path.isfile(_mask), '%s not found' % _mask
139 | self.masks.append(_mask)
140 |
141 | if self.split != 'test':
142 | assert (len(self.images) == len(self.masks))
143 | if self.split == 'train':
144 | assert len(self.images) == 10582
145 | elif self.split == 'val':
146 | assert len(self.images) == 1449
147 |
148 | self.transform = tf.Compose([tf.MaskRandResizedCrop(self.cfg.DATASET), \
149 | tf.MaskHFlip(), \
150 | tf.MaskColourJitter(p = 1.0), \
151 | tf.MaskNormalise(self.MEAN, self.STD), \
152 | tf.MaskToTensor()])
153 |
154 | def __len__(self):
155 | return len(self.images)
156 |
157 | def __getitem__(self, index):
158 |
159 | image = Image.open(self.images[index]).convert('RGB')
160 | mask = Image.open(self.masks[index])
161 |
162 | unique_labels = np.unique(mask)
163 |
164 | # ambigious
165 | if unique_labels[-1] == self.CLASS_IDX['ambiguous']:
166 | unique_labels = unique_labels[:-1]
167 |
168 | # ignoring BG
169 | labels = torch.zeros(self.NUM_CLASS - 1)
170 | if unique_labels[0] == self.CLASS_IDX['background']:
171 | unique_labels = unique_labels[1:]
172 | unique_labels -= 1 # shifting since no BG class
173 |
174 | assert unique_labels.size > 0, 'No labels found in %s' % self.masks[index]
175 | labels[unique_labels.tolist()] = 1
176 |
177 | # general resize, normalize and toTensor
178 | image, mask = self.transform(image, mask)
179 |
180 | return image, labels, os.path.basename(self.images[index])
181 |
182 | @property
183 | def pred_offset(self):
184 | return 0
185 |
--------------------------------------------------------------------------------
/datasets/pascal_voc_ms.py:
--------------------------------------------------------------------------------
1 | """
2 | Multi-scale dataloader
3 | Credit PSA
4 | """
5 |
6 | from PIL import Image
7 | from .pascal_voc import PascalVOC
8 |
9 | import math
10 | import numpy as np
11 | import torch
12 | import os.path
13 | import torchvision.transforms.functional as F
14 |
15 |
16 | def load_img_name_list(dataset_path, index = 0):
17 |
18 | img_gt_name_list = open(dataset_path).read().splitlines()
19 | img_name_list = [img_gt_name.split(' ')[index].strip('/') for img_gt_name in img_gt_name_list]
20 |
21 | return img_name_list
22 |
23 | def load_label_name_list(dataset_path):
24 | return load_img_name_list(dataset_path, index = 1)
25 |
26 | class VOC12ImageDataset(PascalVOC):
27 |
28 | def __init__(self, img_name_list_path, voc12_root):
29 | super().__init__()
30 |
31 | self.img_name_list = load_img_name_list(img_name_list_path)
32 | self.voc12_root = voc12_root
33 | self.batch_size = 1
34 |
35 | def __len__(self):
36 | return len(self.img_name_list)
37 |
38 | def __getitem__(self, idx):
39 | fullpath = os.path.join(self.voc12_root, self.img_name_list[idx])
40 | img = Image.open(fullpath).convert("RGB")
41 | return fullpath, img
42 |
43 | class VOC12ClsDataset(VOC12ImageDataset):
44 |
45 | def __init__(self, img_name_list_path, voc12_root):
46 | super(VOC12ClsDataset, self).__init__(img_name_list_path, voc12_root)
47 | self.label_list = load_label_name_list(img_name_list_path)
48 |
49 | def __len__(self):
50 | return self.batch_size * len(self.img_name_list)
51 |
52 | def _pad(self, image):
53 | w, h = image.size
54 |
55 | pad_mask = Image.new("L", image.size)
56 | pad_height = self.pad_size[0] - h
57 | pad_width = self.pad_size[1] - w
58 |
59 | assert pad_height >= 0 and pad_width >= 0
60 |
61 | pad_l = max(0, pad_width // 2)
62 | pad_r = max(0, pad_width - pad_l)
63 | pad_t = max(0, pad_height // 2)
64 | pad_b = max(0, pad_height - pad_t)
65 |
66 | image = F.pad(image, (pad_l, pad_t, pad_r, pad_b), fill=0, padding_mode="constant")
67 | pad_mask = F.pad(pad_mask, (pad_l, pad_t, pad_r, pad_b), fill=1, padding_mode="constant")
68 |
69 | return image, pad_mask, [pad_t, pad_l]
70 |
71 | def __getitem__(self, idx):
72 | name, img = super(VOC12ClsDataset, self).__getitem__(idx)
73 |
74 | label_fullpath = self.label_list[idx]
75 | assert len(label_fullpath) < 256, "Expected label path less than 256 for padding"
76 |
77 | mask = Image.open(os.path.join(self.voc12_root, label_fullpath))
78 | mask = np.array(mask)
79 |
80 | labels = torch.zeros(self.NUM_CLASS - 1)
81 |
82 | # it will also be sorted
83 | unique_labels = np.unique(mask)
84 |
85 | # ambigious
86 | if unique_labels[-1] == self.CLASS_IDX['ambiguous']:
87 | unique_labels = unique_labels[:-1]
88 |
89 | # background
90 | if unique_labels[0] == self.CLASS_IDX['background']:
91 | unique_labels = unique_labels[1:]
92 |
93 | assert unique_labels.size > 0, 'No labels found in %s' % self.masks[index]
94 | unique_labels -= 1 # shifting since no BG class
95 | labels[unique_labels.tolist()] = 1
96 |
97 | return name, img, labels, mask.astype(np.int)
98 |
99 |
100 | class MultiscaleLoader(VOC12ClsDataset):
101 |
102 | def __init__(self, img_list, cfg, transform):
103 | super().__init__(img_list, cfg.DATA_ROOT)
104 |
105 | self.scales = cfg.SCALES
106 | self.pad_size = cfg.PAD_SIZE
107 | self.use_flips = cfg.FLIP
108 | self.transform = transform
109 |
110 | self.batch_size = len(self.scales)
111 | if self.use_flips:
112 | self.batch_size *= 2
113 |
114 | print("Inference batch size: ", self.batch_size)
115 | assert self.batch_size == cfg.BATCH_SIZE
116 |
117 | def __getitem__(self, idx):
118 | im_idx = idx // self.batch_size
119 | sub_idx = idx % self.batch_size
120 |
121 | scale = self.scales[sub_idx // (2 if self.use_flips else 1)]
122 | flip = self.use_flips and sub_idx % 2
123 |
124 | name, img, label, mask = super().__getitem__(im_idx)
125 |
126 | target_size = (int(round(img.size[0]*scale)),
127 | int(round(img.size[1]*scale)))
128 |
129 | s_img = img.resize(target_size, resample=Image.CUBIC)
130 |
131 | if flip:
132 | s_img = F.hflip(s_img)
133 |
134 | w, h = s_img.size
135 | im_msc, ignore, pads_tl = self._pad(s_img)
136 | pad_t, pad_l = pads_tl
137 |
138 | im_msc = self.transform(im_msc)
139 | img = F.to_tensor(self.transform(img))
140 |
141 | pads = torch.Tensor([pad_t, pad_l, h, w])
142 |
143 | ignore = np.array(ignore).astype(im_msc.dtype)[..., np.newaxis]
144 | im_msc = F.to_tensor(im_msc * (1 - ignore))
145 |
146 | return name, img, im_msc, pads, label, mask
147 |
148 |
149 | class CropLoader(VOC12ClsDataset):
150 |
151 | def __init__(self, img_list, cfg, transform):
152 | super().__init__(img_list, cfg.DATA_ROOT)
153 |
154 | self.use_flips = cfg.FLIP
155 | self.transform = transform
156 |
157 | self.grid_h, self.grid_w = cfg.CROP_GRID_SIZE
158 | self.crop_h, self.crop_w = cfg.CROP_SIZE
159 | self.pad_size = cfg.PAD_SIZE
160 |
161 | self.stride_h = int(math.ceil(self.pad_size[0] / self.grid_h))
162 | self.stride_w = int(math.ceil(self.pad_size[1] / self.grid_w))
163 |
164 | assert self.stride_h <= self.crop_h and \
165 | self.stride_w <= self.crop_w
166 |
167 | self.batch_size = self.grid_h * self.grid_w
168 | if self.use_flips:
169 | self.batch_size *= 2
170 |
171 | print("Inference batch size: ", self.batch_size)
172 | assert self.batch_size == cfg.BATCH_SIZE
173 |
174 |
175 | def __getitem__(self, index):
176 | image_index = index // self.batch_size
177 | batch_index = index % self.batch_size
178 | grid_index = batch_index // (2 if self.use_flips else 1)
179 |
180 | index_h = grid_index // self.grid_w
181 | index_w = grid_index % self.grid_w
182 | flip = self.use_flips and batch_index % 2 == 0
183 |
184 | name, image, label, mask = super().__getitem__(image_index)
185 |
186 | image_pad, pad_mask, pads = self._pad(image)
187 | assert image_pad.size[0] == self.pad_size[1] and \
188 | image_pad.size[1] == self.pad_size[0]
189 |
190 | s_h = index_h * self.stride_h
191 | e_h = min(s_h + self.crop_h, self.pad_size[0])
192 | s_h = e_h - self.crop_h
193 |
194 | s_w = index_w * self.stride_w
195 | e_w = min(s_w + self.crop_w, self.pad_size[1])
196 | s_w = e_w - self.crop_w
197 |
198 | image_pad = self.transform(image_pad)
199 | pad_mask = np.array(pad_mask).astype(image_pad.dtype)[..., np.newaxis]
200 | image_pad *= 1 - pad_mask
201 |
202 | image_pad = F.to_tensor(image_pad)
203 | image_crop = image_pad[:, s_h:e_h, s_w:e_w].clone()
204 |
205 | pads = torch.LongTensor([s_h, e_h, s_w, e_w] + pads)
206 |
207 | if flip:
208 | image_crop = image_crop.flip(-1)
209 |
210 | image = F.to_tensor(self.transform(image))
211 |
212 | return name, image, image_crop, pads, label, mask
213 |
--------------------------------------------------------------------------------
/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import numpy as np
4 |
5 | from PIL import Image
6 | import torchvision.transforms as tf
7 | import torchvision.transforms.functional as F
8 |
9 | from functools import partial
10 |
11 | class Compose:
12 | def __init__(self, segtransform):
13 | self.segtransform = segtransform
14 |
15 | def __call__(self, image, label):
16 | # allow for intermediate representations
17 | result = (image, label)
18 | for t in self.segtransform:
19 | result = t(*result)
20 |
21 | # ensure we have just the image
22 | # and the label in the end
23 | image, label = result
24 | return image, label
25 |
26 | class MaskRandResizedCrop:
27 |
28 | def __init__(self, cfg):
29 | self.rnd_crop = tf.RandomResizedCrop(cfg.CROP_SIZE, \
30 | scale=(cfg.SCALE_FROM, \
31 | cfg.SCALE_TO))
32 |
33 | def get_params(self, image):
34 | return self.rnd_crop.get_params(image, \
35 | self.rnd_crop.scale, \
36 | self.rnd_crop.ratio)
37 |
38 | def __call__(self, image, labels):
39 |
40 | i, j, h, w = self.get_params(image)
41 |
42 | image = F.resized_crop(image, i, j, h, w, self.rnd_crop.size, Image.CUBIC)
43 | labels = F.resized_crop(labels, i, j, h, w, self.rnd_crop.size, Image.NEAREST)
44 |
45 | return image, labels
46 |
47 | class MaskHFlip:
48 |
49 | def __init__(self, p=0.5):
50 | self.p = p
51 |
52 | def __call__(self, image, mask):
53 |
54 | if random.random() < self.p:
55 | image = F.hflip(image)
56 | mask = F.hflip(mask)
57 |
58 | return image, mask
59 |
60 | class MaskNormalise:
61 |
62 | def __init__(self, mean, std):
63 | self.norm = tf.Normalize(mean, std)
64 |
65 | def __toByteTensor(self, pic):
66 | return torch.from_numpy(np.array(pic, np.int32, copy=False))
67 |
68 | def __call__(self, image, labels):
69 |
70 | image = F.to_tensor(image)
71 | image = self.norm(image)
72 | labels = self.__toByteTensor(labels)
73 |
74 | return image, labels
75 |
76 | class MaskToTensor:
77 |
78 | def __call__(self, image, mask):
79 | gt_labels = torch.arange(0, 21)
80 | gt_labels = gt_labels.unsqueeze(-1).unsqueeze(-1)
81 | mask = mask.unsqueeze(0).type_as(gt_labels)
82 | mask = torch.eq(mask, gt_labels).float()
83 | return image, mask
84 |
85 | class MaskColourJitter:
86 |
87 | def __init__(self, p=0.5, brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1):
88 | self.p = p
89 | self.jitter = tf.ColorJitter(brightness=0.3, \
90 | contrast=0.3, \
91 | saturation=0.3, \
92 | hue=0.1)
93 |
94 | def __call__(self, image, mask):
95 |
96 | if random.random() < self.p:
97 | image = self.jitter(image)
98 |
99 | return image, mask
100 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def colormap(N=256):
5 | def bitget(byteval, idx):
6 | return ((byteval & (1 << idx)) != 0)
7 |
8 | dtype = 'uint8'
9 | cmap = []
10 | for i in range(N):
11 | r = g = b = 0
12 | c = i
13 | for j in range(8):
14 | r = r | (bitget(c, 0) << 7-j)
15 | g = g | (bitget(c, 1) << 7-j)
16 | b = b | (bitget(c, 2) << 7-j)
17 | c = c >> 3
18 |
19 | cmap.append((r, g, b))
20 |
21 | return cmap
22 |
23 | """
24 | Python implementation of the color map function for the PASCAL VOC data set.
25 | Official Matlab version can be found in the PASCAL VOC devkit
26 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit
27 | """
28 |
29 | def uint82bin(n, count=8):
30 | """returns the binary of integer n, count refers to amount of bits"""
31 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
32 |
33 | def labelcolormap(N):
34 | cmap = np.zeros((N, 3), dtype=np.uint8)
35 | for i in range(N):
36 | r = 0
37 | g = 0
38 | b = 0
39 | id = i
40 | for j in range(7):
41 | str_id = uint82bin(id)
42 | r = r ^ (np.uint8(str_id[-1]) << (7-j))
43 | g = g ^ (np.uint8(str_id[-2]) << (7-j))
44 | b = b ^ (np.uint8(str_id[-3]) << (7-j))
45 | id = id >> 3
46 | cmap[i, 0] = r
47 | cmap[i, 1] = g
48 | cmap[i, 2] = b
49 | return cmap
50 |
51 | class Colorize(object):
52 |
53 | def __init__(self, n=22):
54 | self.cmap = labelcolormap(22)
55 | self.cmap = torch.from_numpy(self.cmap[:n])
56 |
57 | def __call__(self, gray_image):
58 | size = gray_image.size()
59 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
60 |
61 | for label in range(0, len(self.cmap)):
62 | mask = (label == gray_image[0]).cpu()
63 | color_image[0][mask] = self.cmap[label][0]
64 | color_image[1][mask] = self.cmap[label][1]
65 | color_image[2][mask] = self.cmap[label][2]
66 |
67 | return color_image
68 |
69 |
70 |
--------------------------------------------------------------------------------
/eval_seg.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluating the masks
3 |
4 | TODO:
5 | Parallelise with
6 |
7 | from multiprocessing import Pool
8 |
9 | ConfM = ConfusionMatrix(class_num)
10 | f = ConfM.generateM
11 | pool = Pool()
12 | m_list = pool.map(f, data_list)
13 | pool.close()
14 | pool.join()
15 |
16 | """
17 |
18 | import sys
19 | import os
20 | import numpy as np
21 | import argparse
22 | import scipy
23 |
24 | from tqdm import tqdm
25 | from datasets.pascal_voc import PascalVOC
26 | from PIL import Image
27 | from utils.metrics import Metric
28 |
29 | # Defining
30 | parser = argparse.ArgumentParser(description="Mask Evaluation")
31 |
32 | parser.add_argument("--data", type=str, default='./data/annotation',
33 | help="The prefix for data directory")
34 | parser.add_argument("--filelist", type=str, default='./data/val.txt',
35 | help="A text file containing the paths to masks")
36 | parser.add_argument("--masks", type=str, default='./masks',
37 | help="A path to generated masks")
38 | parser.add_argument("--oracle-from", type=str, default="",
39 | help="Use GT mask but down- then upscale them")
40 | parser.add_argument("--log-scores", type=str, default='./scores.log',
41 | help="Logging scores for invididual images")
42 |
43 | def check_args(args):
44 | """Check the files/directories exist"""
45 |
46 | assert os.path.isdir(args.data), \
47 | "Directory {} does not exist".format(args.data)
48 | assert os.path.isfile(args.filelist), \
49 | "File {} does not exist".format(args.filelist)
50 | if len(args.oracle_from) > 0:
51 | vals = args.oracle_from.split('x')
52 | assert len(vals) == 2, "HxW expected"
53 | h, w = vals
54 | assert int(h) > 2, "Meaningless resolution"
55 | assert int(w) > 2, "Meaningless resolution"
56 | else:
57 | assert os.path.isdir(args.masks), \
58 | "Directory {} does not exist".format(args.masks)
59 |
60 | def format_num(x):
61 | return round(x*100., 1)
62 |
63 |
64 | def get_stats(M, i):
65 |
66 | TP = M[i, i]
67 | FN = np.sum(M[i, :]) - TP # false negatives
68 | FP = np.sum(M[:, i]) - TP # false positives
69 |
70 | return TP, FN, FP
71 |
72 | def summarise_one(class_stats, M, name, labels):
73 |
74 | for i in labels:
75 |
76 | # skipping the ambiguous
77 | if i == 255:
78 | continue
79 |
80 | # category name
81 | TP, FN, FP = get_stats(M, i)
82 | score = TP - FN - FP
83 |
84 | class_stats[i].append((name, score))
85 |
86 | def summarise_per_class(class_stats, filename):
87 |
88 | data = ""
89 | for cat in PascalVOC.CLASSES:
90 |
91 | if cat == "ambiguous":
92 | continue
93 |
94 | i = PascalVOC.CLASS_IDX[cat]
95 | sorted_by_score = sorted(class_stats[i], key=lambda x: -x[1])
96 | data += cat + "\n"
97 | for name, score in sorted_by_score:
98 | data += "{:05d} | {}\n".format(int(score), name)
99 |
100 | with open(filename, 'w') as f:
101 | f.write(data)
102 |
103 | def summarise_stats(M):
104 |
105 | eps = 1e-20
106 |
107 | mean = Metric()
108 | mean.add_metric(Metric.IoU)
109 | mean.add_metric(Metric.Precision)
110 | mean.add_metric(Metric.Recall)
111 |
112 | mean_bkg = Metric()
113 | mean_bkg.add_metric(Metric.IoU)
114 | mean_bkg.add_metric(Metric.Precision)
115 | mean_bkg.add_metric(Metric.Recall)
116 |
117 | head_fmt = "{:>12} | {:>5}" + " | {:>5}"*3
118 | row_fmt = "{:>12} | {:>5}" + " | {:>5.1f}"*3
119 | split = "-"*44
120 |
121 | def print_row(fmt, row):
122 | print(fmt.format(*row))
123 |
124 | print_row(head_fmt, ("Class", "#", "IoU", "Pr", "Re"))
125 | print(split)
126 |
127 | for cat in PascalVOC.CLASSES:
128 |
129 | if cat == "ambiguous":
130 | continue
131 |
132 | i = PascalVOC.CLASS_IDX[cat]
133 |
134 | TP, FN, FP = get_stats(M, i)
135 |
136 | iou = 100. * TP / (eps + FN + FP + TP)
137 | pr = 100. * TP / (eps + TP + FP)
138 | re = 100. * TP / (eps + TP + FN)
139 |
140 | mean_bkg.update_value(Metric.IoU, iou)
141 | mean_bkg.update_value(Metric.Precision, pr)
142 | mean_bkg.update_value(Metric.Recall, re)
143 |
144 | if cat != "background":
145 | mean.update_value(Metric.IoU, iou)
146 | mean.update_value(Metric.Precision, pr)
147 | mean.update_value(Metric.Recall, re)
148 |
149 | count = int(np.sum(M[i, :]))
150 | print_row(row_fmt, (cat, count, iou, pr, re))
151 |
152 |
153 | print(split)
154 | sys.stdout.write("mIoU: {:.2f}\t".format(mean.summarize(Metric.IoU)))
155 | sys.stdout.write(" Pr: {:.2f}\t".format(mean.summarize(Metric.Precision)))
156 | sys.stdout.write(" Re: {:.2f}\n".format(mean.summarize(Metric.Recall)))
157 |
158 | print(split)
159 | print("With background: ")
160 | sys.stdout.write("mIoU: {:.2f}\t".format(mean_bkg.summarize(Metric.IoU)))
161 | sys.stdout.write(" Pr: {:.2f}\t".format(mean_bkg.summarize(Metric.Precision)))
162 | sys.stdout.write(" Re: {:.2f}\n".format(mean_bkg.summarize(Metric.Recall)))
163 |
164 |
165 | def evaluate_one(conf_mat, mask_gt, mask):
166 |
167 | gt = mask_gt.reshape(-1)
168 | pred = mask.reshape(-1)
169 | conf_mat_one = np.zeros_like(conf_mat)
170 |
171 | assert(len(gt) == len(pred))
172 |
173 | for i in range(len(gt)):
174 | if gt[i] < conf_mat.shape[0]:
175 | conf_mat[gt[i], pred[i]] += 1.0
176 | conf_mat_one[gt[i], pred[i]] += 1.0
177 |
178 | return conf_mat_one
179 |
180 | def read_mask_file(filepath):
181 | return np.array(Image.open(filepath))
182 |
183 | def oracle_lower(mask, h, w, alpha):
184 |
185 | mask_dict = {}
186 | labels = np.unique(mask)
187 | new_mask = np.zeros_like(mask)
188 | H, W = mask.shape
189 |
190 | # skipping background
191 | for l in labels:
192 | if l in (0, 255):
193 | continue
194 |
195 | mask_l = (mask == l).astype(np.float)
196 | mask_down = scipy.misc.imresize(mask_l, (h, w), interp='bilinear')
197 | mask_up = scipy.misc.imresize(mask_down, (H, W), interp='bilinear')
198 | new_mask[mask_up > alpha] = l
199 |
200 | return new_mask
201 |
202 | def get_image_name(name):
203 | base = os.path.basename(name)
204 | base = base.replace(".jpg", "")
205 | return base
206 |
207 | def evaluate_all(args):
208 |
209 | with_oracle = False
210 | if len(args.oracle_from) > 0:
211 | oh, ow = [int(x) for x in args.oracle_from.split("x")]
212 | with_oracle = (oh > 1 and ow > 1)
213 |
214 | if with_oracle:
215 | print(">>> Using oracle {}x{}".format(oh, ow))
216 |
217 | # initialising the confusion matrix
218 | conf_mat = np.zeros((21, 21))
219 | class_stats = {}
220 | for class_idx in range(21):
221 | class_stats[class_idx] = []
222 |
223 | # count of the images
224 | num_im = 0
225 |
226 | # opening the filelist
227 | with open(args.filelist) as fd:
228 |
229 | for line in tqdm(fd.readlines()):
230 |
231 | files = [x.strip('/ \n') for x in line.split(' ')]
232 |
233 | if len(files) < 2:
234 | print("No path to GT mask found in line\n")
235 | print("\t{}".format(line))
236 | continue
237 |
238 | filepath_gt = os.path.join(args.data, files[1])
239 | if not os.path.isfile(filepath_gt):
240 | print("File not found (GT): {}".format(filepath_gt))
241 | continue
242 |
243 | mask_gt = read_mask_file(filepath_gt)
244 |
245 | if with_oracle:
246 | mask = oracle_lower(mask_gt, oh, ow, alpha=0.5)
247 | else:
248 | basename = os.path.basename(files[1])
249 | filepath = os.path.join(args.masks, basename)
250 | if not os.path.isfile(filepath):
251 | print("File not found: {}".format(filepath))
252 | continue
253 |
254 | mask = read_mask_file(filepath)
255 |
256 | if mask.shape != mask_gt.shape:
257 | print("Mask shape mismatch in {}: ".format(basename), \
258 | mask.shape, " vs ", mask_gt.shape)
259 | continue
260 |
261 | conf_mat_one = evaluate_one(conf_mat, mask_gt, mask)
262 |
263 | image_name = get_image_name(files[0])
264 | image_labels = np.unique(mask_gt)
265 | summarise_one(class_stats, conf_mat_one, image_name, image_labels)
266 |
267 | num_im += 1
268 |
269 |
270 | print("# of images: {}".format(num_im))
271 | summarise_per_class(class_stats, args.log_scores)
272 |
273 | return conf_mat
274 |
275 | if __name__ == "__main__":
276 |
277 | args = parser.parse_args(sys.argv[1:])
278 | check_args(args)
279 | stats = evaluate_all(args)
280 | summarise_stats(stats)
281 |
--------------------------------------------------------------------------------
/figures/results.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/figures/results.gif
--------------------------------------------------------------------------------
/figures/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/figures/results.png
--------------------------------------------------------------------------------
/fonts/UbuntuMono-R.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/fonts/UbuntuMono-R.ttf
--------------------------------------------------------------------------------
/infer_val.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluating class activation maps from a given snapshot
3 | Supports multi-scale fusion of the masks
4 | Based on PSA
5 | """
6 |
7 | import os
8 | import sys
9 | import numpy as np
10 | import scipy
11 | import torch.multiprocessing as mp
12 | from tqdm import tqdm
13 | from PIL import Image
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torchvision
18 | import torchvision.transforms as tf
19 | from torch.utils.data import DataLoader
20 | from torch.backends import cudnn
21 | cudnn.enabled = True
22 | cudnn.benchmark = False
23 | cudnn.deterministic = True
24 |
25 | from opts import get_arguments
26 | from core.config import cfg, cfg_from_file, cfg_from_list
27 | from models import get_model
28 |
29 | from utils.checkpoints import Checkpoint
30 | from utils.timer import Timer
31 | from utils.dcrf import crf_inference
32 | from utils.inference_tools import get_inference_io
33 |
34 | def check_dir(base_path, name):
35 |
36 | # create the directory
37 | fullpath = os.path.join(base_path, name)
38 | if not os.path.exists(fullpath):
39 | os.makedirs(fullpath)
40 |
41 | return fullpath
42 |
43 | def HWC_to_CHW(img):
44 | return np.transpose(img, (2, 0, 1))
45 |
46 | if __name__ == '__main__':
47 |
48 | # loading the model
49 | args = get_arguments(sys.argv[1:])
50 |
51 | # reading the config
52 | cfg_from_file(args.cfg_file)
53 | if args.set_cfgs is not None:
54 | cfg_from_list(args.set_cfgs)
55 |
56 | # initialising the dirs
57 | check_dir(args.mask_output_dir, "vis")
58 | check_dir(args.mask_output_dir, "crf")
59 |
60 | # Loading the model
61 | model = get_model(cfg.NET, num_classes=cfg.TEST.NUM_CLASSES)
62 | checkpoint = Checkpoint(args.snapshot_dir, max_n = 5)
63 | checkpoint.add_model('enc', model)
64 | checkpoint.load(args.resume)
65 |
66 | for p in model.parameters():
67 | p.requires_grad = False
68 |
69 | # setting the evaluation mode
70 | model.eval()
71 |
72 | assert hasattr(model, 'normalize')
73 | transform = tf.Compose([np.asarray, model.normalize])
74 |
75 | WriterClass, DatasetClass = get_inference_io(cfg.TEST.METHOD)
76 |
77 | dataset = DatasetClass(args.infer_list, cfg.TEST, transform=transform)
78 |
79 | dataloader = DataLoader(dataset, shuffle=False, num_workers=args.workers, \
80 | pin_memory=True, batch_size=cfg.TEST.BATCH_SIZE)
81 |
82 | model = nn.DataParallel(model).cuda()
83 |
84 | timer = Timer()
85 | N = len(dataloader)
86 |
87 | palette = dataset.get_palette()
88 | pool = mp.Pool(processes=args.workers)
89 | writer = WriterClass(cfg.TEST, palette, args.mask_output_dir)
90 |
91 | for iter, (img_name, img_orig, images_in, pads, labels, gt_mask) in enumerate(tqdm(dataloader)):
92 |
93 | # cutting the masks
94 | masks = []
95 |
96 | with torch.no_grad():
97 | cls_raw, masks_pred = model(images_in)
98 |
99 | if not cfg.TEST.USE_GT_LABELS:
100 | cls_sigmoid = torch.sigmoid(cls_raw)
101 | cls_sigmoid, _ = cls_sigmoid.max(0)
102 | #cls_sigmoid = cls_sigmoid.mean(0)
103 | # threshold class scores
104 | labels = (cls_sigmoid > cfg.TEST.FP_CUT_SCORE)
105 | else:
106 | labels = labels[0]
107 |
108 | # saving the raw npy
109 | image = dataset.denorm(img_orig[0]).numpy()
110 | masks_pred = masks_pred.cpu()
111 | labels = labels.type_as(masks_pred)
112 |
113 | #writer.save(img_name[0], image, masks_pred, pads, labels, gt_mask[0])
114 | pool.apply_async(writer.save, args=(img_name[0], image, masks_pred, pads, labels, gt_mask[0]))
115 |
116 | timer.update_progress(float(iter + 1) / N)
117 | if iter % 100 == 0:
118 | msg = "Finish time: {}".format(timer.str_est_finish())
119 | tqdm.write(msg)
120 | sys.stdout.flush()
121 |
122 | pool.close()
123 | pool.join()
124 |
--------------------------------------------------------------------------------
/launch/eval_seg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATASET=pascal_voc
4 | FILELIST=data/val_voc.txt # validation
5 |
6 | ## You values here:
7 | #
8 | OUTPUT_DIR=
9 | EXP=
10 | RUN_ID=
11 | #
12 | ##
13 |
14 |
15 | LISTNAME=`basename $FILELIST .txt`
16 |
17 | # without CRF
18 | SAVE_DIR=$OUTPUT_DIR/$DATASET/$EXP/$RUN_ID/$LISTNAME
19 | nohup python eval_seg.py --data ./data --filelist $FILELIST --masks $SAVE_DIR > $SAVE_DIR.eval 2>&1 &
20 |
21 | # with CRF
22 | SAVE_DIR=$OUTPUT_DIR/$DATASET/$EXP/$RUN_ID/$LISTNAME/crf
23 | nohup python eval_seg.py --data ./data --filelist $FILELIST --masks $SAVE_DIR > $SAVE_DIR.eval 2>&1 &
24 |
25 | sleep 1
26 |
27 | echo "Log: ${SAVE_DIR}.eval"
28 | tail -f $SAVE_DIR.eval
29 |
--------------------------------------------------------------------------------
/launch/infer_val.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #
4 | # Set your argument here
5 | #
6 | CONFIG=configs/voc_resnet38.yaml
7 | DATASET=pascal_voc
8 | FILELIST=data/val_voc.txt
9 |
10 | ## You values here (see below how they're used)
11 | #
12 | OUTPUT_DIR=
13 | EXP=
14 | RUN_ID=
15 | SNAPSHOT=
16 | EXTRA_ARGS=
17 | SAVE_ID=
18 | #
19 | ##
20 |
21 | # limiting threads
22 | NUM_THREADS=6
23 |
24 | set OMP_NUM_THREADS=$NUM_THREADS
25 | export OMP_NUM_THREADS=$NUM_THREADS
26 |
27 | #
28 | # Code goes here
29 | #
30 | LISTNAME=`basename $FILELIST .txt`
31 | SAVE_DIR=$OUTPUT_DIR/$DATASET/$EXP/$SAVE_ID/$LISTNAME
32 | LOG_FILE=$OUTPUT_DIR/$DATASET/$EXP/$SAVE_ID/$LISTNAME.log
33 |
34 | CMD="python infer_val.py --dataset $DATASET \
35 | --cfg $CONFIG \
36 | --exp $EXP \
37 | --run $RUN_ID \
38 | --resume $SNAPSHOT \
39 | --infer-list $FILELIST \
40 | --workers $NUM_THREADS \
41 | --mask-output-dir $SAVE_DIR \
42 | $EXTRA_ARGS"
43 |
44 | if [ ! -d $SAVE_DIR ]; then
45 | echo "Creating directory: $SAVE_DIR"
46 | mkdir -p $SAVE_DIR
47 | else
48 | echo "Saving to: $SAVE_DIR"
49 | fi
50 |
51 | git rev-parse HEAD > ${SAVE_DIR}.head
52 | git diff > ${SAVE_DIR}.diff
53 | echo $CMD > ${SAVE_DIR}.cmd
54 |
55 | echo $CMD
56 | nohup $CMD > $LOG_FILE 2>&1 &
57 |
58 | sleep 1
59 | tail -f $LOG_FILE
60 |
--------------------------------------------------------------------------------
/launch/run_bsl_resnet101.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ## Your values here:
4 | #
5 | DS=pascal_voc
6 | EXP=
7 | RUN_ID=
8 | #
9 | ##
10 |
11 | #
12 | # Script
13 | #
14 |
15 | LOG_DIR=logs/${DS}/${EXP}
16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet101.yaml --exp $EXP --run $RUN_ID --set NET.MODEL bsl TRAIN.NUM_EPOCHS 6"
17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log
18 |
19 | if [ ! -d "$LOG_DIR" ]; then
20 | echo "Creating directory $LOG_DIR"
21 | mkdir -p $LOG_DIR
22 | fi
23 |
24 | echo $CMD
25 | echo "LOG: $LOG_FILE"
26 |
27 | nohup $CMD > $LOG_FILE 2>&1 &
28 | sleep 1
29 | tail -f $LOG_FILE
30 |
--------------------------------------------------------------------------------
/launch/run_bsl_resnet38.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ## Your values here:
4 | #
5 | DS=pascal_voc
6 | EXP=
7 | RUN_ID=
8 | #
9 | ##
10 |
11 | #
12 | # Script
13 | #
14 |
15 | LOG_DIR=logs/${DS}/${EXP}
16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet38.yaml --exp $EXP --run $RUN_ID --set NET.MODEL bsl TRAIN.NUM_EPOCHS 6"
17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log
18 |
19 | if [ ! -d "$LOG_DIR" ]; then
20 | echo "Creating directory $LOG_DIR"
21 | mkdir -p $LOG_DIR
22 | fi
23 |
24 | echo $CMD
25 | echo "LOG: $LOG_FILE"
26 |
27 | nohup $CMD > $LOG_FILE 2>&1 &
28 | sleep 1
29 | tail -f $LOG_FILE
30 |
--------------------------------------------------------------------------------
/launch/run_bsl_resnet50.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ## Your values here:
4 | #
5 | DS=pascal_voc
6 | EXP=
7 | RUN_ID=
8 | #
9 | ##
10 |
11 | #
12 | # Script
13 | #
14 |
15 | LOG_DIR=logs/${DS}/${EXP}
16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet50.yaml --exp $EXP --run $RUN_ID --set NET.MODEL bsl TRAIN.NUM_EPOCHS 6"
17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log
18 |
19 | if [ ! -d "$LOG_DIR" ]; then
20 | echo "Creating directory $LOG_DIR"
21 | mkdir -p $LOG_DIR
22 | fi
23 |
24 | echo $CMD
25 | echo "LOG: $LOG_FILE"
26 |
27 | nohup $CMD > $LOG_FILE 2>&1 &
28 | sleep 1
29 | tail -f $LOG_FILE
30 |
--------------------------------------------------------------------------------
/launch/run_bsl_vgg16.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ## Your values here:
4 | #
5 | DS=pascal_voc
6 | EXP=
7 | RUN_ID=
8 | #
9 | ##
10 |
11 | #
12 | # Script
13 | #
14 |
15 | LOG_DIR=logs/${DS}/${EXP}
16 | CMD="python train.py --dataset $DS --cfg configs/voc_vgg16.yaml --exp $EXP --run $RUN_ID --set NET.MODEL bsl TRAIN.NUM_EPOCHS 6"
17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log
18 |
19 | if [ ! -d "$LOG_DIR" ]; then
20 | echo "Creating directory $LOG_DIR"
21 | mkdir -p $LOG_DIR
22 | fi
23 |
24 | echo $CMD
25 | echo "LOG: $LOG_FILE"
26 |
27 | nohup $CMD > $LOG_FILE 2>&1 &
28 | sleep 1
29 | tail -f $LOG_FILE
30 |
--------------------------------------------------------------------------------
/launch/run_voc_resnet101.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ## Your values here:
4 | #
5 | DS=pascal_voc
6 | EXP=
7 | RUN_ID=
8 | #
9 | ##
10 |
11 | #
12 | # Script
13 | #
14 |
15 | LOG_DIR=logs/${DS}/${EXP}
16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet101.yaml --exp $EXP --run $RUN_ID"
17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log
18 |
19 | if [ ! -d "$LOG_DIR" ]; then
20 | echo "Creating directory $LOG_DIR"
21 | mkdir -p $LOG_DIR
22 | fi
23 |
24 | echo $CMD
25 | echo "LOG: $LOG_FILE"
26 |
27 | nohup $CMD > $LOG_FILE 2>&1 &
28 | sleep 1
29 | tail -f $LOG_FILE
30 |
--------------------------------------------------------------------------------
/launch/run_voc_resnet38.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ## Your values here:
4 | #
5 | DS=pascal_voc
6 | EXP=
7 | RUN_ID=
8 | #
9 | ##
10 |
11 | #
12 | # Script
13 | #
14 |
15 | LOG_DIR=logs/${DS}/${EXP}
16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet38.yaml --exp $EXP --run $RUN_ID"
17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log
18 |
19 | if [ ! -d "$LOG_DIR" ]; then
20 | echo "Creating directory $LOG_DIR"
21 | mkdir -p $LOG_DIR
22 | fi
23 |
24 | echo $CMD
25 | echo "LOG: $LOG_FILE"
26 |
27 | nohup $CMD > $LOG_FILE 2>&1 &
28 | sleep 1
29 | tail -f $LOG_FILE
30 |
--------------------------------------------------------------------------------
/launch/run_voc_resnet50.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ## Your values here:
4 | #
5 | DS=pascal_voc
6 | EXP=
7 | RUN_ID=
8 | #
9 | ##
10 |
11 | #
12 | # Script
13 | #
14 |
15 | LOG_DIR=logs/${DS}/${EXP}
16 | CMD="python train.py --dataset $DS --cfg configs/voc_resnet50.yaml --exp $EXP --run $RUN_ID"
17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log
18 |
19 | if [ ! -d "$LOG_DIR" ]; then
20 | echo "Creating directory $LOG_DIR"
21 | mkdir -p $LOG_DIR
22 | fi
23 |
24 | echo $CMD
25 | echo "LOG: $LOG_FILE"
26 |
27 | nohup $CMD > $LOG_FILE 2>&1 &
28 | sleep 1
29 | tail -f $LOG_FILE
30 |
--------------------------------------------------------------------------------
/launch/run_voc_vgg16.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ## Your values here:
4 | #
5 | DS=pascal_voc
6 | EXP=
7 | RUN_ID=
8 | #
9 | ##
10 |
11 | #
12 | # Script
13 | #
14 |
15 | LOG_DIR=logs/${DS}/${EXP}
16 | CMD="python train.py --dataset $DS --cfg configs/voc_vgg16.yaml --exp $EXP --run $RUN_ID"
17 | LOG_FILE=$LOG_DIR/${RUN_ID}.log
18 |
19 | if [ ! -d "$LOG_DIR" ]; then
20 | echo "Creating directory $LOG_DIR"
21 | mkdir -p $LOG_DIR
22 | fi
23 |
24 | echo $CMD
25 | echo "LOG: $LOG_FILE"
26 |
27 | nohup $CMD > $LOG_FILE 2>&1 &
28 | sleep 1
29 | tail -f $LOG_FILE
30 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from functools import partial
5 |
6 | class MLHingeLoss(nn.Module):
7 |
8 | def forward(self, x, y, reduction='mean'):
9 | """
10 | y: labels have standard {0,1} form and will be converted to indices
11 | """
12 | b, c = x.size()
13 | idx = (torch.arange(c) + 1).type_as(x)
14 | y_idx, _ = (idx * y).sort(-1, descending=True)
15 | y_idx = (y_idx - 1).long()
16 |
17 | return F.multilabel_margin_loss(x, y_idx, reduction=reduction)
18 |
19 | def get_criterion(loss_name, **kwargs):
20 |
21 | losses = {
22 | "SoftMargin": nn.MultiLabelSoftMarginLoss,
23 | "Hinge": MLHingeLoss
24 | }
25 |
26 | return losses[loss_name](**kwargs)
27 |
28 |
29 | #
30 | # Mask self-supervision
31 | #
32 | def mask_loss_ce(mask, pseudo_gt, ignore_index=255):
33 | mask = F.interpolate(mask, size=pseudo_gt.size()[-2:], mode="bilinear", align_corners=True)
34 |
35 | # indices of the max classes
36 | mask_gt = torch.argmax(pseudo_gt, 1)
37 |
38 | # for each pixel there should be at least one 1
39 | # otherwise, ignore
40 | weight = pseudo_gt.sum(1).type_as(mask_gt)
41 | mask_gt += (1 - weight) * ignore_index
42 |
43 | # BCE loss
44 | loss = F.cross_entropy(mask, mask_gt, ignore_index=ignore_index)
45 | return loss.mean()
46 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from .stage_net import network_factory
3 |
4 | def get_model(cfg, *args, **kwargs):
5 | net = partial(network_factory(cfg), config=cfg, pre_weights=cfg.PRE_WEIGHTS_PATH)
6 | return net(*args, **kwargs)
7 |
--------------------------------------------------------------------------------
/models/backbones/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/models/backbones/__init__.py
--------------------------------------------------------------------------------
/models/backbones/base_net.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | class Normalize():
7 | def __init__(self, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)):
8 |
9 | self.mean = mean
10 | self.std = std
11 |
12 | def undo(self, imgarr):
13 | proc_img = imgarr.copy()
14 |
15 | proc_img[..., 0] = (self.std[0] * imgarr[..., 0] + self.mean[0]) * 255.
16 | proc_img[..., 1] = (self.std[1] * imgarr[..., 1] + self.mean[1]) * 255.
17 | proc_img[..., 2] = (self.std[2] * imgarr[..., 2] + self.mean[2]) * 255.
18 |
19 | return proc_img
20 |
21 | def __call__(self, img):
22 | imgarr = np.asarray(img)
23 | proc_img = np.empty_like(imgarr, np.float32)
24 |
25 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0]
26 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1]
27 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2]
28 |
29 | return proc_img
30 |
31 | class BaseNet(nn.Module):
32 |
33 | def __init__(self):
34 | super().__init__()
35 | self.normalize = Normalize()
36 | self.NormLayer = nn.BatchNorm2d
37 |
38 | self.not_training = [] # freezing parameters
39 | self.bn_frozen = [] # freezing running stats
40 | self.from_scratch_layers = [] # new layers -> higher LR
41 |
42 | def _init_weights(self, path_to_weights):
43 | print("Loading weights from: ", path_to_weights)
44 | weights_dict = torch.load(path_to_weights)
45 | self.load_state_dict(weights_dict, strict=False)
46 |
47 | def fan_out(self):
48 | raise NotImplementedError
49 |
50 | def fixed_layers(self):
51 | return self.not_training
52 |
53 | def _fix_running_stats(self, layer, fix_params=False):
54 |
55 | if isinstance(layer, self.NormLayer):
56 | self.bn_frozen.append(layer)
57 | if fix_params and not layer in self.not_training:
58 | self.not_training.append(layer)
59 | elif isinstance(layer, list):
60 | for m in layer:
61 | self._fix_running_stats(m, fix_params)
62 | else:
63 | for m in layer.children():
64 | self._fix_running_stats(m, fix_params)
65 |
66 | def _fix_params(self, layer):
67 |
68 | if isinstance(layer, nn.Conv2d) or \
69 | isinstance(layer, self.NormLayer) or \
70 | isinstance(layer, nn.Linear):
71 | self.not_training.append(layer)
72 | if isinstance(layer, self.NormLayer):
73 | self.bn_frozen.append(layer)
74 | elif isinstance(layer, list):
75 | for m in layer:
76 | self._fix_params(m)
77 | elif isinstance(layer, nn.Module):
78 | if hasattr(layer, "weight") or hasattr(layer, "bias"):
79 | print("Ignoring fixed weight/bias layer: ", layer)
80 |
81 | for m in layer.children():
82 | self._fix_params(m)
83 |
84 | def _freeze_bn(self, layer):
85 |
86 | if isinstance(layer, self.NormLayer):
87 | # freezing the layer
88 | layer.eval()
89 | elif isinstance(layer, nn.Module):
90 | for m in layer.children():
91 | self._freeze_bn(m)
92 |
93 | def train(self, mode=True):
94 |
95 | super().train(mode)
96 |
97 | for layer in self.not_training:
98 |
99 | if hasattr(layer, "weight") and not layer.weight is None:
100 | layer.weight.requires_grad = False
101 |
102 | if hasattr(layer, "bias") and not layer.bias is None:
103 | layer.bias.requires_grad = False
104 |
105 | elif isinstance(layer, torch.nn.Module):
106 | print("Unkown layer to fix: ", layer)
107 |
108 | for bn_layer in self.bn_frozen:
109 | self._freeze_bn(bn_layer)
110 |
111 | def _lr_mult(self):
112 | return 1., 2., 10., 20
113 |
114 | def parameter_groups(self, base_lr, wd):
115 |
116 | w_old, b_old, w_new, b_new = self._lr_mult()
117 |
118 | groups = ({"params": [], "weight_decay": wd, "lr": w_old*base_lr}, # weight learning
119 | {"params": [], "weight_decay": 0.0, "lr": b_old*base_lr}, # bias finetuning
120 | {"params": [], "weight_decay": wd, "lr": w_new*base_lr}, # weight finetuning
121 | {"params": [], "weight_decay": 0.0, "lr": b_new*base_lr}) # bias learning
122 |
123 | fixed_layers = self.fixed_layers()
124 |
125 | for m in self.modules():
126 |
127 | if m in fixed_layers:
128 | # skipping fixed layers
129 | continue
130 |
131 | if isinstance(m, nn.Conv2d) or \
132 | isinstance(m, nn.Linear) or \
133 | isinstance(m, self.NormLayer):
134 |
135 | if not m.weight is None:
136 | if m in self.from_scratch_layers:
137 | groups[2]["params"].append(m.weight)
138 | else:
139 | groups[0]["params"].append(m.weight)
140 |
141 | if not m.bias is None:
142 | if m in self.from_scratch_layers:
143 | groups[3]["params"].append(m.bias)
144 | else:
145 | groups[1]["params"].append(m.bias)
146 |
147 | elif hasattr(m, "weight"):
148 | print("! Skipping learnable: ", m)
149 |
150 | for i, g in enumerate(groups):
151 | print("Group {}: #{}, LR={:4.3e}".format(i, len(g["params"]), g["lr"]))
152 |
153 | return groups
154 |
--------------------------------------------------------------------------------
/models/backbones/resnet38d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 | import torch.nn.functional as F
6 |
7 | from models.backbones.base_net import BaseNet
8 |
9 | class ResBlock(nn.Module):
10 | def __init__(self, in_channels, mid_channels, out_channels, stride=1, first_dilation=None, dilation=1):
11 | super(ResBlock, self).__init__()
12 |
13 | self.same_shape = (in_channels == out_channels and stride == 1)
14 |
15 | if first_dilation == None: first_dilation = dilation
16 |
17 | self.bn_branch2a = nn.BatchNorm2d(in_channels)
18 |
19 | self.conv_branch2a = nn.Conv2d(in_channels, mid_channels, 3, stride,
20 | padding=first_dilation, dilation=first_dilation, bias=False)
21 |
22 | self.bn_branch2b1 = nn.BatchNorm2d(mid_channels)
23 |
24 | self.conv_branch2b1 = nn.Conv2d(mid_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False)
25 |
26 | if not self.same_shape:
27 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
28 |
29 | def forward(self, x, get_x_bn_relu=False):
30 |
31 | branch2 = self.bn_branch2a(x)
32 | branch2 = F.relu(branch2)
33 |
34 | x_bn_relu = branch2
35 |
36 | if not self.same_shape:
37 | branch1 = self.conv_branch1(branch2)
38 | else:
39 | branch1 = x
40 |
41 | branch2 = self.conv_branch2a(branch2)
42 | branch2 = self.bn_branch2b1(branch2)
43 | branch2 = F.relu(branch2)
44 | branch2 = self.conv_branch2b1(branch2)
45 |
46 | x = branch1 + branch2
47 |
48 | if get_x_bn_relu:
49 | return x, x_bn_relu
50 |
51 | return x
52 |
53 | def __call__(self, x, get_x_bn_relu=False):
54 | return self.forward(x, get_x_bn_relu=get_x_bn_relu)
55 |
56 | class ResBlock_bot(nn.Module):
57 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, dropout=0.):
58 | super(ResBlock_bot, self).__init__()
59 |
60 | self.same_shape = (in_channels == out_channels and stride == 1)
61 |
62 | self.bn_branch2a = nn.BatchNorm2d(in_channels)
63 | self.conv_branch2a = nn.Conv2d(in_channels, out_channels//4, 1, stride, bias=False)
64 |
65 | self.bn_branch2b1 = nn.BatchNorm2d(out_channels//4)
66 | self.dropout_2b1 = torch.nn.Dropout2d(dropout)
67 | self.conv_branch2b1 = nn.Conv2d(out_channels//4, out_channels//2, 3, padding=dilation, dilation=dilation, bias=False)
68 |
69 | self.bn_branch2b2 = nn.BatchNorm2d(out_channels//2)
70 | self.dropout_2b2 = torch.nn.Dropout2d(dropout)
71 | self.conv_branch2b2 = nn.Conv2d(out_channels//2, out_channels, 1, bias=False)
72 |
73 | if not self.same_shape:
74 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
75 |
76 | def forward(self, x, get_x_bn_relu=False):
77 |
78 | branch2 = self.bn_branch2a(x)
79 | branch2 = F.relu(branch2)
80 | x_bn_relu = branch2
81 |
82 | branch1 = self.conv_branch1(branch2)
83 |
84 | branch2 = self.conv_branch2a(branch2)
85 |
86 | branch2 = self.bn_branch2b1(branch2)
87 | branch2 = F.relu(branch2)
88 | branch2 = self.dropout_2b1(branch2)
89 | branch2 = self.conv_branch2b1(branch2)
90 |
91 | branch2 = self.bn_branch2b2(branch2)
92 | branch2 = F.relu(branch2)
93 | branch2 = self.dropout_2b2(branch2)
94 | branch2 = self.conv_branch2b2(branch2)
95 |
96 | x = branch1 + branch2
97 |
98 | if get_x_bn_relu:
99 | return x, x_bn_relu
100 |
101 | return x
102 |
103 | def __call__(self, x, get_x_bn_relu=False):
104 | return self.forward(x, get_x_bn_relu=get_x_bn_relu)
105 |
106 | class ResNet38(BaseNet):
107 |
108 | def __init__(self):
109 | super(ResNet38, self).__init__()
110 |
111 | self.conv1a = nn.Conv2d(3, 64, 3, padding=1, bias=False)
112 |
113 | self.b2 = ResBlock(64, 128, 128, stride=2)
114 | self.b2_1 = ResBlock(128, 128, 128)
115 | self.b2_2 = ResBlock(128, 128, 128)
116 |
117 | self.b3 = ResBlock(128, 256, 256, stride=2)
118 | self.b3_1 = ResBlock(256, 256, 256)
119 | self.b3_2 = ResBlock(256, 256, 256)
120 |
121 | self.b4 = ResBlock(256, 512, 512, stride=2)
122 | self.b4_1 = ResBlock(512, 512, 512)
123 | self.b4_2 = ResBlock(512, 512, 512)
124 | self.b4_3 = ResBlock(512, 512, 512)
125 | self.b4_4 = ResBlock(512, 512, 512)
126 | self.b4_5 = ResBlock(512, 512, 512)
127 |
128 | self.b5 = ResBlock(512, 512, 1024, stride=1, first_dilation=1, dilation=2)
129 | self.b5_1 = ResBlock(1024, 512, 1024, dilation=2)
130 | self.b5_2 = ResBlock(1024, 512, 1024, dilation=2)
131 |
132 | self.b6 = ResBlock_bot(1024, 2048, stride=1, dilation=4, dropout=0.3)
133 |
134 | self.b7 = ResBlock_bot(2048, 4096, dilation=4, dropout=0.5)
135 |
136 | self.bn7 = nn.BatchNorm2d(4096)
137 |
138 | # fixing the parameters
139 | self._fix_params([self.conv1a, self.b2, self.b2_1, self.b2_2])
140 |
141 | def fan_out(self):
142 | return 4096
143 |
144 | def forward(self, x):
145 | return self.forward_as_dict(x)['conv6']
146 |
147 | def forward_as_dict(self, x):
148 |
149 | x = self.conv1a(x)
150 |
151 | x = self.b2(x)
152 | x = self.b2_1(x)
153 | x = self.b2_2(x)
154 |
155 | x = self.b3(x)
156 | x = self.b3_1(x)
157 | x = self.b3_2(x)
158 | conv3 = x
159 |
160 | x = self.b4(x)
161 | x = self.b4_1(x)
162 | x = self.b4_2(x)
163 | x = self.b4_3(x)
164 | x = self.b4_4(x)
165 | x = self.b4_5(x)
166 |
167 | x, conv4 = self.b5(x, get_x_bn_relu=True)
168 | x = self.b5_1(x)
169 | x = self.b5_2(x)
170 |
171 | x, conv5 = self.b6(x, get_x_bn_relu=True)
172 |
173 | x = self.b7(x)
174 | conv6 = F.relu(self.bn7(x))
175 |
176 | return dict({'conv3': conv3, 'conv6': conv6})
177 |
--------------------------------------------------------------------------------
/models/backbones/resnets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from models.backbones.base_net import BaseNet
6 |
7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8 | """3x3 convolution with padding"""
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=dilation, groups=groups, bias=False, dilation=dilation)
11 |
12 |
13 | def conv1x1(in_planes, out_planes, stride=1):
14 | """1x1 convolution"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
16 |
17 | class Bottleneck(nn.Module):
18 | expansion = 4
19 | __constants__ = ['downsample']
20 |
21 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
22 | base_width=64, dilation=1, norm_layer=None):
23 | super(Bottleneck, self).__init__()
24 | if norm_layer is None:
25 | norm_layer = nn.BatchNorm2d
26 | width = int(planes * (base_width / 64.)) * groups
27 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
28 | self.conv1 = conv1x1(inplanes, width)
29 | self.bn1 = norm_layer(width)
30 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
31 | self.bn2 = norm_layer(width)
32 | self.conv3 = conv1x1(width, planes * self.expansion)
33 | self.bn3 = norm_layer(planes * self.expansion)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | identity = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv3(out)
50 | out = self.bn3(out)
51 |
52 | if self.downsample is not None:
53 | identity = self.downsample(x)
54 |
55 | out += identity
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 | class ResNet(BaseNet):
61 |
62 | def __init__(self, block, layers, zero_init_residual=False,
63 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
64 | norm_layer=None, deep_base=False):
65 | super(ResNet, self).__init__()
66 | if norm_layer is None:
67 | norm_layer = nn.BatchNorm2d
68 | self._norm_layer = norm_layer
69 |
70 | self.dilation = 1
71 | if replace_stride_with_dilation is None:
72 | # each element in the tuple indicates if we should replace
73 | # the 2x2 stride with a dilated convolution instead
74 | replace_stride_with_dilation = [False, False, False]
75 | if len(replace_stride_with_dilation) != 3:
76 | raise ValueError("replace_stride_with_dilation should be None "
77 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
78 | self.groups = groups
79 | self.base_width = width_per_group
80 | self.deep_base = deep_base
81 | if not self.deep_base: # see PSPNet implementation
82 | self.inplanes = 64
83 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
84 | bias=False)
85 | self.bn1 = norm_layer(self.inplanes)
86 | else:
87 | self.inplanes = 128
88 | self.conv1 = conv3x3(3, 64, stride=2)
89 | self.bn1 = norm_layer(64)
90 | self.conv2 = conv3x3(64, 64)
91 | self.bn2 = norm_layer(64)
92 | self.conv3 = conv3x3(64, 128)
93 | self.bn3 = norm_layer(128)
94 |
95 | self.relu = nn.ReLU(inplace=True)
96 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
97 | self.layer1 = self._make_layer(block, 64, layers[0])
98 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
99 | dilate=replace_stride_with_dilation[0])
100 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
101 | dilate=replace_stride_with_dilation[1])
102 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
103 | dilate=replace_stride_with_dilation[2])
104 |
105 | # note no global pooling of fully-connected layers
106 |
107 | for m in self.modules():
108 | if isinstance(m, nn.Conv2d):
109 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111 | nn.init.constant_(m.weight, 1)
112 | nn.init.constant_(m.bias, 0)
113 |
114 | # Zero-initialize the last BN in each residual branch,
115 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
116 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
117 | if zero_init_residual:
118 | for m in self.modules():
119 | if isinstance(m, Bottleneck):
120 | nn.init.constant_(m.bn3.weight, 0)
121 | elif isinstance(m, BasicBlock):
122 | nn.init.constant_(m.bn2.weight, 0)
123 |
124 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
125 | norm_layer = self._norm_layer
126 | downsample = None
127 | previous_dilation = self.dilation
128 | if dilate:
129 | self.dilation *= stride
130 | stride = 1
131 | if stride != 1 or self.inplanes != planes * block.expansion:
132 | downsample = nn.Sequential(
133 | conv1x1(self.inplanes, planes * block.expansion, stride),
134 | norm_layer(planes * block.expansion),
135 | )
136 |
137 | layers = []
138 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
139 | self.base_width, previous_dilation, norm_layer))
140 | self.inplanes = planes * block.expansion
141 | for _ in range(1, blocks):
142 | layers.append(block(self.inplanes, planes, groups=self.groups,
143 | base_width=self.base_width, dilation=self.dilation,
144 | norm_layer=norm_layer))
145 |
146 | return nn.Sequential(*layers)
147 |
148 | def _forward_impl(self, x):
149 | x = self.conv1(x)
150 | x = self.bn1(x)
151 | x = self.relu(x)
152 |
153 | if self.deep_base:
154 | x = self.relu(self.bn2(self.conv2(x)))
155 | x = self.relu(self.bn3(self.conv3(x)))
156 |
157 | x = self.maxpool(x)
158 |
159 | x = self.layer1(x)
160 | x = self.layer2(x)
161 | x = self.layer3(x)
162 | x = self.layer4(x)
163 |
164 | return x
165 |
166 | def forward_as_dict(self, x):
167 | # See note [TorchScript super()]
168 | x = self.conv1(x)
169 | x = self.bn1(x)
170 | x = self.relu(x)
171 |
172 | if self.deep_base:
173 | x = self.relu(self.bn2(self.conv2(x)))
174 | x = self.relu(self.bn3(self.conv3(x)))
175 |
176 | x = self.maxpool(x)
177 |
178 | x = self.layer1(x)
179 | conv3 = x
180 |
181 | x = self.layer2(x)
182 | x = self.layer3(x)
183 | x = self.layer4(x)
184 |
185 | return {"conv6": x, "conv3": conv3}
186 |
187 | def forward(self, x):
188 | return self.forward_as_dict(x)["conv6"]
189 |
190 | def _lr_mult(self):
191 | return 1., 1., 10., 10.
192 |
193 | class ResNet50(ResNet):
194 |
195 | def __init__(self):
196 | super(ResNet50, self).__init__(Bottleneck, [3, 4, 6, 3], \
197 | replace_stride_with_dilation=[False, False, False])
198 |
199 | # fixing the parameters
200 | self._fix_params([self.conv1, self.bn1])
201 |
202 | assert not self.deep_base
203 |
204 | def fan_out(self):
205 | return 2048
206 |
207 | class ResNet101(ResNet):
208 |
209 | def __init__(self):
210 | super(ResNet101, self).__init__(Bottleneck, [3, 4, 23, 3], \
211 | replace_stride_with_dilation=[False, False, False])
212 |
213 | # fixing the parameters
214 | self._fix_params([self.conv1, self.bn1])
215 |
216 | assert not self.deep_base
217 |
218 | def fan_out(self):
219 | return 2048
220 |
--------------------------------------------------------------------------------
/models/backbones/vgg16d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn
6 |
7 | from models.backbones.base_net import BaseNet
8 |
9 | class VGG16(BaseNet):
10 | def __init__(self, fc6_dilation = 1):
11 | super(VGG16, self).__init__()
12 |
13 | self.conv1_1 = nn.Conv2d(3,64,3,padding = 1)
14 | self.conv1_2 = nn.Conv2d(64,64,3,padding = 1)
15 | self.pool1 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding=1)
16 | self.conv2_1 = nn.Conv2d(64,128,3,padding = 1)
17 | self.conv2_2 = nn.Conv2d(128,128,3,padding = 1)
18 | self.pool2 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding=1)
19 | self.conv3_1 = nn.Conv2d(128,256,3,padding = 1)
20 | self.conv3_2 = nn.Conv2d(256,256,3,padding = 1)
21 | self.conv3_3 = nn.Conv2d(256,256,3,padding = 1)
22 | self.pool3 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding=1)
23 | self.conv4_1 = nn.Conv2d(256,512,3,padding = 1)
24 | self.conv4_2 = nn.Conv2d(512,512,3,padding = 1)
25 | self.conv4_3 = nn.Conv2d(512,512,3,padding = 1)
26 | self.pool4 = nn.MaxPool2d(kernel_size = 3, stride = 1, padding=1)
27 | self.conv5_1 = nn.Conv2d(512,512,3,padding = 2, dilation = 2)
28 | self.conv5_2 = nn.Conv2d(512,512,3,padding = 2, dilation = 2)
29 | self.conv5_3 = nn.Conv2d(512,512,3,padding = 2, dilation = 2)
30 |
31 | self.fc6 = nn.Conv2d(512,1024, 3, padding = fc6_dilation, dilation = fc6_dilation)
32 |
33 | self.drop6 = nn.Dropout2d(p=0.5)
34 | self.fc7 = nn.Conv2d(1024, 1024, 1)
35 |
36 | # fixing the parameters
37 | self._fix_params([self.conv1_1, self.conv1_2])
38 |
39 | def fan_out(self):
40 | return 1024
41 |
42 | def forward(self, x):
43 | return self.forward_as_dict(x)['conv6']
44 |
45 | def forward_as_dict(self, x):
46 |
47 | x = F.relu(self.conv1_1(x), inplace=True)
48 | x = F.relu(self.conv1_2(x), inplace=True)
49 | x = self.pool1(x)
50 |
51 | x = F.relu(self.conv2_1(x), inplace=True)
52 | x = F.relu(self.conv2_2(x), inplace=True)
53 | x = self.pool2(x)
54 |
55 | x = F.relu(self.conv3_1(x), inplace=True)
56 | x = F.relu(self.conv3_2(x), inplace=True)
57 | x = F.relu(self.conv3_3(x), inplace=True)
58 | conv3 = x
59 |
60 | x = self.pool3(x)
61 |
62 | x = F.relu(self.conv4_1(x), inplace=True)
63 | x = F.relu(self.conv4_2(x), inplace=True)
64 | x = F.relu(self.conv4_3(x), inplace=True)
65 |
66 | x = self.pool4(x)
67 |
68 | x = F.relu(self.conv5_1(x), inplace=True)
69 | x = F.relu(self.conv5_2(x), inplace=True)
70 | x = F.relu(self.conv5_3(x), inplace=True)
71 |
72 | x = F.relu(self.fc6(x), inplace=True)
73 | x = self.drop6(x)
74 | x = F.relu(self.fc7(x), inplace=True)
75 |
76 | conv6 = x
77 |
78 | return dict({'conv3': conv3, 'conv6': conv6})
79 |
--------------------------------------------------------------------------------
/models/mods/__init__.py:
--------------------------------------------------------------------------------
1 | from .sg import StochasticGate
2 | from .pamr import PAMR
3 | from .aspp import ASPP
4 | from .gci import GCI
5 |
--------------------------------------------------------------------------------
/models/mods/aspp.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class _ASPPModule(nn.Module):
7 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
8 | super(_ASPPModule, self).__init__()
9 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
10 | stride=1, padding=padding, dilation=dilation, bias=False)
11 | self.bn = BatchNorm(planes)
12 | self.relu = nn.ReLU()
13 |
14 | def forward(self, x):
15 | x = self.atrous_conv(x)
16 | x = self.bn(x)
17 |
18 | return self.relu(x)
19 |
20 | class ASPP(nn.Module):
21 | def __init__(self, inplanes, output_stride, BatchNorm):
22 | super(ASPP, self).__init__()
23 |
24 | if output_stride == 16:
25 | dilations = [1, 6, 12, 18]
26 | elif output_stride == 8:
27 | dilations = [1, 12, 24, 36]
28 | else:
29 | raise NotImplementedError
30 |
31 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
32 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
33 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
34 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
35 |
36 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
37 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
38 | BatchNorm(256),
39 | nn.ReLU())
40 |
41 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
42 | self.bn1 = BatchNorm(256)
43 | self.relu = nn.ReLU()
44 | self.dropout = nn.Dropout(0.5)
45 | self._init_weight()
46 |
47 | def forward(self, x):
48 | x1 = self.aspp1(x)
49 | x2 = self.aspp2(x)
50 | x3 = self.aspp3(x)
51 | x4 = self.aspp4(x)
52 | x5 = self.global_avg_pool(x)
53 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
54 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
55 |
56 | x = self.conv1(x)
57 | x = self.bn1(x)
58 | x = self.relu(x)
59 |
60 | return self.dropout(x)
61 |
62 | def _init_weight(self):
63 | for m in self.modules():
64 | if isinstance(m, nn.Conv2d):
65 | torch.nn.init.kaiming_normal_(m.weight)
66 | elif isinstance(m, nn.BatchNorm2d):
67 | if not m.weight is None:
68 | m.weight.data.fill_(1)
69 | else:
70 | print("ASPP has not weight: ", m)
71 |
72 | if not m.bias is None:
73 | m.bias.data.zero_()
74 | else:
75 | print("ASPP has not bias: ", m)
76 |
77 |
78 | def build_aspp(backbone, output_stride, BatchNorm):
79 | return ASPP(backbone, output_stride, BatchNorm)
80 |
--------------------------------------------------------------------------------
/models/mods/gci.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class GCI(nn.Module):
7 | """Global Cue Injection
8 | Takes shallow features with low receptive
9 | field and augments it with global info via
10 | adaptive instance normalisation"""
11 |
12 | def __init__(self, NormLayer=nn.BatchNorm2d):
13 | super(GCI, self).__init__()
14 |
15 | self.NormLayer = NormLayer
16 | self.from_scratch_layers = []
17 |
18 | self._init_params()
19 |
20 | def _conv2d(self, *args, **kwargs):
21 | conv = nn.Conv2d(*args, **kwargs)
22 | self.from_scratch_layers.append(conv)
23 | torch.nn.init.kaiming_normal_(conv.weight)
24 | return conv
25 |
26 | def _bnorm(self, *args, **kwargs):
27 | bn = self.NormLayer(*args, **kwargs)
28 | #self.bn_learn.append(bn)
29 | self.from_scratch_layers.append(bn)
30 | if not bn.weight is None:
31 | bn.weight.data.fill_(1)
32 | bn.bias.data.zero_()
33 | return bn
34 |
35 | def _init_params(self):
36 |
37 | self.fc_deep = nn.Sequential(self._conv2d(256, 512, 1, bias=False), \
38 | self._bnorm(512), nn.ReLU())
39 |
40 | self.fc_skip = nn.Sequential(self._conv2d(256, 256, 1, bias=False), \
41 | self._bnorm(256, affine=False))
42 |
43 | self.fc_cls = nn.Sequential(self._conv2d(256, 256, 1, bias=False), \
44 | self._bnorm(256), nn.ReLU())
45 |
46 | def forward(self, x, y):
47 | """Forward pass
48 |
49 | Args:
50 | x: shalow features
51 | y: deep features
52 | """
53 |
54 | # extract global attributes
55 | y = self.fc_deep(y)
56 | attrs, _ = y.view(y.size(0), y.size(1), -1).max(-1)
57 |
58 | # pre-process shallow features
59 | x = self.fc_skip(x)
60 | x = F.relu(self._adin_conv(x, attrs))
61 |
62 | return self.fc_cls(x)
63 |
64 | def _adin_conv(self, x, y):
65 |
66 | bs, num_c, _, _ = x.size()
67 | assert 2*num_c == y.size(1), "AdIN: dimension mismatch"
68 |
69 | y = y.view(bs, 2, num_c)
70 | gamma, beta = y[:, 0], y[:, 1]
71 |
72 | gamma = gamma.unsqueeze(-1).unsqueeze(-1)
73 | beta = beta.unsqueeze(-1).unsqueeze(-1)
74 |
75 | return x * (gamma + 1) + beta
76 |
--------------------------------------------------------------------------------
/models/mods/pamr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | from functools import partial
6 |
7 | #
8 | # Helper modules
9 | #
10 | class LocalAffinity(nn.Module):
11 |
12 | def __init__(self, dilations=[1]):
13 | super(LocalAffinity, self).__init__()
14 | self.dilations = dilations
15 | weight = self._init_aff()
16 | self.register_buffer('kernel', weight)
17 |
18 | def _init_aff(self):
19 | # initialising the shift kernel
20 | weight = torch.zeros(8, 1, 3, 3)
21 |
22 | for i in range(weight.size(0)):
23 | weight[i, 0, 1, 1] = 1
24 |
25 | weight[0, 0, 0, 0] = -1
26 | weight[1, 0, 0, 1] = -1
27 | weight[2, 0, 0, 2] = -1
28 |
29 | weight[3, 0, 1, 0] = -1
30 | weight[4, 0, 1, 2] = -1
31 |
32 | weight[5, 0, 2, 0] = -1
33 | weight[6, 0, 2, 1] = -1
34 | weight[7, 0, 2, 2] = -1
35 |
36 | self.weight_check = weight.clone()
37 |
38 | return weight
39 |
40 | def forward(self, x):
41 |
42 | self.weight_check = self.weight_check.type_as(x)
43 | assert torch.all(self.weight_check.eq(self.kernel))
44 |
45 | B,K,H,W = x.size()
46 | x = x.view(B*K,1,H,W)
47 |
48 | x_affs = []
49 | for d in self.dilations:
50 | x_pad = F.pad(x, [d]*4, mode='replicate')
51 | x_aff = F.conv2d(x_pad, self.kernel, dilation=d)
52 | x_affs.append(x_aff)
53 |
54 | x_aff = torch.cat(x_affs, 1)
55 | return x_aff.view(B,K,-1,H,W)
56 |
57 | class LocalAffinityCopy(LocalAffinity):
58 |
59 | def _init_aff(self):
60 | # initialising the shift kernel
61 | weight = torch.zeros(8, 1, 3, 3)
62 |
63 | weight[0, 0, 0, 0] = 1
64 | weight[1, 0, 0, 1] = 1
65 | weight[2, 0, 0, 2] = 1
66 |
67 | weight[3, 0, 1, 0] = 1
68 | weight[4, 0, 1, 2] = 1
69 |
70 | weight[5, 0, 2, 0] = 1
71 | weight[6, 0, 2, 1] = 1
72 | weight[7, 0, 2, 2] = 1
73 |
74 | self.weight_check = weight.clone()
75 | return weight
76 |
77 | class LocalStDev(LocalAffinity):
78 |
79 | def _init_aff(self):
80 | weight = torch.zeros(9, 1, 3, 3)
81 | weight.zero_()
82 |
83 | weight[0, 0, 0, 0] = 1
84 | weight[1, 0, 0, 1] = 1
85 | weight[2, 0, 0, 2] = 1
86 |
87 | weight[3, 0, 1, 0] = 1
88 | weight[4, 0, 1, 1] = 1
89 | weight[5, 0, 1, 2] = 1
90 |
91 | weight[6, 0, 2, 0] = 1
92 | weight[7, 0, 2, 1] = 1
93 | weight[8, 0, 2, 2] = 1
94 |
95 | self.weight_check = weight.clone()
96 | return weight
97 |
98 | def forward(self, x):
99 | # returns (B,K,P,H,W), where P is the number
100 | # of locations
101 | x = super(LocalStDev, self).forward(x)
102 |
103 | return x.std(2, keepdim=True)
104 |
105 | class LocalAffinityAbs(LocalAffinity):
106 |
107 | def forward(self, x):
108 | x = super(LocalAffinityAbs, self).forward(x)
109 | return torch.abs(x)
110 |
111 | #
112 | # PAMR module
113 | #
114 | class PAMR(nn.Module):
115 |
116 | def __init__(self, num_iter=1, dilations=[1]):
117 | super(PAMR, self).__init__()
118 |
119 | self.num_iter = num_iter
120 | self.aff_x = LocalAffinityAbs(dilations)
121 | self.aff_m = LocalAffinityCopy(dilations)
122 | self.aff_std = LocalStDev(dilations)
123 |
124 | def forward(self, x, mask):
125 | mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True)
126 |
127 | # x: [BxKxHxW]
128 | # mask: [BxCxHxW]
129 | B,K,H,W = x.size()
130 | _,C,_,_ = mask.size()
131 |
132 | x_std = self.aff_std(x)
133 |
134 | x = -self.aff_x(x) / (1e-8 + 0.1 * x_std)
135 | x = x.mean(1, keepdim=True)
136 | x = F.softmax(x, 2)
137 |
138 | for _ in range(self.num_iter):
139 | m = self.aff_m(mask) # [BxCxPxHxW]
140 | mask = (m * x).sum(2)
141 |
142 | # xvals: [BxCxHxW]
143 | return mask
144 |
--------------------------------------------------------------------------------
/models/mods/sg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class StochasticGate(nn.Module):
7 | """Stochastically merges features from two levels
8 | with varying size of the receptive field
9 | """
10 |
11 | def __init__(self):
12 | super(StochasticGate, self).__init__()
13 | self._mask_drop = None
14 |
15 | def forward(self, x1, x2, alpha_rate=0.3):
16 | """Stochastic Gate (SG)
17 |
18 | SG stochastically mixes deep and shallow features
19 | at training time and deterministically combines
20 | them at test time with a hyperparam. alpha
21 | """
22 |
23 | if self.training: # training time
24 | # dropout: selecting either x1 or x2
25 | if self._mask_drop is None:
26 | bs, c, h, w = x1.size()
27 | assert c == x2.size(1), "Number of features is different"
28 | self._mask_drop = torch.ones_like(x1)
29 |
30 | # a mask of {0,1}
31 | mask_drop = (1 - alpha_rate) * F.dropout(self._mask_drop, alpha_rate)
32 |
33 | # shift and scale deep features
34 | # at train time: E[x] = x1
35 | x1 = (x1 - alpha_rate * x2) / max(1e-8, 1 - alpha_rate)
36 |
37 | # combine the features
38 | x = mask_drop * x1 + (1 - mask_drop) * x2
39 | else:
40 | # inference time: deterministic
41 | x = (1 - alpha_rate) * x1 + alpha_rate * x2
42 |
43 | return x
44 |
45 |
--------------------------------------------------------------------------------
/models/stage_net.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | # backbone nets
7 | from models.backbones.resnet38d import ResNet38
8 | from models.backbones.vgg16d import VGG16
9 | from models.backbones.resnets import ResNet101, ResNet50
10 |
11 | # modules
12 | from models.mods import ASPP
13 | from models.mods import PAMR
14 | from models.mods import StochasticGate
15 | from models.mods import GCI
16 |
17 | #
18 | # Helper classes
19 | #
20 | def rescale_as(x, y, mode="bilinear", align_corners=True):
21 | h, w = y.size()[2:]
22 | x = F.interpolate(x, size=[h, w], mode=mode, align_corners=align_corners)
23 | return x
24 |
25 | def focal_loss(x, p = 1, c = 0.1):
26 | return torch.pow(1 - x, p) * torch.log(c + x)
27 |
28 | def pseudo_gtmask(mask, cutoff_top=0.6, cutoff_low=0.2, eps=1e-8):
29 | """Convert continuous mask into binary mask"""
30 | bs,c,h,w = mask.size()
31 | mask = mask.view(bs,c,-1)
32 |
33 | # for each class extract the max confidence
34 | mask_max, _ = mask.max(-1, keepdim=True)
35 | mask_max[:, :1] *= 0.7
36 | mask_max[:, 1:] *= cutoff_top
37 | #mask_max *= cutoff_top
38 |
39 | # if the top score is too low, ignore it
40 | lowest = torch.Tensor([cutoff_low]).type_as(mask_max)
41 | mask_max = mask_max.max(lowest)
42 |
43 | pseudo_gt = (mask > mask_max).type_as(mask)
44 |
45 | # remove ambiguous pixels
46 | ambiguous = (pseudo_gt.sum(1, keepdim=True) > 1).type_as(mask)
47 | pseudo_gt = (1 - ambiguous) * pseudo_gt
48 |
49 | return pseudo_gt.view(bs,c,h,w)
50 |
51 | def balanced_mask_loss_ce(mask, pseudo_gt, gt_labels, ignore_index=255):
52 | """Class-balanced CE loss
53 | - cancel loss if only one class in pseudo_gt
54 | - weight loss equally between classes
55 | """
56 |
57 | mask = F.interpolate(mask, size=pseudo_gt.size()[-2:], mode="bilinear", align_corners=True)
58 |
59 | # indices of the max classes
60 | mask_gt = torch.argmax(pseudo_gt, 1)
61 |
62 | # for each pixel there should be at least one 1
63 | # otherwise, ignore
64 | ignore_mask = pseudo_gt.sum(1) < 1.
65 | mask_gt[ignore_mask] = ignore_index
66 |
67 | # class weight balances the loss w.r.t. number of pixels
68 | # because we are equally interested in all classes
69 | bs,c,h,w = pseudo_gt.size()
70 | num_pixels_per_class = pseudo_gt.view(bs,c,-1).sum(-1)
71 | num_pixels_total = num_pixels_per_class.sum(-1, keepdim=True)
72 | class_weight = (num_pixels_total - num_pixels_per_class) / (1 + num_pixels_total)
73 | class_weight = (pseudo_gt * class_weight[:,:,None,None]).sum(1).view(bs, -1)
74 |
75 | # BCE loss
76 | loss = F.cross_entropy(mask, mask_gt, ignore_index=ignore_index, reduction="none")
77 | loss = loss.view(bs, -1)
78 |
79 | # we will have the loss only for batch indices
80 | # which have all classes in pseudo mask
81 | gt_num_labels = gt_labels.sum(-1).type_as(loss) + 1 # + BG
82 | ps_num_labels = (num_pixels_per_class > 0).type_as(loss).sum(-1)
83 | batch_weight = (gt_num_labels == ps_num_labels).type_as(loss)
84 |
85 | loss = batch_weight * (class_weight * loss).mean(-1)
86 | return loss
87 |
88 | class Flatten(nn.Module):
89 | def forward(self, input):
90 | return input.view(input.size(0), -1)
91 |
92 | #
93 | # Dynamic change of the base class
94 | #
95 | def network_factory(cfg):
96 |
97 | if cfg.BACKBONE == "resnet38":
98 | print("Backbone: ResNet38")
99 | backbone = ResNet38
100 | elif cfg.BACKBONE == "vgg16":
101 | print("Backbone: VGG16")
102 | backbone = VGG16
103 | elif cfg.BACKBONE == "resnet50":
104 | print("Backbone: ResNet50")
105 | backbone = ResNet50
106 | elif cfg.BACKBONE == "resnet101":
107 | print("Backbone: ResNet101")
108 | backbone = ResNet101
109 | else:
110 | raise NotImplementedError("No backbone found for '{}'".format(cfg.BACKBONE))
111 |
112 | #
113 | # Class definitions
114 | #
115 | class BaselineCAM(backbone):
116 |
117 | def __init__(self, config, pre_weights=None, num_classes=21, dropout=True):
118 | super().__init__()
119 |
120 | self.cfg = config
121 |
122 | self.fc8 = nn.Conv2d(self.fan_out(), num_classes - 1, 1, bias=False)
123 | nn.init.xavier_uniform_(self.fc8.weight)
124 |
125 | cls_modules = [nn.AdaptiveAvgPool2d((1, 1)), self.fc8, Flatten()]
126 | if dropout:
127 | cls_modules.insert(0, nn.Dropout2d(0.5))
128 |
129 | self.cls_branch = nn.Sequential(*cls_modules)
130 | self.mask_branch = nn.Sequential(self.fc8, nn.ReLU())
131 |
132 | self.from_scratch_layers = [self.fc8]
133 | self._init_weights(pre_weights)
134 | self._mask_logits = None
135 |
136 | self._fix_running_stats(self, fix_params=True) # freeze backbone BNs
137 |
138 | def forward_backbone(self, x):
139 | self._mask_logits = super().forward(x)
140 | return self._mask_logits
141 |
142 | def forward_cls(self, x):
143 | return self.cls_branch(x)
144 |
145 | def forward_mask(self, x, size):
146 | logits = self.fc8(x)
147 | masks = F.interpolate(logits, size=size, mode='bilinear', align_corners=True)
148 | masks = F.relu(masks)
149 |
150 | # CAMs are unbounded
151 | # so let's normalised it first
152 | # (see jiwoon-ahn/psa)
153 | b,c,h,w = masks.size()
154 | masks_ = masks.view(b,c,-1)
155 | z, _ = masks_.max(-1, keepdim=True)
156 | masks_ /= (1e-5 + z)
157 | masks = masks.view(b,c,h,w)
158 |
159 | bg = torch.ones_like(masks[:, :1])
160 | masks = torch.cat([self.cfg.BG_SCORE * bg, masks], 1)
161 |
162 | # note, that the masks contain the background as the first channel
163 | return logits, masks
164 |
165 | def forward(self, y, _, labels=None):
166 | test_mode = labels is None
167 |
168 | x = self.forward_backbone(y)
169 |
170 | cls = self.forward_cls(x)
171 | logits, masks = self.forward_mask(x, y.size()[-2:])
172 |
173 | if test_mode:
174 | return cls, masks
175 |
176 | # foreground stats
177 | b,c,h,w = masks.size()
178 | masks_ = masks.view(b,c,-1)
179 | masks_ = masks_[:, 1:]
180 | cls_fg = (masks_.mean(-1) * labels).sum(-1) / labels.sum(-1)
181 |
182 | # upscale the masks & clean
183 | masks = self._rescale_and_clean(masks, y, labels)
184 |
185 | return cls, cls_fg, {"cam": masks}, logits, None, None
186 |
187 | def _rescale_and_clean(self, masks, image, labels):
188 | masks = F.interpolate(masks, size=image.size()[-2:], mode='bilinear', align_corners=True)
189 | masks[:, 1:] *= labels[:, :, None, None].type_as(masks)
190 | return masks
191 |
192 | #
193 | # Softmax unit
194 | #
195 | class SoftMaxAE(backbone):
196 |
197 | def __init__(self, config, pre_weights=None, num_classes=21, dropout=True):
198 | super().__init__()
199 |
200 | self.cfg = config
201 | self.num_classes = num_classes
202 |
203 | self._init_weights(pre_weights) # initialise backbone weights
204 | self._fix_running_stats(self, fix_params=True) # freeze backbone BNs
205 |
206 | # Decoder
207 | self._init_aspp()
208 | self._init_decoder(num_classes)
209 |
210 | self._backbone = None
211 | self._mask_logits = None
212 |
213 | def _init_aspp(self):
214 | self.aspp = ASPP(self.fan_out(), 8, self.NormLayer)
215 |
216 | for m in self.aspp.modules():
217 | if isinstance(m, nn.Conv2d) or isinstance(m, self.NormLayer):
218 | self.from_scratch_layers.append(m)
219 |
220 | self._fix_running_stats(self.aspp) # freeze BN
221 |
222 | def _init_decoder(self, num_classes):
223 |
224 | self._aff = PAMR(self.cfg.PAMR_ITER, self.cfg.PAMR_KERNEL)
225 |
226 | def conv2d(*args, **kwargs):
227 | conv = nn.Conv2d(*args, **kwargs)
228 | self.from_scratch_layers.append(conv)
229 | torch.nn.init.kaiming_normal_(conv.weight)
230 | return conv
231 |
232 | def bnorm(*args, **kwargs):
233 | bn = self.NormLayer(*args, **kwargs)
234 | self.from_scratch_layers.append(bn)
235 | if not bn.weight is None:
236 | bn.weight.data.fill_(1)
237 | bn.bias.data.zero_()
238 | return bn
239 |
240 | # pre-processing for shallow features
241 | self.shallow_mask = GCI(self.NormLayer)
242 | self.from_scratch_layers += self.shallow_mask.from_scratch_layers
243 |
244 | # Stochastic Gate
245 | self.sg = StochasticGate()
246 | self.fc8_skip = nn.Sequential(conv2d(256, 48, 1, bias=False), bnorm(48), nn.ReLU())
247 | self.fc8_x = nn.Sequential(conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
248 | bnorm(256), nn.ReLU())
249 |
250 | # decoder
251 | self.last_conv = nn.Sequential(conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
252 | bnorm(256), nn.ReLU(),
253 | nn.Dropout(0.5),
254 | conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
255 | bnorm(256), nn.ReLU(),
256 | nn.Dropout(0.1),
257 | conv2d(256, num_classes - 1, kernel_size=1, stride=1))
258 |
259 | def run_pamr(self, im, mask):
260 | im = F.interpolate(im, mask.size()[-2:], mode="bilinear", align_corners=True)
261 | masks_dec = self._aff(im, mask)
262 | return masks_dec
263 |
264 | def forward_backbone(self, x):
265 | self._backbone = super().forward_as_dict(x)
266 | return self._backbone['conv6']
267 |
268 | def forward(self, y, y_raw=None, labels=None):
269 | test_mode = y_raw is None and labels is None
270 |
271 | # 1. backbone pass
272 | x = self.forward_backbone(y)
273 |
274 | # 2. ASPP modules
275 | x = self.aspp(x)
276 |
277 | #
278 | # 3. merging deep and shallow features
279 | #
280 |
281 | # 3.1 skip connection for deep features
282 | x2_x = self.fc8_skip(self._backbone['conv3'])
283 | x_up = rescale_as(x, x2_x)
284 | x = self.fc8_x(torch.cat([x_up, x2_x], 1))
285 |
286 | # 3.2 deep feature context for shallow features
287 | x2 = self.shallow_mask(self._backbone['conv3'], x)
288 |
289 | # 3.3 stochastically merging the masks
290 | x = self.sg(x, x2, alpha_rate=self.cfg.SG_PSI)
291 |
292 | # 4. final convs to get the masks
293 | x = self.last_conv(x)
294 |
295 | #
296 | # 5. Finalising the masks and scores
297 | #
298 |
299 | # constant BG scores
300 | bg = torch.ones_like(x[:, :1])
301 | x = torch.cat([bg, x], 1)
302 |
303 | bs, c, h, w = x.size()
304 |
305 | masks = F.softmax(x, dim=1)
306 |
307 | # reshaping
308 | features = x.view(bs, c, -1)
309 | masks_ = masks.view(bs, c, -1)
310 |
311 | # classification loss
312 | cls_1 = (features * masks_).sum(-1) / (1.0 + masks_.sum(-1))
313 |
314 | # focal penalty loss
315 | cls_2 = focal_loss(masks_.mean(-1), \
316 | p=self.cfg.FOCAL_P, \
317 | c=self.cfg.FOCAL_LAMBDA)
318 |
319 | # adding the losses together
320 | cls = cls_1[:, 1:] + cls_2[:, 1:]
321 |
322 | if test_mode:
323 | # if in test mode, not mask
324 | # cleaning is performed
325 | return cls, rescale_as(masks, y)
326 |
327 | self._mask_logits = x
328 |
329 | # foreground stats
330 | masks_ = masks_[:, 1:]
331 | cls_fg = (masks_.mean(-1) * labels).sum(-1) / labels.sum(-1)
332 |
333 | # mask refinement with PAMR
334 | masks_dec = self.run_pamr(y_raw, masks.detach())
335 |
336 | # upscale the masks & clean
337 | masks = self._rescale_and_clean(masks, y, labels)
338 | masks_dec = self._rescale_and_clean(masks_dec, y, labels)
339 |
340 | # create pseudo GT
341 | pseudo_gt = pseudo_gtmask(masks_dec).detach()
342 | loss_mask = balanced_mask_loss_ce(self._mask_logits, pseudo_gt, labels)
343 |
344 | return cls, cls_fg, {"cam": masks, "dec": masks_dec}, self._mask_logits, pseudo_gt, loss_mask
345 |
346 | def _rescale_and_clean(self, masks, image, labels):
347 | """Rescale to fit the image size and remove any masks
348 | of labels that are not present"""
349 | masks = F.interpolate(masks, size=image.size()[-2:], mode='bilinear', align_corners=True)
350 | masks[:, 1:] *= labels[:, :, None, None].type_as(masks)
351 | return masks
352 |
353 |
354 | if cfg.MODEL == 'ae':
355 | print("Model: AE")
356 | return SoftMaxAE
357 | elif cfg.MODEL == 'bsl':
358 | print("Model: Baseline")
359 | return BaselineCAM
360 | else:
361 | raise NotImplementedError("Unknown model '{}'".format(cfg.MODEL))
362 |
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import torch
5 | import argparse
6 | from core.config import cfg
7 |
8 | def add_global_arguments(parser):
9 |
10 | parser.add_argument("--dataset", type=str,
11 | help="Determines dataloader to use (only Pascal VOC supported)")
12 | parser.add_argument("--exp", type=str, default="main",
13 | help="ID of the experiment (multiple runs)")
14 | parser.add_argument("--resume", type=str, default=None,
15 | help="Snapshot \"ID,iter\" to load")
16 | parser.add_argument("--run", type=str, help="ID of the run")
17 | parser.add_argument('--workers', type=int, default=8,
18 | metavar='N', help='dataloader threads')
19 | parser.add_argument("--snapshot-dir", type=str, default='./snapshots',
20 | help="Where to save snapshots of the model.")
21 | parser.add_argument("--logdir", type=str, default='./logs',
22 | help="Where to save log files of the model.")
23 |
24 | # used at inference only
25 | parser.add_argument("--infer-list", type=str, default='data/val_augvoc.txt',
26 | help="Path to a file list")
27 | parser.add_argument("--mask-output-dir", type=str, default='results/',
28 | help="Path where to save masks")
29 |
30 | #
31 | # Configuration
32 | #
33 | parser.add_argument(
34 | '--cfg', dest='cfg_file', required=True,
35 | help='Config file for training (and optionally testing)')
36 | parser.add_argument(
37 | '--set', dest='set_cfgs',
38 | help='Set config keys. Key value sequence seperate by whitespace.'
39 | 'e.g. [key] [value] [key] [value]',
40 | default=[], nargs='+')
41 |
42 | parser.add_argument("--random-seed", type=int, default=64, help="Random seed")
43 |
44 |
45 | def maybe_create_dir(path):
46 | if not os.path.exists(path):
47 | os.makedirs(path)
48 |
49 | def check_global_arguments(args):
50 |
51 | torch.set_num_threads(args.workers)
52 | if args.workers != torch.get_num_threads():
53 | print("Warning: # of threads is only ", torch.get_num_threads())
54 |
55 | setattr(args, "fixed_batch_path", os.path.join(args.logdir, args.dataset, args.exp, "fixed_batch.pt"))
56 | args.logdir = os.path.join(args.logdir, args.dataset, args.exp, args.run)
57 | maybe_create_dir(args.logdir)
58 | #print("Saving events in: {}".format(args.logdir))
59 |
60 | #
61 | # Model directories
62 | #
63 | args.snapshot_dir = os.path.join(args.snapshot_dir, args.dataset, args.exp, args.run)
64 | maybe_create_dir(args.snapshot_dir)
65 | #print("Saving snapshots in: {}".format(args.snapshot_dir))
66 |
67 | def get_arguments(args_in):
68 | """Parse all the arguments provided from the CLI.
69 |
70 | Returns:
71 | A list of parsed arguments.
72 | """
73 | parser = argparse.ArgumentParser(description="Model Evaluation")
74 |
75 | add_global_arguments(parser)
76 | args = parser.parse_args(args_in)
77 | check_global_arguments(args)
78 |
79 | return args
80 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | python=3.6.3
5 | numpy=1.15.4
6 | pillow=5.4.1
7 | pydensecrf=1.0rc3
8 | pytorch=1.0.1
9 | pyyaml=3.13
10 | torchvision=0.2.1
11 | tqdm==4.31.1
12 | tensorboardX=1.6
13 | packaging=19.0
14 | scikit-learn=0.20.2
15 | scikit-image=0.16.1
16 | six=1.12.0
17 | scipy=1.2.1
18 |
--------------------------------------------------------------------------------
/tools/convert_sbd.py:
--------------------------------------------------------------------------------
1 | """Convert .mat segmentation mask of SBD to .png
2 | See: https://github.com/visinf/1-stage-wseg/issues/9
3 | """
4 |
5 | import os
6 | import sys
7 | import glob
8 | import argparse
9 | from PIL import Image
10 | from scipy.io import loadmat
11 |
12 | # load tqdm optionally
13 | try:
14 | from tqdm import tqdm
15 | except ImportError:
16 | tqdm = lambda x: x
17 |
18 |
19 | def args():
20 | parser = argparse.ArgumentParser(description="Convert SBD .mat to .png")
21 | parser.add_argument("--inp", type=str, default='./dataset/cls/',
22 | help="Directory with .mat files")
23 | parser.add_argument("--out", type=str, default='./dataset/cls_png/',
24 | help="Directory where to save .png files")
25 | return parser.parse_args(sys.argv[1:])
26 |
27 |
28 | def convert(opts):
29 |
30 | # searching for files .mat
31 | opts.inp = opts.inp + ("" if opts.inp[-1] == "/" else "/")
32 | filelist = glob.glob(opts.inp + "*.mat")
33 | print("Found {:d} files".format(len(filelist)))
34 |
35 | if len(filelist) == 0:
36 | return
37 |
38 | # check output directory
39 | if not os.path.isdir(opts.out):
40 | print("Creating {}".format(opts.out))
41 | os.makedirs(opts.out)
42 |
43 |
44 | for filepath in tqdm(filelist):
45 |
46 | x = loadmat(filepath)
47 | y = x['GTcls']['Segmentation'][0][0]
48 |
49 | # converting to PIL image
50 | png = Image.fromarray(y)
51 |
52 | name = os.path.basename(filepath).replace(".mat", ".png")
53 | png.save(os.path.join(opts.out, name))
54 |
55 |
56 | if __name__ == "__main__":
57 | opts = args()
58 | convert(opts)
59 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 | import math
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 | from sklearn.metrics import average_precision_score
11 |
12 | from datasets import get_dataloader, get_num_classes, get_class_names
13 | from models import get_model
14 |
15 | from base_trainer import BaseTrainer
16 | from functools import partial
17 |
18 | from opts import get_arguments
19 | from core.config import cfg, cfg_from_file, cfg_from_list
20 | from datasets.utils import Colorize
21 | from losses import get_criterion, mask_loss_ce
22 |
23 | from utils.timer import Timer
24 | from utils.stat_manager import StatManager
25 | from utils.metrics import compute_jaccard
26 |
27 | # specific to pytorch-v1 cuda-9.0
28 | # see: https://github.com/pytorch/pytorch/issues/15054#issuecomment-450191923
29 | # and: https://github.com/pytorch/pytorch/issues/14456
30 | torch.backends.cudnn.benchmark = True
31 | #torch.backends.cudnn.deterministic = True
32 | DEBUG = False
33 |
34 | def rescale_as(x, y, mode="bilinear", align_corners=True):
35 | h, w = y.size()[2:]
36 | x = F.interpolate(x, size=[h, w], mode=mode, align_corners=align_corners)
37 | return x
38 |
39 | class DecTrainer(BaseTrainer):
40 |
41 | def __init__(self, args, **kwargs):
42 | super(DecTrainer, self).__init__(args, **kwargs)
43 |
44 | # dataloader
45 | self.trainloader = get_dataloader(args, cfg, 'train')
46 | self.trainloader_val = get_dataloader(args, cfg, 'train_voc')
47 | self.valloader = get_dataloader(args, cfg, 'val')
48 | self.denorm = self.trainloader.dataset.denorm
49 |
50 | self.nclass = get_num_classes(args)
51 | self.classNames = get_class_names(args)[:-1]
52 | assert self.nclass == len(self.classNames)
53 |
54 | self.classIndex = {}
55 | for i, cname in enumerate(self.classNames):
56 | self.classIndex[cname] = i
57 |
58 | # model
59 | self.enc = get_model(cfg.NET, num_classes=self.nclass)
60 | self.criterion_cls = get_criterion(cfg.NET.LOSS)
61 | print(self.enc)
62 |
63 | # optimizer using different LR
64 | enc_params = self.enc.parameter_groups(cfg.NET.LR, cfg.NET.WEIGHT_DECAY)
65 | self.optim_enc = self.get_optim(enc_params, cfg.NET)
66 |
67 | # checkpoint management
68 | self._define_checkpoint('enc', self.enc, self.optim_enc)
69 | self._load_checkpoint(args.resume)
70 |
71 | self.fixed_batch = None
72 | self.fixed_batch_path = args.fixed_batch_path
73 | if os.path.isfile(self.fixed_batch_path):
74 | print("Loading fixed batch from {}".format(self.fixed_batch_path))
75 | self.fixed_batch = torch.load(self.fixed_batch_path)
76 |
77 | # using cuda
78 | self.enc = nn.DataParallel(self.enc).cuda()
79 | self.criterion_cls = nn.DataParallel(self.criterion_cls).cuda()
80 |
81 | def step(self, epoch, image, gt_labels, train=False, visualise=False):
82 |
83 | PRETRAIN = epoch < (11 if DEBUG else cfg.TRAIN.PRETRAIN)
84 |
85 | # denorm image
86 | image_raw = self.denorm(image.clone())
87 |
88 | # classification
89 | cls_out, cls_fg, masks, mask_logits, pseudo_gt, loss_mask = self.enc(image, image_raw, gt_labels)
90 |
91 | # classification loss
92 | loss_cls = self.criterion_cls(cls_out, gt_labels).mean()
93 |
94 | # keep track of all losses for logging
95 | losses = {"loss_cls": loss_cls.item(),
96 | "loss_fg": cls_fg.mean().item()}
97 |
98 | loss = loss_cls.clone()
99 | if "dec" in masks:
100 | loss_mask = loss_mask.mean()
101 |
102 | if not PRETRAIN:
103 | loss += cfg.NET.MASK_LOSS_BCE * loss_mask
104 |
105 | assert not "pseudo" in masks
106 | masks["pseudo"] = pseudo_gt
107 | losses["loss_mask"] = loss_mask.item()
108 |
109 | losses["loss"] = loss.item()
110 |
111 | if train:
112 | self.optim_enc.zero_grad()
113 | loss.backward()
114 | self.optim_enc.step()
115 |
116 | for mask_key, mask_val in masks.items():
117 | masks[mask_key] = masks[mask_key].detach()
118 |
119 | mask_logits = mask_logits.detach()
120 |
121 | if visualise:
122 | self._visualise(epoch, image, masks, mask_logits, cls_out, gt_labels)
123 |
124 | # make sure to cut the return values from graph
125 | return losses, cls_out.detach(), masks, mask_logits
126 |
127 | def train_epoch(self, epoch):
128 | self.enc.train()
129 |
130 | stat = StatManager()
131 | stat.add_val("loss")
132 | stat.add_val("loss_cls")
133 | stat.add_val("loss_fg")
134 | stat.add_val("loss_bce")
135 |
136 | # adding stats for classes
137 | timer = Timer("New Epoch: ")
138 | train_step = partial(self.step, train=True, visualise=False)
139 |
140 | for i, (image, gt_labels, _) in enumerate(self.trainloader):
141 |
142 | # masks
143 | losses, _, _, _ = train_step(epoch, image, gt_labels)
144 |
145 | if self.fixed_batch is None:
146 | self.fixed_batch = {}
147 | self.fixed_batch["image"] = image.clone()
148 | self.fixed_batch["labels"] = gt_labels.clone()
149 | torch.save(self.fixed_batch, self.fixed_batch_path)
150 |
151 | for loss_key, loss_val in losses.items():
152 | stat.update_stats(loss_key, loss_val)
153 |
154 | # intermediate logging
155 | if i % 10 == 0:
156 | msg = "Loss [{:04d}]: ".format(i)
157 | for loss_key, loss_val in losses.items():
158 | msg += "{}: {:.4f} | ".format(loss_key, loss_val)
159 |
160 | msg += " | Im/Sec: {:.1f}".format(i * cfg.TRAIN.BATCH_SIZE / timer.get_stage_elapsed())
161 | print(msg)
162 | sys.stdout.flush()
163 |
164 | del image, gt_labels
165 |
166 | if DEBUG and i > 100:
167 | break
168 |
169 | def publish_loss(stats, name, t, prefix='data/'):
170 | print("{}: {:4.3f}".format(name, stats.summarize_key(name)))
171 | #self.writer.add_scalar(prefix + name, stats.summarize_key(name), t)
172 |
173 | for stat_key in stat.vals.keys():
174 | publish_loss(stat, stat_key, epoch)
175 |
176 | # plotting learning rate
177 | for ii, l in enumerate(self.optim_enc.param_groups):
178 | print("Learning rate [{}]: {:4.3e}".format(ii, l['lr']))
179 | self.writer.add_scalar('lr/enc_group_%02d' % ii, l['lr'], epoch)
180 |
181 | #self.writer.add_scalar('lr/bg_baseline', self.enc.module.mean.item(), epoch)
182 |
183 | # visualising
184 | self.enc.eval()
185 | with torch.no_grad():
186 | self.step(epoch, self.fixed_batch["image"], \
187 | self.fixed_batch["labels"], \
188 | train=False, visualise=True)
189 |
190 | def _mask_rgb(self, masks, image_norm):
191 | # visualising masks
192 | masks_conf, masks_idx = torch.max(masks, 1)
193 | masks_conf = masks_conf - F.relu(masks_conf - 1, 0)
194 |
195 | masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), masks_conf.cpu())
196 | return 0.3 * image_norm + 0.7 * masks_idx_rgb
197 |
198 | def _init_norm(self):
199 | self.trainloader.dataset.set_norm(self.enc.normalize)
200 | self.valloader.dataset.set_norm(self.enc.normalize)
201 | self.trainloader_val.dataset.set_norm(self.enc.normalize)
202 |
203 | def _apply_cmap(self, mask_idx, mask_conf):
204 | palette = self.trainloader.dataset.get_palette()
205 |
206 | masks = []
207 | col = Colorize()
208 | mask_conf = mask_conf.float() / 255.0
209 | for mask, conf in zip(mask_idx.split(1), mask_conf.split(1)):
210 | m = col(mask).float()
211 | m = m * conf
212 | masks.append(m[None, ...])
213 |
214 | return torch.cat(masks, 0)
215 |
216 | def validation(self, epoch, writer, loader, checkpoint=False):
217 |
218 | stat = StatManager()
219 |
220 | # Fast test during the training
221 | def eval_batch(image, gt_labels):
222 |
223 | losses, cls, masks, mask_logits = \
224 | self.step(epoch, image, gt_labels, train=False, visualise=False)
225 |
226 | for loss_key, loss_val in losses.items():
227 | stat.update_stats(loss_key, loss_val)
228 |
229 | return cls.cpu(), masks, mask_logits.cpu()
230 |
231 | self.enc.eval()
232 |
233 | # class ground truth
234 | targets_all = []
235 |
236 | # class predictions
237 | preds_all = []
238 |
239 | def add_stats(means, stds, x):
240 | means.append(x.mean())
241 | stds.append(x.std())
242 |
243 | for n, (image, gt_labels, _) in enumerate(loader):
244 |
245 | with torch.no_grad():
246 | cls_raw, masks_all, mask_logits = eval_batch(image, gt_labels)
247 |
248 | cls_sigmoid = torch.sigmoid(cls_raw).numpy()
249 |
250 | preds_all.append(cls_sigmoid)
251 | targets_all.append(gt_labels.cpu().numpy())
252 |
253 | #
254 | # classification
255 | #
256 | targets_stacked = np.vstack(targets_all)
257 | preds_stacked = np.vstack(preds_all)
258 | aps = average_precision_score(targets_stacked, preds_stacked, average=None)
259 |
260 | # skip BG AP
261 | offset = self.nclass - aps.size
262 | assert offset == 1, 'Class number mismatch'
263 |
264 | classNames = self.classNames[offset:]
265 | for ni, className in enumerate(classNames):
266 | writer.add_scalar('%02d_%s/AP' % (ni + offset, className), aps[ni], epoch)
267 | print("AP_{}: {:4.3f}".format(className, aps[ni]))
268 |
269 | meanAP = np.mean(aps)
270 | writer.add_scalar('all_wo_BG/mAP', meanAP, epoch)
271 | print('mAP: {:4.3f}'.format(meanAP))
272 |
273 | # total classification loss
274 | for stat_key in stat.vals.keys():
275 | writer.add_scalar('all/{}'.format(stat_key), stat.summarize_key(stat_key), epoch)
276 |
277 | if checkpoint and epoch >= cfg.TRAIN.PRETRAIN:
278 | # we will use mAP - mask_loss as our proxy score
279 | # to save the best checkpoint so far
280 | proxy_score = 1 - stat.summarize_key("loss")
281 | writer.add_scalar('all/checkpoint_score', proxy_score, epoch)
282 | self.checkpoint_best(proxy_score, epoch)
283 |
284 | def _visualise(self, epoch, image, masks, mask_logits, cls_out, gt_labels):
285 | image_norm = self.denorm(image.clone()).cpu()
286 | visual = [image_norm]
287 |
288 | if "cam" in masks:
289 | visual.append(self._mask_rgb(masks["cam"], image_norm))
290 |
291 | if "dec" in masks:
292 | visual.append(self._mask_rgb(masks["dec"], image_norm))
293 |
294 | if "pseudo" in masks:
295 | pseudo_gt_rgb = self._mask_rgb(masks["pseudo"], image_norm)
296 |
297 | # cancel ambiguous
298 | ambiguous = 1 - masks["pseudo"].sum(1, keepdim=True).cpu()
299 | pseudo_gt_rgb = ambiguous * image_norm + (1 - ambiguous) * pseudo_gt_rgb
300 | visual.append(pseudo_gt_rgb)
301 |
302 | # ready to assemble
303 | visual_logits = torch.cat(visual, -1)
304 | self._visualise_grid(visual_logits, gt_labels, epoch, scores=cls_out)
305 |
306 | if __name__ == "__main__":
307 | args = get_arguments(sys.argv[1:])
308 |
309 | # Reading the config
310 | cfg_from_file(args.cfg_file)
311 | if args.set_cfgs is not None:
312 | cfg_from_list(args.set_cfgs)
313 |
314 | print("Config: \n", cfg)
315 |
316 | trainer = DecTrainer(args)
317 | torch.manual_seed(0)
318 |
319 | timer = Timer()
320 | def time_call(func, msg, *args, **kwargs):
321 | timer.reset_stage()
322 | func(*args, **kwargs)
323 | print(msg + (" {:3.2}m".format(timer.get_stage_elapsed() / 60.)))
324 |
325 | for epoch in range(trainer.start_epoch, cfg.TRAIN.NUM_EPOCHS + 1):
326 | print("Epoch >>> ", epoch)
327 |
328 | log_int = 5 if DEBUG else 2
329 | if epoch % log_int == 0:
330 | with torch.no_grad():
331 | if not DEBUG:
332 | time_call(trainer.validation, "Validation / Train: ", epoch, trainer.writer, trainer.trainloader_val)
333 | time_call(trainer.validation, "Validation / Val: ", epoch, trainer.writer_val, trainer.valloader, checkpoint=True)
334 |
335 | time_call(trainer.train_epoch, "Train epoch: ", epoch)
336 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/visinf/1-stage-wseg/d905fca1134d5c33551422a76d82d8b0f00c48cc/utils/__init__.py
--------------------------------------------------------------------------------
/utils/checkpoints.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | class Checkpoint(object):
6 |
7 | def __init__(self, path, max_n=3):
8 | self.path = path
9 | self.max_n = max_n
10 | self.models = {}
11 | self.checkpoints = []
12 |
13 | def add_model(self, name, model, opt=None):
14 | assert not name in self.models, "Model {} already added".format(name)
15 |
16 | self.models[name] = {}
17 | self.models[name]['model'] = model
18 | self.models[name]['opt'] = opt
19 |
20 | def limit(self):
21 | return self.max_n
22 |
23 | def add_checkpoints(self, name=None):
24 | # searching for names
25 | fns = os.listdir(self.path)
26 | fns = filter(lambda x: x[-4:] == '.pth', fns)
27 |
28 | names = {}
29 | for fn in fns:
30 | sfx = fn.split("_")[-1].rstrip('.pth')
31 | path = self._get_full_path(fn)
32 | if not sfx in names:
33 | names[sfx] = os.path.getmtime(path)
34 | else:
35 | names[sfx] = max(names[sfx], os.path.getmtime(path))
36 |
37 | # assembling
38 | names_and_time = []
39 | for sfx, time in names.items():
40 | exists, paths = self.find(sfx)
41 | if exists:
42 | names_and_time.append((sfx, time))
43 |
44 | # if there are more checkpoints
45 | # than we can handle, remove the older ones
46 | # but do not remove them (for safety)
47 | if len(names_and_time) > self.max_n:
48 | names_and_time = sorted(names_and_time, \
49 | key=lambda x: x[1], \
50 | reverse=False)
51 | new_checkpoints = []
52 | for key in names_and_time[-self.max_n:]:
53 | new_checkpoints.append(key[0])
54 |
55 | self.checkpoints = new_checkpoints
56 |
57 | def __len__(self):
58 | return len(self.checkpoints)
59 |
60 | def _get_full_path(self, filename):
61 | return os.path.join(self.path, filename)
62 |
63 | def clean(self, n_remove):
64 |
65 | n_remove = min(n_remove, len(self.checkpoints))
66 |
67 | for i in range(n_remove):
68 | sfx = self.checkpoints[i]
69 |
70 | for name, data in self.models.items():
71 | for d in ('model', 'opt'):
72 | fn = self._filename(d, name, sfx)
73 | self._rm(fn)
74 |
75 | removed = self.checkpoints[:n_remove]
76 | self.checkpoints = self.checkpoints[n_remove:]
77 | return removed
78 |
79 | def _rm(self, fn):
80 | path = self._get_full_path(fn)
81 | if os.path.isfile(path):
82 | os.remove(path)
83 |
84 | def _filename(self, d, name, suffix):
85 | return "{}_{}_{}.pth".format(d, name, suffix)
86 |
87 | def load(self, suffix):
88 | if suffix is None:
89 | return False
90 |
91 | found, paths = self.find(suffix)
92 | if not found:
93 | return False
94 |
95 | # loading
96 | for name, data in self.models.items():
97 | for d in ('model', 'opt'):
98 | if data[d] is not None:
99 | data[d].load_state_dict(torch.load(paths[name][d]))
100 |
101 | return True
102 |
103 | def find(self, suffix, force=False):
104 | paths = {}
105 | found = True
106 | for name, data in self.models.items():
107 | paths[name] = {}
108 | for d in ('model', 'opt'):
109 | fn = self._filename(d, name, suffix)
110 | path = self._get_full_path(fn)
111 | paths[name][d] = path
112 | if not os.path.isfile(path):
113 | print("File not found: ", path)
114 | if d == 'model':
115 | found = False
116 |
117 | if found and not suffix in self.checkpoints:
118 | if len(self.checkpoints) < self.max_n or force:
119 | self.checkpoints.insert(0, suffix)
120 | if force:
121 | self.max_n = max(self.max_n, len(self.checkpoints))
122 |
123 | return found, paths
124 |
125 | def checkpoint(self, suffix):
126 | assert not '_' in suffix, "Underscores are not allowed"
127 |
128 | self.checkpoints.append(suffix)
129 |
130 | for name, data in self.models.items():
131 | for d in ('model', 'opt'):
132 | fn = self._filename(d, name, suffix)
133 | path = self._get_full_path(fn)
134 | if not os.path.isfile(path) and data[d] is not None:
135 | torch.save(data[d].state_dict(), path)
136 |
137 | # removing
138 | n_remove = max(0, len(self.checkpoints) - self.max_n)
139 | removed = self.clean(n_remove)
140 |
141 | return removed
142 |
--------------------------------------------------------------------------------
/utils/collections.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | """A simple attribute dictionary used for representing configuration options."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from __future__ import unicode_literals
22 |
23 |
24 | class AttrDict(dict):
25 |
26 | IMMUTABLE = '__immutable__'
27 |
28 | def __init__(self, *args, **kwargs):
29 | super(AttrDict, self).__init__(*args, **kwargs)
30 | self.__dict__[AttrDict.IMMUTABLE] = False
31 |
32 | def __getattr__(self, name):
33 | if name in self.__dict__:
34 | return self.__dict__[name]
35 | elif name in self:
36 | return self[name]
37 | else:
38 | raise AttributeError(name)
39 |
40 | def __setattr__(self, name, value):
41 | if not self.__dict__[AttrDict.IMMUTABLE]:
42 | if name in self.__dict__:
43 | self.__dict__[name] = value
44 | else:
45 | self[name] = value
46 | else:
47 | raise AttributeError(
48 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'.
49 | format(name, value)
50 | )
51 |
52 | def immutable(self, is_immutable):
53 | """Set immutability to is_immutable and recursively apply the setting
54 | to all nested AttrDicts.
55 | """
56 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable
57 | # Recursively set immutable state
58 | for v in self.__dict__.values():
59 | if isinstance(v, AttrDict):
60 | v.immutable(is_immutable)
61 | for v in self.values():
62 | if isinstance(v, AttrDict):
63 | v.immutable(is_immutable)
64 |
65 | def is_immutable(self):
66 | return self.__dict__[AttrDict.IMMUTABLE]
67 |
--------------------------------------------------------------------------------
/utils/dcrf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pydensecrf.densecrf as dcrf
3 | from pydensecrf.utils import unary_from_softmax
4 |
5 |
6 | def crf_inference(img, probs, t=10, scale_factor=1, labels=21):
7 |
8 | h, w = img.shape[:2]
9 | n_labels = labels
10 |
11 | d = dcrf.DenseCRF2D(w, h, n_labels)
12 |
13 | unary = unary_from_softmax(probs)
14 | unary = np.ascontiguousarray(unary)
15 |
16 | d.setUnaryEnergy(unary)
17 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3)
18 | d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10)
19 | Q = d.inference(t)
20 |
21 | return np.array(Q).reshape((n_labels, h, w))
22 |
--------------------------------------------------------------------------------
/utils/inference_tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import scipy.misc
5 |
6 | import torch.nn.functional as F
7 |
8 | from PIL import Image
9 | from utils.dcrf import crf_inference
10 |
11 | from datasets.pascal_voc_ms import MultiscaleLoader, CropLoader
12 |
13 | class ResultWriter:
14 |
15 | def __init__(self, cfg, palette, out_path, verbose=True):
16 | self.cfg = cfg
17 | self.palette = palette
18 | self.root = out_path
19 | self.verbose = verbose
20 |
21 | def _mask_overlay(self, mask, image, alpha=0.3):
22 | """Creates an overlayed mask visualisation"""
23 | mask_rgb = self.__mask2rgb(mask)
24 | return alpha * image + (1 - alpha) * mask_rgb
25 |
26 | def __mask2rgb(self, mask):
27 | im = Image.fromarray(mask).convert("P")
28 | im.putpalette(self.palette)
29 | mask_rgb = np.array(im.convert("RGB"), dtype=np.float)
30 | return mask_rgb / 255.
31 |
32 | def _merge_masks(self, masks, labels, pads):
33 | """Combines masks at multiple scales
34 |
35 | Args:
36 | masks: list of masks obtained at different scales
37 | (already scaled to the original)
38 | Returns:
39 | pred: combined single mask
40 | pred_crf: refined mask with CRF
41 | """
42 | raise NotImplementedError
43 |
44 | def save(self, img_path, img_orig, all_masks, labels, pads, gt_mask):
45 |
46 | img_name = os.path.basename(img_path).rstrip(".jpg")
47 |
48 | # converting original image to [0, 255]
49 | img_orig255 = np.round(255. * img_orig).astype(np.uint8)
50 | img_orig255 = np.transpose(img_orig255, [1,2,0])
51 | img_orig255 = np.ascontiguousarray(img_orig255)
52 |
53 | merged_mask = self._merge_masks(all_masks, pads, labels, img_orig255.shape[:2])
54 | pred = np.argmax(merged_mask, 0)
55 |
56 | # CRF
57 | pred_crf = crf_inference(img_orig255, merged_mask, t=10, scale_factor=1, labels=21)
58 | pred_crf = np.argmax(pred_crf, 0)
59 |
60 | filepath = os.path.join(self.root, img_name + '.png')
61 | scipy.misc.imsave(filepath, pred.astype(np.uint8))
62 |
63 | filepath = os.path.join(self.root, "crf", img_name + '.png')
64 | scipy.misc.imsave(filepath, pred_crf.astype(np.uint8))
65 |
66 | if self.verbose:
67 | mask_gt = gt_mask.numpy()
68 | masks_all = np.concatenate([pred, pred_crf, mask_gt], 1).astype(np.uint8)
69 | images = np.concatenate([img_orig]*3, 2)
70 | images = np.transpose(images, [1,2,0])
71 |
72 | overlay = self._mask_overlay(masks_all, images)
73 | filepath = os.path.join(self.root, "vis", img_name + '.png')
74 | overlay255 = np.round(overlay * 255.).astype(np.uint8)
75 | scipy.misc.imsave(filepath, overlay255)
76 |
77 | class MergeMultiScale(ResultWriter):
78 |
79 | def _cut(self, x_chw, pads):
80 | pad_h, pad_w, h, w = [int(p) for p in pads]
81 | return x_chw[:, pad_h:(pad_h + h), pad_w:(pad_w + w)]
82 |
83 | def _merge_masks(self, masks, labels, pads, imsize_hw):
84 |
85 | mask_list = []
86 | for i, mask in enumerate(masks.split(1, dim=0)):
87 |
88 | # removing the padding
89 | mask_cut = self._cut(mask[0], pads[i]).unsqueeze(0)
90 |
91 | # normalising the scale
92 | mask_cut = F.interpolate(mask_cut, imsize_hw, mode='bilinear', align_corners=False)[0]
93 |
94 | # flipping if necessary
95 | if self.cfg.FLIP and i % 2 == 1:
96 | mask_cut = torch.flip(mask_cut, (-1, ))
97 |
98 | # getting the max response
99 | mask_cut[1:, ::] *= labels[:, None, None]
100 | mask_list.append(mask_cut)
101 |
102 | mean_mask = sum(mask_list).numpy() / len(mask_list)
103 |
104 | # discounting BG
105 | #mean_mask[0, ::] *= 0.5
106 | mean_mask[0, ::] = np.power(mean_mask[0, ::], self.cfg.BG_POW)
107 |
108 | return mean_mask
109 |
110 | class MergeCrops(ResultWriter):
111 |
112 | def _cut(self, x_chw, pads):
113 | pad_h, pad_w, h, w = [int(p) for p in pads]
114 | return x_chw[:, pad_h:(pad_h + h), pad_w:(pad_w + w)]
115 |
116 | def _merge_masks(self, masks, labels, coords, imsize_hw):
117 | num_classes = masks.size(1)
118 |
119 | masks_sum = torch.zeros([num_classes, *imsize_hw]).type_as(masks)
120 | counts = torch.zeros(imsize_hw).type_as(masks)
121 |
122 | for ii, (mask, pads) in enumerate(zip(masks.split(1), coords.split(1))):
123 |
124 | mask = mask[0]
125 | s_h, e_h, s_w, e_w = pads[0][:4]
126 | pad_t, pad_l = pads[0][4:]
127 |
128 | if self.cfg.FLIP and ii % 2 == 0:
129 | mask = mask.flip(-1)
130 |
131 | # crop mask, if needed
132 | m_h = 0 if s_h > 0 else pad_t
133 | m_w = 0 if s_w > 0 else pad_l
134 |
135 | # due to padding
136 | # end point is shifted
137 | s_h = max(0, s_h - pad_t)
138 | s_w = max(0, s_w - pad_l)
139 | e_h = min(e_h - pad_t, imsize_hw[0])
140 | e_w = min(e_w - pad_l, imsize_hw[1])
141 |
142 | m_he = m_h + e_h - s_h
143 | m_we = m_w + e_w - s_w
144 |
145 | masks_sum[:, s_h:e_h, s_w:e_w] += mask[:, m_h:m_he, m_w:m_we]
146 | counts[s_h:e_h, s_w:e_w] += 1
147 |
148 | assert torch.all(counts > 0)
149 |
150 | # removing false pasitives
151 | masks_sum[1:, ::] *= labels[:, None, None]
152 |
153 | # removing the padding
154 | return (masks_sum / counts).numpy()
155 |
156 | class PAMRWriter(ResultWriter):
157 |
158 | def save_batch(self, img_paths, imgs, all_masks, all_gt_masks):
159 |
160 | for b, img_path in enumerate(img_paths):
161 |
162 | img_name = os.path.basename(img_path).rstrip(".jpg")
163 | img_orig = imgs[b]
164 | gt_mask = all_gt_masks[b]
165 |
166 | # converting original image to [0, 255]
167 | img_orig255 = np.round(255. * img_orig).astype(np.uint8)
168 | img_orig255 = np.transpose(img_orig255, [1,2,0])
169 | img_orig255 = np.ascontiguousarray(img_orig255)
170 |
171 | mask_gt = torch.argmax(gt_mask, 0)
172 |
173 | # cancel ambiguous
174 | ambiguous = gt_mask.sum(0) == 0
175 | mask_gt[ambiguous] = 255
176 | mask_gt = mask_gt.numpy()
177 |
178 | # saving GT
179 | image_hwc = np.transpose(img_orig, [1,2,0])
180 | overlay_gt = self._mask_overlay(mask_gt.astype(np.uint8), image_hwc, alpha=0.5)
181 |
182 | filepath = os.path.join(self.root, img_name + '_gt.png')
183 | overlay255 = np.round(overlay_gt * 255.).astype(np.uint8)
184 | scipy.misc.imsave(filepath, overlay255)
185 |
186 | for it, mask_batch in enumerate(all_masks):
187 |
188 | mask = mask_batch[b]
189 | mask_idx = torch.argmax(mask, 0)
190 |
191 | # cancel ambiguous
192 | ambiguous = mask.sum(0) == 0
193 | mask_idx[ambiguous] = 255
194 |
195 | overlay = self._mask_overlay(mask_idx.numpy().astype(np.uint8), image_hwc, alpha=0.5)
196 |
197 | filepath = os.path.join(self.root, img_name + '_{:02d}.png'.format(it))
198 | overlay255 = np.round(overlay * 255.).astype(np.uint8)
199 | scipy.misc.imsave(filepath, overlay255)
200 |
201 |
202 | def get_inference_io(method_name):
203 |
204 | if method_name == "multiscale":
205 | return MergeMultiScale, MultiscaleLoader
206 | elif method_name == "multicrop":
207 | return MergeCrops, CropLoader
208 | else:
209 | raise NotImplementedError("Method {} is unknown".format(method_name))
210 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Hang Zhang
3 | ## ECE Department, Rutgers University
4 | ## Email: zhang.hang@rutgers.edu
5 | ## Copyright (c) 2017
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
11 | import threading
12 | import numpy as np
13 | import torch
14 |
15 | from sklearn.metrics import average_precision_score
16 |
17 | class Metric(object):
18 |
19 | # synonyms
20 | IoU = "IoU"
21 | MaskIoU = "IoU"
22 |
23 | Precision = "Precision"
24 | Recall = "Recall"
25 | ClassAP = "ClassAP"
26 |
27 | def __init__(self):
28 | self.data = {}
29 | self.count = {}
30 | self.fn = {}
31 |
32 | # initising the functions
33 | self.fn[Metric.MaskIoU] = self.mask_iou_
34 | self.fn[Metric.Precision] = self.precision_
35 | self.fn[Metric.Recall] = self.recall_
36 | self.fn[Metric.ClassAP] = self.class_ap_
37 |
38 | def add_metric(self, m):
39 | assert m in self.fn, "Unknown metric with key {}".format(m)
40 |
41 | self.data[m] = 0.
42 | self.count[m] = 0.
43 |
44 | def metrics(self):
45 | return self.data.keys()
46 |
47 | def print_summary(self):
48 |
49 | keys_sorted = sorted(self.data.keys())
50 |
51 | for m in keys_sorted:
52 | print("{}: {:5.4f}".format(m, self.summarize(m)))
53 |
54 | def reset_stat(self, m=None):
55 |
56 | if m is None:
57 | # resetting everything
58 | for m in self.data:
59 | self.data[m] = 0.
60 | self.count[m] = 0.
61 | else:
62 | assert m in self.fn, "Unknown metric with key {}".format(m)
63 |
64 | self.data[m] = 0.
65 | self.count[m] = 0.
66 |
67 | def update_value(self, m, value, count=1.):
68 |
69 | self.data[m] += value
70 | self.count[m] += count
71 |
72 | def update(self, gt, pred):
73 |
74 | for m in self.data:
75 | self.data[m] += self.fn[m](gt, pred)
76 | self.count[m] += 1.
77 |
78 | def merge(self, metric):
79 |
80 | for m in metric.data:
81 | if not m in self.data:
82 | self.reset_stat(m)
83 |
84 | self.update_value(m, metric.data[m], metric.count[m])
85 |
86 | def merge_summary(self, metric):
87 |
88 | for m in metric.data:
89 | if not m in self.data:
90 | self.reset_stat(m)
91 |
92 | mean_value = metric.summarize(m)
93 | self.update_value(m, mean_value, 1.)
94 |
95 | def summarize(self, m):
96 | if not m in self.count or self.count[m] == 0.:
97 | return 0.
98 |
99 | return self.data[m] / self.count[m]
100 |
101 | @staticmethod
102 | def mask_iou_(a, b):
103 | # computing the mask IoU
104 | isc = (a * b).sum()
105 | unn = (a + b).sum()
106 | z = unn - isc
107 |
108 | if z == 0.:
109 | return 0.
110 |
111 | return isc / z
112 |
113 | @staticmethod
114 | def precision_(gt, p):
115 | # computing the mask IoU
116 | acc = (gt * p).sum()
117 | sss = p.sum()
118 |
119 | if sss == 0.:
120 | return 0.
121 |
122 | return acc / sss
123 |
124 | @staticmethod
125 | def recall_(gt, p):
126 | # computing the mask IoU
127 | acc = (gt * p).sum()
128 | sss = gt.sum()
129 |
130 | if sss == 0.:
131 | return 0.
132 |
133 | return acc / sss
134 |
135 | @staticmethod
136 | def class_ap_(gt, p):
137 |
138 | # this return AP for each class
139 | ap = average_precision_score(gt, p, average=None)
140 |
141 | # return the average
142 | return np.mean(ap[gt.sum(0) > 0])
143 |
144 |
145 | def compute_jaccard(preds_masks_all, targets_masks_all, num_classes=21):
146 |
147 | tps = np.zeros((num_classes, ))
148 | fps = np.zeros((num_classes, ))
149 | fns = np.zeros((num_classes, ))
150 | counts = np.zeros((num_classes, ))
151 |
152 | for mask_pred, mask_gt in zip(preds_masks_all, targets_masks_all):
153 |
154 | bs, h, w = mask_pred.size()
155 | assert bs == mask_gt.size(0), "Batch size mismatch"
156 | assert h == mask_gt.size(1), "Width mismatch"
157 | assert w == mask_gt.size(2), "Height mismatch"
158 |
159 | mask_pred = mask_pred.view(bs, 1, -1)
160 | mask_gt = mask_gt.view(bs, 1, -1)
161 |
162 | # ignore ambiguous
163 | mask_pred[mask_gt == 255] = 255
164 |
165 | for label in range(num_classes):
166 | mask_pred_ = (mask_pred == label).float()
167 | mask_gt_ = (mask_gt == label).float()
168 |
169 | tps[label] += (mask_pred_ * mask_gt_).sum().item()
170 | diff = mask_pred_ - mask_gt_
171 | fps[label] += np.maximum(0., diff).float().sum().item()
172 | fns[label] += np.maximum(0., -diff).float().sum().item()
173 |
174 | jaccards = [None]*num_classes
175 | precision = [None]*num_classes
176 | recall = [None]*num_classes
177 | for i in range(num_classes):
178 | tp = tps[i]
179 | fn = fns[i]
180 | fp = fps[i]
181 | jaccards[i] = tp / max(1e-3, fn + fp + tp)
182 | precision[i] = tp / max(1e-3, tp + fp)
183 | recall[i] = tp / max(1e-3, tp + fn)
184 |
185 | return jaccards, precision, recall
186 |
187 |
--------------------------------------------------------------------------------
/utils/pallete.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Hang Zhang
3 | ## ECE Department, Rutgers University
4 | ## Email: zhang.hang@rutgers.edu
5 | ## Copyright (c) 2017
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
11 | from PIL import Image
12 |
13 | def get_mask_pallete(npimg, dataset='detail'):
14 | """Get image color pallete for visualizing masks"""
15 | # recovery boundary
16 | if dataset == 'pascal_voc':
17 | npimg[npimg==21] = 255
18 | # put colormap
19 | out_img = Image.fromarray(npimg.squeeze().astype('uint8'))
20 | if dataset == 'ade20k':
21 | out_img.putpalette(adepallete)
22 | elif dataset == 'cityscapes':
23 | out_img.putpalette(citypallete)
24 | elif dataset in ('detail', 'pascal_voc', 'pascal_aug'):
25 | out_img.putpalette(vocpallete)
26 | return out_img
27 |
28 | def _get_voc_pallete(num_cls):
29 | n = num_cls
30 | pallete = [0]*(n*3)
31 | for j in range(0,n):
32 | lab = j
33 | pallete[j*3+0] = 0
34 | pallete[j*3+1] = 0
35 | pallete[j*3+2] = 0
36 | i = 0
37 | while (lab > 0):
38 | pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i))
39 | pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i))
40 | pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i))
41 | i = i + 1
42 | lab >>= 3
43 | return pallete
44 |
45 | vocpallete = _get_voc_pallete(256)
46 |
47 | adepallete = [0,0,0,120,120,120,180,120,120,6,230,230,80,50,50,4,200,3,120,120,80,140,140,140,204,5,255,230,230,230,4,250,7,224,5,255,235,255,7,150,5,61,120,120,70,8,255,51,255,6,82,143,255,140,204,255,4,255,51,7,204,70,3,0,102,200,61,230,250,255,6,51,11,102,255,255,7,71,255,9,224,9,7,230,220,220,220,255,9,92,112,9,255,8,255,214,7,255,224,255,184,6,10,255,71,255,41,10,7,255,255,224,255,8,102,8,255,255,61,6,255,194,7,255,122,8,0,255,20,255,8,41,255,5,153,6,51,255,235,12,255,160,150,20,0,163,255,140,140,140,250,10,15,20,255,0,31,255,0,255,31,0,255,224,0,153,255,0,0,0,255,255,71,0,0,235,255,0,173,255,31,0,255,11,200,200,255,82,0,0,255,245,0,61,255,0,255,112,0,255,133,255,0,0,255,163,0,255,102,0,194,255,0,0,143,255,51,255,0,0,82,255,0,255,41,0,255,173,10,0,255,173,255,0,0,255,153,255,92,0,255,0,255,255,0,245,255,0,102,255,173,0,255,0,20,255,184,184,0,31,255,0,255,61,0,71,255,255,0,204,0,255,194,0,255,82,0,10,255,0,112,255,51,0,255,0,194,255,0,122,255,0,255,163,255,153,0,0,255,10,255,112,0,143,255,0,82,0,255,163,255,0,255,235,0,8,184,170,133,0,255,0,255,92,184,0,255,255,0,31,0,184,255,0,214,255,255,0,112,92,255,0,0,224,255,112,224,255,70,184,160,163,0,255,153,0,255,71,255,0,255,0,163,255,204,0,255,0,143,0,255,235,133,255,0,255,0,235,245,0,255,255,0,122,255,245,0,10,190,212,214,255,0,0,204,255,20,0,255,255,255,0,0,153,255,0,41,255,0,255,204,41,0,255,41,255,0,173,0,255,0,245,255,71,0,255,122,0,255,0,255,184,0,92,255,184,255,0,0,133,255,255,214,0,25,194,194,102,255,0,92,0,255]
48 |
49 | citypallete = [
50 | 128,64,128,244,35,232,70,70,70,102,102,156,190,153,153,153,153,153,250,170,30,220,220,0,107,142,35,152,251,152,70,130,180,220,20,60,255,0,0,0,0,142,0,0,70,0,60,100,0,80,100,0,0,230,119,11,32,128,192,0,0,64,128,128,64,128,0,192,128,128,192,128,64,64,0,192,64,0,64,192,0,192,192,0,64,64,128,192,64,128,64,192,128,192,192,128,0,0,64,128,0,64,0,128,64,128,128,64,0,0,192,128,0,192,0,128,192,128,128,192,64,0,64,192,0,64,64,128,64,192,128,64,64,0,192,192,0,192,64,128,192,192,128,192,0,64,64,128,64,64,0,192,64,128,192,64,0,64,192,128,64,192,0,192,192,128,192,192,64,64,64,192,64,64,64,192,64,192,192,64,64,64,192,192,64,192,64,192,192,192,192,192,32,0,0,160,0,0,32,128,0,160,128,0,32,0,128,160,0,128,32,128,128,160,128,128,96,0,0,224,0,0,96,128,0,224,128,0,96,0,128,224,0,128,96,128,128,224,128,128,32,64,0,160,64,0,32,192,0,160,192,0,32,64,128,160,64,128,32,192,128,160,192,128,96,64,0,224,64,0,96,192,0,224,192,0,96,64,128,224,64,128,96,192,128,224,192,128,32,0,64,160,0,64,32,128,64,160,128,64,32,0,192,160,0,192,32,128,192,160,128,192,96,0,64,224,0,64,96,128,64,224,128,64,96,0,192,224,0,192,96,128,192,224,128,192,32,64,64,160,64,64,32,192,64,160,192,64,32,64,192,160,64,192,32,192,192,160,192,192,96,64,64,224,64,64,96,192,64,224,192,64,96,64,192,224,64,192,96,192,192,224,192,192,0,32,0,128,32,0,0,160,0,128,160,0,0,32,128,128,32,128,0,160,128,128,160,128,64,32,0,192,32,0,64,160,0,192,160,0,64,32,128,192,32,128,64,160,128,192,160,128,0,96,0,128,96,0,0,224,0,128,224,0,0,96,128,128,96,128,0,224,128,128,224,128,64,96,0,192,96,0,64,224,0,192,224,0,64,96,128,192,96,128,64,224,128,192,224,128,0,32,64,128,32,64,0,160,64,128,160,64,0,32,192,128,32,192,0,160,192,128,160,192,64,32,64,192,32,64,64,160,64,192,160,64,64,32,192,192,32,192,64,160,192,192,160,192,0,96,64,128,96,64,0,224,64,128,224,64,0,96,192,128,96,192,0,224,192,128,224,192,64,96,64,192,96,64,64,224,64,192,224,64,64,96,192,192,96,192,64,224,192,192,224,192,32,32,0,160,32,0,32,160,0,160,160,0,32,32,128,160,32,128,32,160,128,160,160,128,96,32,0,224,32,0,96,160,0,224,160,0,96,32,128,224,32,128,96,160,128,224,160,128,32,96,0,160,96,0,32,224,0,160,224,0,32,96,128,160,96,128,32,224,128,160,224,128,96,96,0,224,96,0,96,224,0,224,224,0,96,96,128,224,96,128,96,224,128,224,224,128,32,32,64,160,32,64,32,160,64,160,160,64,32,32,192,160,32,192,32,160,192,160,160,192,96,32,64,224,32,64,96,160,64,224,160,64,96,32,192,224,32,192,96,160,192,224,160,192,32,96,64,160,96,64,32,224,64,160,224,64,32,96,192,160,96,192,32,224,192,160,224,192,96,96,64,224,96,64,96,224,64,224,224,64,96,96,192,224,96,192,96,224,192,0,0,0]
51 |
52 |
--------------------------------------------------------------------------------
/utils/stat_manager.py:
--------------------------------------------------------------------------------
1 |
2 | class StatManager(object):
3 |
4 | def __init__(self):
5 | self.func_keys = {}
6 | self.vals = {}
7 | self.vals_count = {}
8 | self.formats = {}
9 |
10 | def reset(self):
11 | for k in self.vals:
12 | self.vals[k] = 0.0
13 | self.vals_count[k] = 0.0
14 |
15 | def add_val(self, key, form="{:4.3f}"):
16 | self.vals[key] = 0.0
17 | self.vals_count[key] = 0.0
18 | self.formats[key] = form
19 |
20 | def get_val(self, key):
21 | return self.vals[key], self.vals_count[key]
22 |
23 | def add_compute(self, key, func, form="{:4.3f}"):
24 | self.func_keys[key] = func
25 | self.add_val(key)
26 | self.formats[key] = form
27 |
28 | def update_stats(self, key, val, count = 1):
29 | if not key in self.vals:
30 | self.add_val(key)
31 |
32 | self.vals[key] += val
33 | self.vals_count[key] += count
34 |
35 | def compute_stats(self, a, b, size = 1):
36 |
37 | for k, func in self.func_keys.iteritems():
38 | self.vals[k] += func(a, b)
39 | self.vals_count[k] += size
40 |
41 | def has_vals(self, k):
42 | if not k in self.vals_count:
43 | return False
44 | return self.vals_count[k] > 0
45 |
46 | def summarize_key(self, k):
47 | if self.has_vals(k):
48 | return self.vals[k] / self.vals_count[k]
49 | else:
50 | return 0
51 |
52 | def summarize(self, epoch = 0, verbose = True):
53 |
54 | if verbose:
55 | out = "\tEpoch[{:03d}]".format(epoch)
56 | for k in self.vals:
57 | if self.has_vals(k):
58 | out += (" / {} " + self.formats[k]).format(k, self.summarize_key(k))
59 | print(out)
60 |
--------------------------------------------------------------------------------
/utils/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | class Timer:
4 | def __init__(self, starting_msg = None):
5 | self.start = time.time()
6 | self.stage_start = self.start
7 |
8 | if starting_msg is not None:
9 | print(starting_msg, time.ctime(time.time()))
10 |
11 |
12 | def update_progress(self, progress):
13 | self.elapsed = time.time() - self.start
14 | self.est_total = self.elapsed / progress
15 | self.est_remaining = self.est_total - self.elapsed
16 | self.est_finish = int(self.start + self.est_total)
17 |
18 |
19 | def str_est_finish(self):
20 | return str(time.ctime(self.est_finish))
21 |
22 | def get_stage_elapsed(self):
23 | return time.time() - self.stage_start
24 |
25 | def reset_stage(self):
26 | self.stage_start = time.time()
27 |
--------------------------------------------------------------------------------