├── .gitignore ├── LICENSE ├── README.md ├── datasets ├── datasets.py ├── grouped_batch_sampler.py ├── mnist.py └── pascal.py ├── images └── mist_pipeline.png ├── json ├── mnist.json └── pascal.json ├── mist_test.py ├── mist_train.py ├── models ├── classifier.py ├── detector.py ├── mist.py ├── net.py └── spatialtransform.py ├── system └── conda_mist.yaml └── utils ├── config_utils.py ├── loss_functions.py ├── summary.py ├── torch_utils.py ├── utils.py └── viz_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | .vscode/ 127 | 128 | dataset/ 129 | logs/ 130 | pretrained_models/ 131 | jobs/ 132 | pretrained/ 133 | val_results/ 134 | test_results/ 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MIST: Multiple Instance Spatial Transformer Network 2 | ### Baptiste Angles, Yuhe Jin, Simon Kornblith, Andrea Tagliasacchi, Kwang Moo Yi 3 | This repository contains training and inference code for [MIST: Multiple Instance Spatial Transformer Network](https://arxiv.org/abs/1811.10725). 4 | 5 | ![alt text](https://github.com/ubc-vision/mist/blob/main/images/mist_pipeline.png) 6 | ## Installation 7 | This code is implemented based on PyTorch. A conda environment is provided with all the dependencies: 8 | ``` 9 | conda env create -f system/conda_mist.yaml 10 | ``` 11 | ## Pretrained models and datasets 12 | Two pretrained models are provided for MNIST dataset and trimmed Pascal+COCO dataset respectively. 13 | Models download path: 14 | ``` 15 | mkdir pretrained_models 16 | wget https://www.cs.ubc.ca/research/kmyi_data/files/2021/mist/mnist_best_models -P ./pretrained_models/ 17 | wget https://www.cs.ubc.ca/research/kmyi_data/files/2021/mist/pascal_coco_best_models -P ./pretrained_models/ 18 | ``` 19 | Dataset download path: 20 | ``` 21 | mkdir dataset 22 | wget https://www.cs.ubc.ca/research/kmyi_data/files/2021/mist/mnist_hard.zip -P ./dataset/ 23 | wget https://www.cs.ubc.ca/research/kmyi_data/files/2021/mist/VOC_pascal_coco_v2.zip -P ./dataset/ 24 | unzip ./dataset/mnist_hard.zip -d ./dataset/ 25 | unzip ./dataset/VOC_pascal_coco_v2.zip -d ./dataset/ 26 | ``` 27 | ## Inference 28 | Following commands will run pretrained model on test set. Visualization can be found in './test_results' 29 | ``` 30 | python mist_test.py --path_json='json/pascal.json' 31 | python mist_test.py --path_json='json/mnist.json' 32 | ``` 33 | ## Citation 34 | ``` 35 | @inproceedings{angles2021mist, 36 | title={MIST: Multiple Instance Spatial Transformer Networks}, 37 | author={Baptiste Angles*, Yuhe Jin*, Simon Kornblith, Andrea Tagliasacchi, Kwang Moo Yi}, 38 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 39 | year={2021} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import copy 3 | import torch 4 | import torch.utils.data 5 | from datasets.grouped_batch_sampler import GroupedBatchSampler 6 | 7 | from datasets.mnist import MNIST, MNISTMetaData 8 | from datasets.pascal import PascalVOC, PascalVOCMetaData 9 | def compute_aspect_ratios(dataset): 10 | aspect_ratios = [] 11 | for i in range(len(dataset)): 12 | img_info = dataset.get_img_info(i) 13 | aspect_ratio = float(img_info["height"]) / float(img_info["width"]) 14 | aspect_ratios.append(aspect_ratio) 15 | return aspect_ratios 16 | 17 | def _quantize(x, bins): 18 | bins = copy.copy(bins) 19 | bins = sorted(bins) 20 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 21 | return quantized 22 | 23 | def get_dataset(config, mode='train'): 24 | if config.dataset.startswith('mnist'): 25 | dataset = MNIST(config, mode) 26 | loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True) 27 | meta_data = MNISTMetaData() 28 | elif config.dataset.startswith('VOC'): 29 | if mode=='valid': 30 | mode='val' 31 | # init dataset 32 | dataset = PascalVOC(config, mode) 33 | #init data sampler 34 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 35 | # init batch sampler 36 | aspect_ratios = compute_aspect_ratios(dataset) 37 | group_ids = _quantize(aspect_ratios, [1]) 38 | batch_sampler = GroupedBatchSampler(sampler, group_ids, config.batch_size, drop_uneven=False) 39 | # init batch collector 40 | collector = BatchCollector() 41 | # init data loader 42 | loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collector, num_workers=0, pin_memory=True) 43 | meta_data = PascalVOCMetaData(config,mode) 44 | else: 45 | raise Exception('invalid dataset') 46 | 47 | return loader, meta_data 48 | 49 | # modified based on 'maskrcnn-benchmark' 50 | # git@github.com:facebookresearch/maskrcnn-benchmark.git 51 | class BatchCollector(object): 52 | """ 53 | From a list of samples from the dataset, 54 | returns the batched images and targets. 55 | This should be passed to the DataLoader 56 | """ 57 | 58 | def __init__(self): 59 | pass 60 | 61 | def __call__(self, batch): 62 | # transposed_batch = list(zip(*batch)) 63 | # images = to_image_list(transposed_batch[0], self.size_divisible) 64 | # targets = transposed_batch[1] 65 | # img_ids = transposed_batch[2] 66 | max_h = max([sample[0].shape[1] for sample in batch]) 67 | max_w = max([sample[0].shape[2] for sample in batch]) 68 | image_tensor = torch.zeros([len(batch),batch[0][0].shape[0],max_h,max_w], device=torch.device('cpu')) 69 | keypoints_tensor = torch.zeros([len(batch),batch[0][1].shape[0],4], device=torch.device('cpu')) 70 | labels_tensor = torch.zeros([len(batch),batch[0][2].shape[0]], dtype=torch.long, device=torch.device('cpu')) 71 | 72 | for idx, sample in enumerate(batch): 73 | image, keypoints, labels = sample 74 | image_tensor[idx,:,:image.shape[1],:image.shape[2]] = image 75 | keypoints_tensor[idx,:,:] = keypoints 76 | labels_tensor[idx,:] = labels 77 | 78 | return image_tensor, keypoints_tensor, labels_tensor 79 | -------------------------------------------------------------------------------- /datasets/grouped_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import itertools 3 | 4 | import torch 5 | from torch.utils.data.sampler import BatchSampler 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | class GroupedBatchSampler(BatchSampler): 10 | """ 11 | Wraps another sampler to yield a mini-batch of indices. 12 | It enforces that elements from the same group should appear in groups of batch_size. 13 | It also tries to provide mini-batches which follows an ordering which is 14 | as close as possible to the ordering from the original sampler. 15 | 16 | Arguments: 17 | sampler (Sampler): Base sampler. 18 | batch_size (int): Size of mini-batch. 19 | drop_uneven (bool): If ``True``, the sampler will drop the batches whose 20 | size is less than ``batch_size`` 21 | 22 | """ 23 | 24 | def __init__(self, sampler, group_ids, batch_size, drop_uneven=False): 25 | if not isinstance(sampler, Sampler): 26 | raise ValueError( 27 | "sampler should be an instance of " 28 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 29 | ) 30 | self.sampler = sampler 31 | self.group_ids = torch.as_tensor(group_ids) 32 | assert self.group_ids.dim() == 1 33 | self.batch_size = batch_size 34 | self.drop_uneven = drop_uneven 35 | 36 | self.groups = torch.unique(self.group_ids).sort(0)[0] 37 | 38 | self._can_reuse_batches = False 39 | 40 | def _prepare_batches(self): 41 | dataset_size = len(self.group_ids) 42 | # get the sampled indices from the sampler 43 | sampled_ids = torch.as_tensor(list(self.sampler)) 44 | # potentially not all elements of the dataset were sampled 45 | # by the sampler (e.g., DistributedSampler). 46 | # construct a tensor which contains -1 if the element was 47 | # not sampled, and a non-negative number indicating the 48 | # order where the element was sampled. 49 | # for example. if sampled_ids = [3, 1] and dataset_size = 5, 50 | # the order is [-1, 1, -1, 0, -1] 51 | order = torch.full((dataset_size,), -1, dtype=torch.int64) 52 | order[sampled_ids] = torch.arange(len(sampled_ids)) 53 | 54 | # get a mask with the elements that were sampled 55 | mask = order >= 0 56 | 57 | # find the elements that belong to each individual cluster 58 | clusters = [(self.group_ids == i) & mask for i in self.groups] 59 | # get relative order of the elements inside each cluster 60 | # that follows the order from the sampler 61 | relative_order = [order[cluster] for cluster in clusters] 62 | # with the relative order, find the absolute order in the 63 | # sampled space 64 | permutation_ids = [s[s.sort()[1]] for s in relative_order] 65 | # permute each cluster so that they follow the order from 66 | # the sampler 67 | permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] 68 | 69 | # splits each cluster in batch_size, and merge as a list of tensors 70 | splits = [c.split(self.batch_size) for c in permuted_clusters] 71 | merged = tuple(itertools.chain.from_iterable(splits)) 72 | 73 | # now each batch internally has the right order, but 74 | # they are grouped by clusters. Find the permutation between 75 | # different batches that brings them as close as possible to 76 | # the order that we have in the sampler. For that, we will consider the 77 | # ordering as coming from the first element of each batch, and sort 78 | # correspondingly 79 | first_element_of_batch = [t[0].item() for t in merged] 80 | # get and inverse mapping from sampled indices and the position where 81 | # they occur (as returned by the sampler) 82 | inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} 83 | # from the first element in each batch, get a relative ordering 84 | first_index_of_batch = torch.as_tensor( 85 | [inv_sampled_ids_map[s] for s in first_element_of_batch] 86 | ) 87 | 88 | # permute the batches so that they approximately follow the order 89 | # from the sampler 90 | permutation_order = first_index_of_batch.sort(0)[1].tolist() 91 | # finally, permute the batches 92 | batches = [merged[i].tolist() for i in permutation_order] 93 | 94 | if self.drop_uneven: 95 | kept = [] 96 | for batch in batches: 97 | if len(batch) == self.batch_size: 98 | kept.append(batch) 99 | batches = kept 100 | return batches 101 | 102 | def __iter__(self): 103 | if self._can_reuse_batches: 104 | batches = self._batches 105 | self._can_reuse_batches = False 106 | else: 107 | batches = self._prepare_batches() 108 | self._batches = batches 109 | return iter(batches) 110 | 111 | def __len__(self): 112 | if not hasattr(self, "_batches"): 113 | self._batches = self._prepare_batches() 114 | self._can_reuse_batches = True 115 | return len(self._batches) 116 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data 4 | import numpy as np 5 | import skimage.io 6 | import skimage.transform 7 | 8 | MNIST_CLASSES = ( 9 | "0", 10 | "1", 11 | "2", 12 | "3", 13 | "4", 14 | "5", 15 | "6", 16 | "7", 17 | "8", 18 | "9" 19 | ) 20 | 21 | class MNISTMetaData(): 22 | def __init__(self): 23 | self.cls = MNIST_CLASSES 24 | def get_num_class(self): 25 | return len(self.cls) 26 | def get_class_name(self, class_id): 27 | return self.cls[class_id] 28 | 29 | class MNIST(torch.utils.data.Dataset): 30 | def __init__(self, config, mode): 31 | self.root_dir = config.dataset_dir+'/'+config.dataset+'/' 32 | self.image_paths = np.genfromtxt(self.root_dir + mode + '.txt', delimiter=',', dtype='str', encoding='utf-8') 33 | self.labels = np.genfromtxt(self.root_dir + mode +'_labels.txt', delimiter=',', dtype='int', encoding='utf-8') 34 | self.keypoints = np.load(self.root_dir + mode +'_keypoints.npy') 35 | self.num_kp = config.k 36 | self.image_size =config.image_size 37 | 38 | def __len__(self): 39 | return len(self.image_paths) 40 | 41 | def __getitem__(self, idx): 42 | # load image 43 | img_name = os.path.join(self.root_dir, self.image_paths[idx]) 44 | image = skimage.io.imread(img_name) 45 | image = skimage.transform.resize(image,(self.image_size,self.image_size)) 46 | image = torch.from_numpy(image).permute(2, 0, 1).float() 47 | image = torch.clamp(image, 0.0, 1.0) 48 | 49 | # load keypoints 50 | keypoints = torch.from_numpy(self.keypoints[idx].copy()) 51 | keypoints[:,2] = keypoints[:,2] * 2.0 52 | keypoints = torch.cat((keypoints,keypoints[:,[2]]), axis=-1) 53 | stride = self.image_size/image.shape[1] 54 | keypoints = keypoints*stride 55 | 56 | # load label 57 | labels = torch.from_numpy(self.labels[idx]) 58 | 59 | return image, keypoints, labels -------------------------------------------------------------------------------- /datasets/pascal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import numpy as np 4 | import os 5 | import skimage.io 6 | import skimage.transform 7 | import random 8 | from tqdm import tqdm 9 | import random 10 | import sys 11 | 12 | if sys.version_info[0] == 2: 13 | import xml.etree.cElementTree as ET 14 | else: 15 | import xml.etree.ElementTree as ET 16 | 17 | 18 | 19 | PASCAL_CLASSES = ( 20 | "__background__", 21 | "aeroplane", 22 | "bicycle", 23 | "bird", 24 | "boat", 25 | "bottle", 26 | "bus", 27 | "car", 28 | "cat", 29 | "chair", 30 | "cow", 31 | "diningtable", 32 | "dog", 33 | "horse", 34 | "motorbike", 35 | "person", 36 | "pottedplant", 37 | "sheep", 38 | "sofa", 39 | "train", 40 | "tvmonitor", 41 | ) 42 | 43 | 44 | class PascalVOCMetaData(): 45 | def __init__(self, config, mode): 46 | self.root_dir = config.dataset_dir+'/'+config.dataset+'/' 47 | self.image_set = mode 48 | self._imgsetpath = os.path.join(self.root_dir, "ImageSets", "Main", "%s.txt") 49 | 50 | with open(self._imgsetpath % self.image_set) as f: 51 | self.ids = f.readlines() 52 | self.ids = [x.strip("\n") for x in self.ids] 53 | 54 | self.cls = PASCAL_CLASSES 55 | def get_num_class(self): 56 | return len(self.cls) 57 | def get_class_name(self, class_id): 58 | return self.cls[class_id] 59 | def get_image_name(self, image_id): 60 | return self.ids[image_id] 61 | 62 | 63 | 64 | 65 | # modified based on 'maskrcnn-benchmark' 66 | # git@github.com:facebookresearch/maskrcnn-benchmark.git 67 | class PascalVOC(torch.utils.data.Dataset): 68 | def __init__(self, config, mode): 69 | self.max_objects = config.k 70 | self.root_dir = config.dataset_dir+'/'+config.dataset+'/' 71 | self.image_set = mode 72 | self.keep_difficult = False 73 | self.num_kp = config.k 74 | self._annopath = os.path.join(self.root_dir, "Annotations", "%s.xml") 75 | self._imgpath = os.path.join(self.root_dir, "JPEGImages", "%s.jpg") 76 | self._imgsetpath = os.path.join(self.root_dir, "ImageSets", "Main", "%s.txt") 77 | self.image_size = config.image_size 78 | with open(self._imgsetpath % self.image_set) as f: 79 | self.ids = f.readlines() 80 | self.ids = [x.strip("\n") for x in self.ids] 81 | self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} 82 | 83 | cls = PASCAL_CLASSES 84 | self.class_to_ind = dict(zip(cls, range(len(cls)))) 85 | self.categories = dict(zip(range(len(cls)), cls)) 86 | self.rand_horiz_flip = config.rand_horiz_flip 87 | self.rand_maskout = config.rand_maskout 88 | 89 | def __getitem__(self, index): 90 | img_id = self.ids[index] 91 | image = skimage.io.imread(self._imgpath%img_id) 92 | # resize image 93 | scale = min(image.shape[0], image.shape[1])/self.image_size 94 | size = [int(image.shape[0]/scale),int(image.shape[1]/scale)] 95 | image = skimage.transform.resize(image,size) 96 | 97 | anno = ET.parse(self._annopath % img_id).getroot() 98 | anno = self._preprocess_annotation(anno) 99 | bboxes = anno['boxes'] 100 | bboxes = bboxes/scale 101 | 102 | keypoints = torch.zeros([self.num_kp,4],device=torch.device('cpu')) 103 | 104 | image = torch.from_numpy(image).permute(2, 0, 1).float() 105 | 106 | keypoints[:bboxes.shape[0],[0]] = (bboxes[:,[0]]+bboxes[:,[2]])/2 107 | keypoints[:bboxes.shape[0],[1]] = (bboxes[:,[1]]+bboxes[:,[3]])/2 108 | keypoints[:bboxes.shape[0],[2]] = bboxes[:,[2]]-bboxes[:,[0]] 109 | keypoints[:bboxes.shape[0],[3]] = bboxes[:,[3]]-bboxes[:,[1]] 110 | 111 | labels = torch.zeros([self.num_kp],device=torch.device('cpu')) 112 | labels[:anno['labels'].shape[0]] = anno['labels'] 113 | 114 | image, keypoints, labels = self._transform(image,keypoints,labels) 115 | 116 | 117 | 118 | return image, keypoints, labels 119 | 120 | def _transform(self, image, keypoints, labels): 121 | 122 | if self.rand_horiz_flip and random.random()>0.5: 123 | image = image.flip(2) 124 | keypoints[:,0] = image.shape[2] - keypoints[:,0] 125 | if self.rand_maskout: 126 | _, H_image, W_image = image.shape 127 | S_image = min(H_image, W_image) 128 | H_mask = int((random.random()*0.45+0.15)*S_image) 129 | W_mask = int((random.random()*0.45+0.15)*S_image) 130 | y_mask = int(random.random()*(H_image-H_mask-2)) + int(H_mask/2) +1 131 | x_mask = int(random.random()*(W_image-W_mask-2)) + int(W_mask/2) +1 132 | avg_color = torch.mean(image[:,y_mask-int(H_mask/2):y_mask+int(H_mask/2),x_mask-int(W_mask/2):x_mask+int(W_mask/2)],(1,2),keepdim=True) 133 | image[:,y_mask-int(H_mask/2):y_mask+int(H_mask/2),x_mask-int(W_mask/2):x_mask+int(W_mask/2)] = avg_color 134 | 135 | return image, keypoints, labels 136 | 137 | 138 | def __len__(self): 139 | return len(self.ids) 140 | 141 | 142 | 143 | def _preprocess_annotation(self, target): 144 | boxes = [] 145 | gt_classes = [] 146 | difficult_boxes = [] 147 | TO_REMOVE = 1 148 | num_objects = 0 149 | for obj in target.iter("object"): 150 | if num_objects == self.max_objects: 151 | break 152 | difficult = int(obj.find("difficult").text) == 1 153 | if not self.keep_difficult and difficult: 154 | continue 155 | num_objects = num_objects + 1 156 | name = obj.find("name").text.lower().strip() 157 | bb = obj.find("bndbox") 158 | # Make pixel indexes 0-based 159 | # Refer to "https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/pascal_voc.py#L208-L211" 160 | box = [ 161 | bb.find("xmin").text, 162 | bb.find("ymin").text, 163 | bb.find("xmax").text, 164 | bb.find("ymax").text, 165 | ] 166 | bndbox = tuple( 167 | map(lambda x: x - TO_REMOVE, list(map(int, box))) 168 | ) 169 | 170 | boxes.append(bndbox) 171 | gt_classes.append(self.class_to_ind[name]) 172 | difficult_boxes.append(difficult) 173 | 174 | size = target.find("size") 175 | im_info = tuple(map(int, (size.find("height").text, size.find("width").text))) 176 | 177 | res = { 178 | "boxes": torch.tensor(boxes, dtype=torch.float32, device=torch.device('cpu')), 179 | "labels": torch.tensor(gt_classes, device=torch.device('cpu')), 180 | "difficult": torch.tensor(difficult_boxes, device=torch.device('cpu')), 181 | "im_info": im_info, 182 | } 183 | return res 184 | 185 | def get_img_info(self, index): 186 | img_id = self.ids[index] 187 | anno = ET.parse(self._annopath % img_id).getroot() 188 | size = anno.find("size") 189 | im_info = tuple(map(int, (size.find("height").text, size.find("width").text))) 190 | return {"height": im_info[0], "width": im_info[1]} 191 | 192 | 193 | -------------------------------------------------------------------------------- /images/mist_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/mist/ef8ab358e83dc306f356910578c5a7c1f8d193d8/images/mist_pipeline.png -------------------------------------------------------------------------------- /json/mnist.json: -------------------------------------------------------------------------------- 1 | { 2 | "anchor_size": 0.25, 3 | "batch_size": 32, 4 | "comment": "", 5 | "cooldown_period": -1, 6 | "dataset": "mnist_hard", 7 | "dataset_dir": "./dataset", 8 | "detector_backbone": "CustomResNet", 9 | "epochs": 10, 10 | "heatmap_reconstruct": "single_point", 11 | "image_size": 80, 12 | "k": 9, 13 | "k_iter": 2, 14 | "log_dir": "./logs", 15 | "loss_type": "MSE", 16 | "lr_detector": 0.0001, 17 | "lr_k": 10000.0, 18 | "lr_task": 0.0001, 19 | "model_dir": "./pretrained_models", 20 | "name": "mnist", 21 | "nms_kernal_size_ratio": 0.05, 22 | "num_classes": 10, 23 | "patch_from_featuremap": false, 24 | "patch_size": 32, 25 | "pretrained_resnet": true, 26 | "rand_horiz_flip": false, 27 | "rand_maskout": false, 28 | "resume": true, 29 | "run_val": true, 30 | "save_path": "./save_results", 31 | "save_period": 5, 32 | "save_weights": true, 33 | "set_seed": true, 34 | "sm_kernal_size_ratio": 0.2, 35 | "softmax_strength": 10, 36 | "spatial_softmax": true, 37 | "sub_pixel_kp": false, 38 | "summary_period": 1, 39 | "test_path": "./test_results", 40 | "val_path": "./val_results", 41 | "valid_period": 5 42 | } -------------------------------------------------------------------------------- /json/pascal.json: -------------------------------------------------------------------------------- 1 | { 2 | "anchor_size": 0.5, 3 | "batch_size": 16, 4 | "comment": "", 5 | "cooldown_period": -1, 6 | "dataset": "VOC_pascal_coco_v2", 7 | "dataset_dir": "./dataset", 8 | "detector_backbone": "ResNet34", 9 | "epochs": 100, 10 | "heatmap_reconstruct": "gaussian", 11 | "image_size": 224, 12 | "k": 2, 13 | "k_iter": 2, 14 | "log_dir": "./logs", 15 | "loss_type": "MSE", 16 | "lr_detector": 0.0001, 17 | "lr_k": 1000.0, 18 | "lr_task": 0.0001, 19 | "model_dir": "./pretrained_models", 20 | "name": "pascal_coco", 21 | "nms_kernal_size_ratio": 0.05, 22 | "num_classes": 21, 23 | "patch_from_featuremap": true, 24 | "patch_size": 32, 25 | "pretrained_resnet": true, 26 | "rand_horiz_flip": true, 27 | "rand_maskout": true, 28 | "resume": true, 29 | "run_val": true, 30 | "save_path": "./save_results", 31 | "save_period": 5, 32 | "save_weights": true, 33 | "set_seed": true, 34 | "sm_kernal_size_ratio": 0.2, 35 | "softmax_strength": 10, 36 | "spatial_softmax": false, 37 | "sub_pixel_kp": true, 38 | "summary_period": 1, 39 | "test_path": "./test_results", 40 | "val_path": "./val_results", 41 | "valid_period": 5 42 | } -------------------------------------------------------------------------------- /mist_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | 5 | from models.mist import MIST 6 | from datasets.datasets import get_dataset 7 | from utils.summary import CustomSummaryWriter 8 | from utils.config_utils import print_config, str2bool, get_mist_config 9 | from utils.viz_utils import save_bbox_images 10 | from utils.torch_utils import to_gpu, eval_accuracy, xywh_to_xyxy 11 | 12 | if __name__ == '__main__': 13 | # get config 14 | config = get_mist_config() 15 | print_config(config) 16 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 17 | 18 | # init data loader 19 | config.rand_horiz_flip = False 20 | config.rand_maskout = False 21 | config.batch_size = 1 22 | dataset_te, metadata_te = get_dataset(config, mode='test') 23 | 24 | # init network 25 | config.pretrained_resnet = False 26 | mist = MIST(config) 27 | 28 | # load pretrained weights 29 | best_model_path = os.path.join(config.model_dir,config.name + '_best_models') 30 | mist.load_state_dict(torch.load(best_model_path)) 31 | 32 | # set evaluation metric 33 | if config.dataset.startswith('mnist'): 34 | background_class=False 35 | val_metric = 'iou' 36 | else: 37 | background_class=True 38 | val_metric = 'center' 39 | 40 | # create viz folder 41 | test_results_path = os.path.join(config.test_path, config.name) 42 | if not os.path.isdir(test_results_path): 43 | os.makedirs(test_results_path) 44 | 45 | # inference 46 | output = {} 47 | output['tp_center'] = 0 48 | output['tp_iou'] = 0 49 | output['num_objects'] = 0 50 | output['num_detetcions'] = 0 51 | for i, data in tqdm(enumerate(dataset_te)): 52 | images, keypoints_gt, labels_gt = to_gpu(data) 53 | B = images.shape[0] 54 | # forward pass 55 | bboxs, labels, _ = mist.forward(images.clone()) 56 | # evaluation 57 | eval_test = eval_accuracy(bboxs, keypoints_gt, labels, labels_gt, 58 | config.num_classes, background_class) 59 | for j,(image, bbox, label) in enumerate(zip(images,xywh_to_xyxy(bboxs),labels)): 60 | save_bbox_images(image, bbox, 61 | [metadata_te.get_class_name(_label) for _label in label],str(i*B+j), 62 | test_results_path, background_class) 63 | # accumulate results 64 | output['tp_center'] += eval_test['tp_center'] 65 | output['tp_iou'] += eval_test['tp_iou'] 66 | output['num_objects'] += eval_test['num_objects'] 67 | output['num_detetcions'] += eval_test['num_detetcions'] 68 | 69 | # calculate f1 score 70 | if eval_test['num_detetcions'] == 0: 71 | precision = 0 72 | else: 73 | precision = output['tp_'+val_metric]/output['num_detetcions'] 74 | recall = output['tp_'+val_metric]/output['num_objects'] 75 | if (precision + recall) == 0: 76 | f1 = 0 77 | else: 78 | f1 = 2*precision*recall/(precision+recall) 79 | print('test set f1 {} score: {}'.format(val_metric, f1)) -------------------------------------------------------------------------------- /mist_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | import random 5 | import os 6 | import time 7 | import math 8 | from tqdm import tqdm 9 | 10 | from utils.summary import CustomSummaryWriter 11 | from datasets.datasets import get_dataset 12 | from models.mist import MIST 13 | from utils.config_utils import print_config, str2bool, get_mist_config, config_to_string 14 | from utils.torch_utils import to_gpu, eval_accuracy, inverse_heatmap_gaussian, inverse_heatmap, xywh_to_xyxy 15 | from utils.loss_functions import one_hot_classification_loss 16 | from utils.utils import tensorboard_scheduler 17 | from utils.viz_utils import save_bbox_images 18 | 19 | 20 | 21 | class MISTTrainer(): 22 | def __init__(self, config, mist): 23 | 24 | # init tensorboard scheduler 25 | self.scheduler = tensorboard_scheduler(config.summary_period,config. 26 | save_period,config.valid_period, 27 | config.cooldown_period) 28 | # copy config 29 | self.config = config 30 | 31 | # init iteration counter 32 | self.num_iter = 0 33 | self.best_val_f1 = 0 34 | 35 | # init summary writer 36 | self.summary = CustomSummaryWriter(config.log_dir + '/' + config.name) 37 | 38 | # create the solvers 39 | self.detector_solver = torch.optim.Adam(mist.detector.parameters(), 40 | lr=config.lr_detector) 41 | self.task_solver = torch.optim.Adam(mist.classifier.parameters(), 42 | lr=config.lr_task) 43 | 44 | # reume flag 45 | self.resumed = False 46 | 47 | # contains background class or not 48 | if config.dataset.startswith('mnist'): 49 | self.background_class=False 50 | self.val_metric = 'iou' 51 | else: 52 | self.background_class=True 53 | self.val_metric = 'center' 54 | 55 | # create validation path 56 | self.val_results_path = os.path.join(config.val_path, config.name) 57 | if not os.path.isdir(self.val_results_path): 58 | os.makedirs(self.val_results_path) 59 | if not os.path.isdir(self.config.model_dir): 60 | os.makedirs(self.config.model_dir) 61 | 62 | 63 | 64 | def resume(self, mist): 65 | try: 66 | print('Loading saved models ...') 67 | mist.load_state_dict(torch.load(os.path.join( 68 | self.config.model_dir,self.config.name+'_models'))) 69 | print('Loading saved solvers ...') 70 | self.load_solvers() 71 | print('Previous state resumed, continue training') 72 | self.resumed = True 73 | except: 74 | print('Did not find saved model, fresh start') 75 | self.resumed = False 76 | 77 | def write_meta_data(self): 78 | # add hyper parameters to summary 79 | self.summary.add_text('hyper paramter',config_to_string(self.config)) 80 | # add comment to summary 81 | self.summary.add_text('comment',self.config.comment) 82 | 83 | def save_solvers(self): 84 | torch.save({ 85 | 'detector_solver': self.detector_solver.state_dict(), 86 | 'task_solver': self.task_solver.state_dict(), 87 | 'iteration': self.num_iter, 88 | 'best_val_f1': self.best_val_f1 89 | }, os.path.join(self.config.model_dir, 90 | self.config.name + '_solvers')) 91 | 92 | def load_solvers(self): 93 | checkpoint = torch.load(os.path.join(self.config.model_dir, 94 | self.config.name + '_solvers')) 95 | if hasattr(self,'solver'): 96 | self.solver.load_state_dict(checkpoint['solver']) 97 | else: 98 | self.detector_solver.load_state_dict(checkpoint['detector_solver']) 99 | self.task_solver.load_state_dict(checkpoint['task_solver']) 100 | self.num_iter = checkpoint['iteration'] 101 | self.best_val_f1 = checkpoint['best_val_f1'] 102 | print('continue at {} iter, f1: {}'.format(self.num_iter, 103 | self.best_val_f1)) 104 | 105 | def train_kp(self, mist, featuremap, bbox, labels_gt): 106 | # only regress bbox location, fix width and height 107 | bbox_xy = bbox[:,:,:2].clone().detach().requires_grad_(True) 108 | bbox_wh = bbox[:,:,2:].clone().detach().requires_grad_(False) 109 | bbox = torch.cat([bbox_xy,bbox_wh],dim=-1) 110 | bbox_solver = torch.optim.SGD([bbox_xy], lr=self.config.lr_k) 111 | 112 | # training loop 113 | for k in range(self.config.k_iter): 114 | bbox_solver.zero_grad() 115 | _, logits, _ = mist.classifier.forward(featuremap, bbox) 116 | loss, _= one_hot_classification_loss(logits, labels_gt, 117 | self.config.num_classes, self.config.loss_type) 118 | loss.backward() 119 | bbox_solver.step() 120 | 121 | bbox = torch.cat([bbox_xy,bbox_wh],dim=-1) 122 | 123 | kp_diag = {} 124 | kp_diag['bbox_grads'] = bbox_xy.grad 125 | return bbox.detach(), kp_diag 126 | 127 | 128 | def train_task(self, mist,featuremap,bbox_opt, labels_gt): 129 | labels, logits, diag = mist.classifier.forward(featuremap, bbox_opt) 130 | # clear grad in solver 131 | self.task_solver.zero_grad() 132 | # calculate loss 133 | loss, loss_diag = one_hot_classification_loss(logits, labels_gt, 134 | self.config.num_classes, self.config.loss_type) 135 | # back propagate 136 | loss.backward() 137 | # update weights 138 | self.task_solver.step() 139 | task_diag = {} 140 | task_diag['loss_per_sample'] = loss_diag['loss_per_sample'] 141 | task_diag['loss_task'] = loss 142 | task_diag['classifier'] = diag 143 | task_diag['labels'] = labels 144 | return loss, task_diag 145 | 146 | def train_detector(self, heatmap, bbox, loss_per_sample): 147 | diag = {} 148 | 149 | B,C,H,W = heatmap.shape 150 | K = bbox.shape[1] 151 | 152 | if self.config.sub_pixel_kp: 153 | # get optimized offset temp gt 154 | offset_gt = bbox[:,:,:2] - torch.floor(bbox[:,:,:2]) 155 | # cast bbox location to int 156 | bbox_int = torch.floor(bbox) 157 | # clamp bbox location within image 158 | bbox_int[:,:,0] = torch.clamp(bbox_int[:,:,0], min=0 ,max=W-1) 159 | bbox_int[:,:,1] = torch.clamp(bbox_int[:,:,1], min=0 ,max=H-1) 160 | # extract offset at bbox locations from heatmap 161 | idx = torch.cat([torch.arange(B).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1,K,2,1).type(torch.long), 162 | torch.arange(1,3).reshape(1,1,2,1).repeat(B,K,1,1), 163 | bbox_int[:,:,[1,0]].type(torch.long).unsqueeze(-2).repeat(1,1,2,1)],dim=-1).permute(3,0,1,2) 164 | offset = heatmap[tuple(idx)] 165 | # offset loss 166 | loss_offset = 0.1*torch.nn.SmoothL1Loss()(offset,offset_gt) 167 | diag['offset_gt'] = offset_gt 168 | diag['offset'] = offset 169 | else: 170 | loss_offset = 0 171 | diag['loss_offset'] = loss_offset 172 | 173 | 174 | # reconstruct heatmap 175 | if self.config.heatmap_reconstruct=='gaussian': 176 | bbox = torch.cat([torch.floor(bbox[:,:,:2]),bbox[:,:,2:]],dim=-1) 177 | target_heatmap = inverse_heatmap_gaussian(bbox, heatmap[:,[0],:,:].shape) 178 | elif self.config.heatmap_reconstruct=='single_point': 179 | target_heatmap = inverse_heatmap(bbox[:,:,:2], heatmap[:,[0],:,:].shape) 180 | else: 181 | raise RunTimeError('Unknown method ({}) for heatmap reconsturction'.format(self.heatmap_reconstruct)) 182 | diag["target_heatmap"] = target_heatmap 183 | 184 | # reconstruction loss 185 | diff_per_sample = (heatmap[:,[0],:,:] - target_heatmap).pow(2).mean(dim=[1,2,3]) 186 | sample_weight = torch.exp(-5.0 * loss_per_sample.detach()) 187 | loss_heatmap = (sample_weight * diff_per_sample).mean() 188 | diag['loss_heatmap'] = loss_heatmap 189 | 190 | # sum of losses 191 | loss = loss_offset + loss_heatmap 192 | 193 | # back propogate 194 | self.detector_solver.zero_grad() 195 | loss.backward() 196 | self.detector_solver.step() 197 | 198 | return loss, diag 199 | 200 | 201 | def train(self, mist, dataset_tr, metadata_tr, dataset_va, metadata_va): 202 | # main training loop 203 | for epoch in tqdm(range(self.config.epochs)): 204 | for x in tqdm(dataset_tr, smoothing=0.1): 205 | # put data on gpu 206 | images, bbox_gt, labels_gt = to_gpu(x) 207 | 208 | # sanity check on data format 209 | if len(images.shape)!=4 or images.shape[1]!=3: 210 | raise RunTimeError('Images does not have dimension (B,3,H,W)') 211 | 212 | # forward path on detector 213 | bbox_dt, detector_diag = mist.detector.forward(images) 214 | labels_dt, logits, _ = mist.classifier.forward(detector_diag['featuremap'], bbox_dt) 215 | 216 | # optimize bbox 217 | bbox_opt, train_kp_diag = self.train_kp(mist, detector_diag['featuremap'], bbox_dt, labels_gt) 218 | 219 | # optimize the task network 220 | loss_task, train_task_diag= self.train_task(mist,detector_diag['featuremap'],bbox_opt, labels_gt) 221 | 222 | # optimize the detector 223 | loss_heatmap, train_detect_diag = self.train_detector( detector_diag['heatmap'], bbox_opt, train_task_diag['loss_per_sample']) 224 | 225 | eval_flag, save_flag, valid_flag = self.scheduler.schedule() 226 | 227 | 228 | if eval_flag: 229 | bbox_img = bbox_dt.detach() * detector_diag['stride'] 230 | labels_nms = labels_dt.detach() 231 | eval_batch = eval_accuracy (bbox_img, bbox_gt, labels_nms, labels_gt, self.config.num_classes, self.background_class) 232 | 233 | # loss 234 | self.summary.add_scalar('loss offset', train_detect_diag['loss_offset'], self.num_iter) 235 | self.summary.add_scalar('loss heatmap', train_detect_diag['loss_heatmap'], self.num_iter) 236 | self.summary.add_scalar('loss heatmap + offset', loss_heatmap, self.num_iter) 237 | self.summary.add_scalar('loss task', loss_task, self.num_iter) 238 | 239 | 240 | # input image 241 | self.summary.add_images('1 input', images, self.num_iter,resize=2) 242 | 243 | # input image with predicted bbox 244 | self.summary.add_images('2 boxes', images, self.num_iter, 245 | boxes_infer=xywh_to_xyxy(bbox_img), 246 | boxes_gt=xywh_to_xyxy(bbox_gt), 247 | labels=labels_nms, 248 | match=eval_batch['keypoint_match_detection'], 249 | resize=2) 250 | 251 | # keypoints displacement 252 | displacement = torch.cat([inverse_heatmap(bbox_dt.clone().detach(), [detector_diag['featuremap'].shape[0],1 ,detector_diag['featuremap'].shape[2], detector_diag['featuremap'].shape[3]]), 253 | inverse_heatmap(bbox_opt.clone().detach(), [detector_diag['featuremap'].shape[0],1 ,detector_diag['featuremap'].shape[2], detector_diag['featuremap'].shape[3]]), 254 | detector_diag['featuremap'].mean(dim=1, keepdim=True)], 255 | dim=1) 256 | self.summary.add_images('3 displacement', displacement, self.num_iter, resize=2) 257 | 258 | # patches for task network 259 | self.summary.add_images('4 patches', train_task_diag['classifier']['patches'], self.num_iter) 260 | 261 | # heatmap 262 | self.summary.add_images('5 heatmap', detector_diag['heatmap'][:,[0],:,:], self.num_iter) 263 | 264 | # target heatmap 265 | self.summary.add_images('6 target_heatmap', torch.clamp(train_detect_diag['target_heatmap'],0,1), self.num_iter, resize=2) 266 | 267 | # labels 268 | self.summary.add_histogram('labels', labels_nms.view(-1), self.num_iter) 269 | 270 | # offset 271 | if 'offset_gt' in train_detect_diag.keys() and 'offset' in train_detect_diag.keys(): 272 | self.summary.add_histogram('offset gt', train_detect_diag['offset_gt'].view(-1), self.num_iter) 273 | self.summary.add_histogram('offset', train_detect_diag['offset'].view(-1), self.num_iter) 274 | 275 | # accuracy 276 | self.summary.add_scalar('Batch f1 center', eval_batch['f1_center'], self.num_iter) 277 | self.summary.add_scalar('Batch f1 iou', eval_batch['f1_iou'], self.num_iter) 278 | self.summary.add_scalar('Batch precision center', eval_batch['precision_center'], self.num_iter) 279 | self.summary.add_scalar('Batch precision iou', eval_batch['precision_iou'], self.num_iter) 280 | self.summary.add_scalar('Batch recall center', eval_batch['recall_center'], self.num_iter) 281 | self.summary.add_scalar('Batch recall iou', eval_batch['recall_iou'], self.num_iter) 282 | self.summary.add_scalar('Batch AP detection', eval_batch['acc_det'], self.num_iter) 283 | self.summary.add_scalar('Batch AP classification', eval_batch['acc_class'], self.num_iter) 284 | 285 | 286 | self.summary.flush() 287 | 288 | if valid_flag and self.config.run_val: 289 | eval_val = self.validate(mist, dataset_va,metadata_va) 290 | if eval_val['num_detetcions'] == 0: 291 | precision = 0 292 | else: 293 | precision = eval_val['tp_'+self.val_metric]/eval_val['num_detetcions'] 294 | recall = eval_val['tp_'+self.val_metric]/eval_val['num_objects'] 295 | if (precision + recall) == 0: 296 | f1 = 0 297 | else: 298 | f1 = 2*precision*recall/(precision+recall) 299 | print('validation set f1 {} score: {}'.format(self.val_metric, f1)) 300 | if self.best_val_f1 < f1: 301 | mist.save_state_dict(os.path.join(self.config.model_dir, self.config.name + '_best_models')) 302 | self.best_val_f1 = f1 303 | print('savin best model f1 score: {}'.format(self.best_val_f1)) 304 | 305 | if save_flag: 306 | if self.config.save_weights: 307 | # save network weights 308 | print('saving wieghts ...') 309 | mist.save_state_dict(os.path.join(self.config.model_dir, self.config.name + '_models')) 310 | # save optimizer params 311 | self.save_solvers() 312 | 313 | self.num_iter += 1 314 | 315 | 316 | 317 | def validate(self, mist, dataset,metadata_va, mode='val'): 318 | output = {} 319 | output['tp_center'] = 0 320 | output['tp_iou'] = 0 321 | output['num_objects'] = 0 322 | output['num_detetcions'] = 0 323 | 324 | # only run validation on 50 mini batches to save time 325 | max_iter = 50 326 | print('Running on validation set up to 50 mini batches') 327 | for i, x in tqdm(enumerate(dataset), smoothing=0.1): 328 | images, keypoints_gt, labels_gt = to_gpu(x) 329 | B, _, _, _ = images.shape 330 | bboxs, labels, _ = mist.forward(images.clone()) 331 | eval_val = eval_accuracy(bboxs, keypoints_gt, labels, labels_gt, self.config.num_classes, self.background_class) 332 | for j,(image, bbox, label) in enumerate(zip(images,xywh_to_xyxy(bboxs),labels)): 333 | save_bbox_images(image,bbox,[metadata_va.get_class_name(_label) for _label in label], str(i*B+j),self.val_results_path,self.background_class) 334 | 335 | # accumulate results 336 | output['tp_center'] += eval_val['tp_center'] 337 | output['tp_iou'] += eval_val['tp_iou'] 338 | output['num_objects'] += eval_val['num_objects'] 339 | output['num_detetcions'] += eval_val['num_detetcions'] 340 | if i > max_iter: 341 | break 342 | return output 343 | 344 | 345 | 346 | 347 | def main(): 348 | 349 | config = get_mist_config() 350 | print_config(config) 351 | 352 | # set torch seed 353 | if config.set_seed: 354 | torch.manual_seed(0) 355 | torch.backends.cudnn.deterministic = True 356 | torch.backends.cudnn.benchmark = False 357 | 358 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 359 | torch.set_printoptions(profile="full") 360 | torch.set_printoptions(threshold=5000) 361 | torch.set_printoptions(precision=10) 362 | 363 | # init data loader 364 | dataset_tr, metadata_tr = get_dataset(config, mode='train') 365 | 366 | val_config = config 367 | val_config.rand_horiz_flip = False 368 | val_config.rand_maskout = False 369 | dataset_va, metadata_va = get_dataset(val_config, mode='valid') 370 | 371 | # init network 372 | mist = MIST(config) 373 | 374 | # init network trainer 375 | mist_trainer = MISTTrainer(config, mist) 376 | 377 | # resume model 378 | if config.resume: 379 | mist_trainer.resume(mist) 380 | 381 | # wirte meta data if first time run 382 | if not mist_trainer.resumed: 383 | mist_trainer.write_meta_data() 384 | 385 | # train model 386 | mist_trainer.train(mist, dataset_tr, metadata_tr, dataset_va, metadata_va) 387 | 388 | 389 | if __name__ == '__main__': 390 | main() 391 | 392 | -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.spatialtransform import spatial_transformer 3 | from models.net import ResNet 4 | 5 | 6 | def ClassifierNetWork(config,feat_ch): 7 | if config.detector_backbone == 'CustomResNet': 8 | return CustomClassifier(config,feat_ch) 9 | elif config.detector_backbone == 'ResNet34': 10 | return ResnetClassifier(config,feat_ch) 11 | else: 12 | raise Exception('invalid backbone') 13 | 14 | class ResnetClassifier (torch.nn.Module): 15 | 16 | def __init__(self, config,feat_ch): 17 | super(ResnetClassifier, self).__init__() 18 | 19 | 20 | self.psize = (config.patch_size, config.patch_size) 21 | self.cnn = ResNet(256, 512, 3, use_padding=True) 22 | self.fcn = torch.nn.Sequential( 23 | torch.nn.Linear(512,config.num_classes), 24 | torch.nn.Softmax(dim=-1)) 25 | self.global_avg_pool = torch.nn.AvgPool2d((self.psize[0],self.psize[1])) 26 | 27 | def forward(self, featuremap, bboxes, stride=None): 28 | diag = {} 29 | 30 | B, N, _ = bboxes.shape 31 | _, C, _, _ = featuremap.shape 32 | 33 | patches = spatial_transformer(featuremap, bboxes, self.psize).view(-1, C, *self.psize) 34 | patches = self.cnn(patches) 35 | features = self.global_avg_pool(patches).view(patches.shape[0],-1) 36 | logits = self.fcn(features) 37 | logits = logits.view(B, N,-1) 38 | labels = torch.argmax(logits,dim=2) 39 | 40 | # get labels 41 | labels = torch.argmax(logits,dim=2) 42 | diag['patches'] = patches 43 | 44 | return labels, logits, diag 45 | 46 | 47 | 48 | class CustomClassifier(torch.nn.Module): 49 | 50 | def __init__(self, config,feat_ch): 51 | super(CustomClassifier, self).__init__() 52 | self.encoder = Encoder(config) 53 | self.classifier = Classifier(config.num_classes, self.encoder.out_channels) 54 | self.psize = (config.patch_size, config.patch_size) 55 | 56 | def forward(self, featuremap, bboxes, stride=None): 57 | diag = {} 58 | 59 | B, N, _ = bboxes.shape 60 | _, C, _, _ = featuremap.shape 61 | 62 | # extract patch 63 | patches = spatial_transformer(featuremap, bboxes, self.psize).view(-1, C, *self.psize) 64 | 65 | # compute logits 66 | latent = self.encoder(patches) 67 | logits = self.classifier(latent) 68 | logits = logits.view(B,N,-1) 69 | 70 | # get labels 71 | labels = torch.argmax(logits,dim=2) 72 | 73 | diag['patches'] = patches 74 | 75 | return labels, logits, diag 76 | 77 | class Classifier(torch.nn.Module): 78 | def __init__(self, number_class,in_channels): 79 | super(Classifier, self).__init__() 80 | self.dense = torch.nn.Linear(in_channels, number_class) 81 | self.softmax = torch.nn.Softmax(dim=-1) 82 | 83 | def forward(self, x): 84 | x = self.dense(x) 85 | x = self.softmax(x) 86 | return x 87 | 88 | class Encoder(torch.nn.Module): 89 | def __init__(self, config): 90 | super(Encoder, self).__init__() 91 | 92 | num_channels = 8 93 | num_blocks = 3 94 | num_levels = 5 95 | 96 | c_out = 3 97 | self.layers = [] 98 | for i in range(num_levels): 99 | c_in = c_out 100 | c_out = num_channels * 2**i 101 | self.layers += [ResNet(in_channels=c_in, num_channels=c_out, num_blocks=num_blocks)] 102 | self.layers += [torch.nn.MaxPool2d(2, stride=2)] 103 | self.out_channels = c_out 104 | self.layers += [ResNet(in_channels=c_out, num_channels=c_out, num_blocks=num_blocks)] 105 | self.layers = torch.nn.Sequential(*self.layers) 106 | 107 | def forward(self, x): 108 | x = self.layers(x) 109 | x = x.view(x.shape[0], x.shape[1]) 110 | return x 111 | 112 | class ClassifierGAP (torch.nn.Module): 113 | 114 | def __init__(self, config,feat_ch): 115 | super(ClassifierGAP, self).__init__() 116 | 117 | self.psize = (config.patch_size, config.patch_size) 118 | self.cnn = ResNet(feat_ch, 512, 3, use_padding=True) 119 | self.fcn = torch.nn.Sequential( 120 | torch.nn.Linear(512,config.num_classes), 121 | torch.nn.Softmax(dim=-1)) 122 | self.global_avg_pool = torch.nn.AvgPool2d((self.psize[0],self.psize[1])) 123 | 124 | def forward(self, featuremap, bboxes, stride=None): 125 | diag = {} 126 | 127 | B, N, _ = bboxes.shape 128 | _, C, _, _ = featuremap.shape 129 | 130 | # extract patch 131 | patches = spatial_transformer(featuremap, bboxes, self.psize).view(-1, C, *self.psize) 132 | 133 | # compute logits 134 | patches = self.cnn(patches) 135 | features = self.global_avg_pool(patches).view(patches.shape[0],-1) 136 | logits = self.fcn(features) 137 | 138 | # get labels 139 | labels = torch.argmax(logits,dim=2) 140 | diag['patches'] = patches 141 | diag['features'] = features 142 | 143 | return labels, logits, diag 144 | 145 | 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /models/detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.models 4 | from itertools import product 5 | from models.net import ResNet 6 | 7 | class Detector(torch.nn.Module): 8 | def __init__(self, config): 9 | super(Detector, self).__init__() 10 | self.config = config 11 | if config.detector_backbone == 'ResNet34': 12 | print('init detector with ResNet') 13 | self.backbone = ResNet34Detector(config.pretrained_resnet) 14 | elif config.detector_backbone == 'CustomResNet': 15 | print('init detector with CustomResNet') 16 | self.backbone = CustomResNetDetector() 17 | else: 18 | raise NotImplementedError('{} backbone has not been implemented'.format(config.detector_backbone)) 19 | 20 | scale = float(config.image_size)/self.backbone.get_stride() 21 | self.anchor_size = self.config.anchor_size*scale 22 | 23 | 24 | def forward(self, images): 25 | diag = {} 26 | 27 | # Generate featuremap 28 | featuremap, heatmap = self.backbone(images,self.anchor_size) 29 | 30 | B,C,H,W = heatmap.shape 31 | 32 | # spatial softmax 33 | if self.config.spatial_softmax: 34 | sm_kernel_size = int(self.config.sm_kernal_size_ratio*(H+W)/4)*2+1 35 | heatmap[:,[0],:,:] = self.spatial_softmax(heatmap[:,[0],:,:], kernel_size=15, strength=10) 36 | 37 | # NMS 38 | nms_kernel_size = int(self.config.nms_kernal_size_ratio*(H+W)/4)*2+1 39 | heatmap_nms, _ = self.non_maximum_suppression(heatmap[:,[0],:,:], 5) 40 | # Choose top k 41 | _, indices = torch.topk(heatmap_nms.view(heatmap_nms.shape[0], -1), self.config.k) 42 | # Convert indices to x,y 43 | kp_int = torch.stack([(indices % W).float(),(indices // W).float()],dim=2) 44 | # Extract bbox from heatmap 45 | idx = torch.stack([torch.arange(images.shape[0]).reshape(-1,1,1).repeat(1,self.config.k,heatmap.shape[1]-1), 46 | torch.arange(1,heatmap.shape[1]).reshape(1,1,-1).repeat(images.shape[0],self.config.k,1), 47 | indices.unsqueeze(-1).repeat(1,1,heatmap.shape[1]-1)],dim=0) 48 | bbox = heatmap.view(heatmap.shape[0],heatmap.shape[1],-1)[tuple(idx)] 49 | 50 | # add x,y to offset 51 | bbox[:,:,:2] = bbox[:,:,:2] + kp_int 52 | 53 | # store results to dict 54 | diag['heatmap'] = heatmap 55 | diag['stride'] = self.backbone.get_stride() 56 | if self.config.patch_from_featuremap: 57 | diag['featuremap'] = featuremap.clone().detach().requires_grad_(False) 58 | else: 59 | diag['featuremap'] = images 60 | return bbox, diag 61 | 62 | def non_maximum_suppression(self, heatmap, size): 63 | # eps = 1e-4 64 | # heatmap = heatmap / eps 65 | max_logits = torch.nn.functional.max_pool2d(heatmap, kernel_size=size, stride=1, padding=size//2) 66 | mask = torch.ge(heatmap, max_logits) 67 | return heatmap * mask.float(), mask 68 | 69 | def spatial_softmax(self, heatmap, kernel_size, strength): 70 | 71 | # heatmap [N, S, H, W] 72 | out_shape = heatmap.shape 73 | heatmap = heatmap.view(-1, 1, out_shape[-2], out_shape[-1]) 74 | pad = kernel_size // 2 75 | # max_logits = torch.nn.functional.max_pool2d(heatmap, kernel_size=kernel_size, stride=1) 76 | # max_logits = torch.nn.functional.pad(max_logits, pad=(pad, pad, pad, pad), mode='replicate') 77 | max_logits = torch.max(heatmap, 2, True)[0] 78 | max_logits = torch.max(max_logits, 3, True)[0] 79 | 80 | ex = torch.exp(strength * (heatmap - max_logits)) 81 | # ex = torch.exp(strength * (heatmap)) 82 | sum_ex = torch.nn.functional.avg_pool2d(ex, kernel_size=kernel_size, stride=1, count_include_pad=False) * kernel_size**2 83 | sum_ex = torch.nn.functional.pad(sum_ex, pad=(pad, pad, pad, pad), mode='replicate') 84 | probs = ex / (sum_ex + 1e-6) 85 | # probs = heatmap - max_logits 86 | probs = probs.view(*out_shape) 87 | return probs 88 | 89 | def get_featuremap_channels(self): 90 | if self.config.patch_from_featuremap: 91 | return self.backbone.get_featuremap_channels() 92 | else: 93 | return 3 94 | 95 | # Resnet detector with 1by1 conv 96 | class ResNet34Detector(torch.nn.Module): 97 | def __init__(self, pretrained_resnet): 98 | super(ResNet34Detector, self).__init__() 99 | 100 | # not using the last maxpool layer 101 | self.backbone = torch.nn.Sequential(*list(self.create_resnet34(pretrained_resnet).children())[:7]) 102 | 103 | for layer in range(len(self.backbone)): 104 | for p in self.backbone[layer].parameters(): p.requires_grad = False 105 | 106 | self.detector_head = torch.nn.Sequential( 107 | torch.nn.Conv2d(256, 32, 3, 1, padding=1), 108 | torch.nn.ReLU(), 109 | torch.nn.Conv2d(32, 3, 1, 1, 0), 110 | ) 111 | 112 | def create_resnet34(self, pretrained_resnet): 113 | if pretrained_resnet: 114 | model = torchvision.models.resnet34(pretrained=True, progress=True) 115 | else: 116 | model = torchvision.models.resnet34(pretrained=False, progress=True) 117 | return model 118 | 119 | def get_featuremap_channels(self): 120 | return 256 121 | 122 | def get_stride(self): 123 | return 16 124 | 125 | def forward(self, images, anchor_size): 126 | 127 | featuremap = self.backbone(images) 128 | heatmap = self.detector_head(featuremap) 129 | 130 | B,_,W,H = heatmap.shape 131 | 132 | heatmap_w_h = torch.ones([B,2,W,H])*anchor_size 133 | heatmap_r = heatmap[:,[0],:,:] 134 | heatmap_x_y = torch.nn.functional.relu(heatmap[:,[1,2],:,:]) 135 | heatmap = torch.cat([heatmap_r,heatmap_x_y,heatmap_w_h],dim=1) 136 | 137 | return featuremap, heatmap 138 | 139 | 140 | # Resnet detector with 1by1 conv 141 | class CustomResNetDetector(torch.nn.Module): 142 | def __init__(self, num_blocks=4, num_channels=32): 143 | super(CustomResNetDetector, self).__init__() 144 | self.pad = torch.nn.ReflectionPad2d(num_blocks * 2) 145 | self.norm = torch.nn.BatchNorm2d(3) 146 | self.resnet = ResNet(3, num_channels, num_blocks, use_padding=False) 147 | self.last = torch.nn.Conv2d(num_channels, 1, kernel_size=1, stride=1, padding=0) 148 | 149 | def forward(self, x, anchor_size): 150 | x = self.pad(x) 151 | x = self.norm(x) 152 | f_x = self.resnet(x) 153 | h_x = self.last(f_x) 154 | N,_,W,H = h_x.shape 155 | h_x = torch.cat([h_x,torch.zeros(N,2,W,H),torch.ones(N,2,W,H)*anchor_size],dim=1) 156 | return f_x, h_x 157 | 158 | def get_stride(self): 159 | return 1 160 | 161 | def get_featuremap_channels(self): 162 | return 32 -------------------------------------------------------------------------------- /models/mist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.detector import Detector 3 | from models.classifier import ClassifierNetWork 4 | from utils.torch_utils import xywh_to_xyxy, xyxy_to_xywh 5 | 6 | class MIST(): 7 | def __init__(self, config): 8 | # init detector 9 | self.detector = Detector(config).cuda() 10 | # init classifier 11 | feat_ch = self.detector.get_featuremap_channels() 12 | self.classifier = ClassifierNetWork(config,feat_ch).cuda() 13 | 14 | def save_state_dict(self, path): 15 | torch.save({'detector': self.detector.state_dict(), 16 | 'classifier': self.classifier.state_dict()}, path) 17 | 18 | def load_state_dict(self, model_dict): 19 | self.detector.load_state_dict(model_dict['detector']) 20 | self.classifier.load_state_dict(model_dict['classifier']) 21 | 22 | def forward(self, image): 23 | # run detection network 24 | bbox_dt, detector_diag = self.detector.forward(image) 25 | # run task network 26 | labels, logits, _ = self.classifier.forward(detector_diag['featuremap'], bbox_dt.detach()) 27 | # convert bbox to image scale 28 | bbox_img = bbox_dt * detector_diag['stride'] 29 | 30 | return bbox_img, labels, logits 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /models/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class ResNet(torch.nn.Module): 4 | def __init__(self, in_channels, num_channels, num_blocks, use_padding=True): 5 | super(ResNet, self).__init__() 6 | self.first = torch.nn.Conv2d(in_channels, num_channels, 1) 7 | self.blocks = [] 8 | for _ in range(num_blocks): 9 | self.blocks += [ResNetBlock(num_channels, use_padding)] 10 | self.blocks = torch.nn.Sequential(*self.blocks) 11 | 12 | def forward(self, x): 13 | x = self.first(x) 14 | x = self.blocks(x) 15 | return x 16 | 17 | class ResNetBlock(torch.nn.Module): 18 | def __init__(self, channels=32, use_padding=True): 19 | super(ResNetBlock, self).__init__() 20 | # self.norm = torch.nn.BatchNorm2d(channels) 21 | # self.norm = torch.nn.GroupNorm(1, channels) # group norm 22 | self.norm = torch.nn.GroupNorm(1, channels) # layer norm 23 | self.use_padding = use_padding 24 | 25 | pad = 1 if self.use_padding else 0 26 | self.block = torch.nn.Sequential( 27 | torch.nn.Conv2d(channels, channels, kernel_size=3, padding=pad), 28 | self.norm, 29 | torch.nn.ReLU(), 30 | torch.nn.Conv2d(channels, channels, kernel_size=3, padding=pad), 31 | self.norm 32 | ) 33 | 34 | def forward(self, x): 35 | x_in = x 36 | if not self.use_padding: 37 | x_in = x_in[:, :, 2:-2, 2:-2] 38 | return torch.nn.functional.relu(x_in + self.block(x)) 39 | -------------------------------------------------------------------------------- /models/spatialtransform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def spatial_transformer(images, boxes, out_shape): 4 | # images: [B, C, H, W] 5 | # boxes: [B, N, (x,y,w,h)] 6 | # out_shape: (h, w) 7 | # torch.cuda.synchronize() 8 | # start = torch.cuda.Event(enable_timing=True) 9 | # end = torch.cuda.Event(enable_timing=True) 10 | # start.record() 11 | 12 | B, C, H, W = images.shape 13 | N = boxes.shape[1] 14 | 15 | grid = meshgrid(out_shape).to(images.device).contiguous() 16 | grid = grid - 0.5 17 | grid = grid.repeat([B, N, 1, 1]) # shape: (B, N, 3, h*w) 18 | bx = boxes[:,:,0] 19 | by = boxes[:,:,1] 20 | bw = boxes[:,:,2] 21 | bh = boxes[:,:,3] 22 | 23 | x = (grid[:,:,:,0] * bw.unsqueeze(2) + bx.unsqueeze(2)).view(-1) 24 | y = (grid[:,:,:,1] * bh.unsqueeze(2) + by.unsqueeze(2)).view(-1) 25 | 26 | 27 | # x = (x / (W-1) - 0.5) * 2 28 | # y = (y / (H-1) - 0.5) * 2 29 | # wgrid = torch.stack([x, y], dim=-1).view(B*N, *out_shape, 2) 30 | # wimages = torch.repeat_interleave(images, N, dim=0) 31 | # output = grid_sample(wimages, wgrid, mode='bilinear') 32 | # output = output.view(B, N, C, *out_shape) 33 | # return output 34 | 35 | 36 | x0 = x.long() 37 | y0 = y.long() 38 | x1 = x0 + 1 39 | y1 = y0 + 1 40 | 41 | # clamp 42 | x = torch.clamp(x, 0, W - 1) 43 | y = torch.clamp(y, 0, H - 1) 44 | x0 = torch.clamp(x0, 0, W - 1) 45 | y0 = torch.clamp(y0, 0, H - 1) 46 | x1 = torch.clamp(x1, 0, W - 1) 47 | y1 = torch.clamp(y1, 0, H - 1) 48 | 49 | # convert to linear indices 50 | batch_inds = torch.arange(B, device=images.device) 51 | 52 | batch_inds = torch.repeat_interleave(batch_inds, N) 53 | base = torch.repeat_interleave(batch_inds, out_shape[0]*out_shape[1], 0) * H * W 54 | 55 | idx_a = base + y0 * W + x0 56 | idx_b = base + y1 * W + x0 57 | idx_c = base + y0 * W + x1 58 | idx_d = base + y1 * W + x1 59 | 60 | # gather pixel values 61 | images = images.permute(0, 2, 3, 1).contiguous().view(-1, C) 62 | 63 | Ia = images[idx_a, :] 64 | Ib = images[idx_b, :] 65 | Ic = images[idx_c, :] 66 | Id = images[idx_d, :] 67 | 68 | # bilinear interpolation 69 | wa = ((x1.float() - x) * (y1.float() - y)).unsqueeze(1) 70 | wb = ((x1.float() - x) * (y - y0.float())).unsqueeze(1) 71 | wc = ((x - x0.float()) * (y1.float() - y)).unsqueeze(1) 72 | wd = ((x - x0.float()) * (y - y0.float())).unsqueeze(1) 73 | 74 | output = wa * Ia + wb * Ib + wc * Ic + wd * Id 75 | output = output.view(B, N, out_shape[0], out_shape[1], C).permute(0, 1, 4, 2, 3) 76 | return output 77 | 78 | def patches_to_image(patches, keypoints, out_shape): 79 | # patches [B, N, C, H, W] 80 | # boxes: [B, N, 4] 81 | # out_shape: (h, w) 82 | # returns -> [B, C, H, W] 83 | pass 84 | 85 | def meshgrid(out_shape): 86 | y, x = torch.meshgrid([torch.linspace(0, 1, steps=out_shape[-2]), torch.linspace(0, 1, steps=out_shape[-1])]) 87 | x, y = x.flatten(), y.flatten() 88 | grid = torch.stack([x, y, torch.ones_like(x)], dim=1) 89 | return grid 90 | 91 | 92 | 93 | 94 | if __name__ == '__main__': 95 | from scipy.misc import face 96 | import matplotlib.pyplot as plt 97 | from matplotlib.image import imread 98 | import numpy as np 99 | 100 | np_image = imread('dog.jpg') 101 | images = torch.from_numpy(np_image)[None, ...].type(torch.FloatTensor) 102 | images = images.permute(0, 3, 1, 2) 103 | 104 | print(images.shape) 105 | boxes = torch.tensor([[[-100, -100, 1000, 700], [512, 340, 100, 100]]], dtype=torch.float32) 106 | 107 | patches = spatial_transformer(images, boxes, out_shape=[300, 300]) 108 | patches = patches.permute(0, 1, 3, 4, 2) 109 | 110 | np_image = patches[0,1].numpy().astype(np.uint8) 111 | plt.imshow(np_image) 112 | plt.show() 113 | 114 | -------------------------------------------------------------------------------- /system/conda_mist.yaml: -------------------------------------------------------------------------------- 1 | name: mist 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - absl-py=0.11.0=py36h5fab9bb_0 10 | - aiohttp=3.7.3=py36h1d69622_0 11 | - async-timeout=3.0.1=py_1000 12 | - attrs=20.3.0=pyhd3deb0d_0 13 | - backcall=0.2.0=py_0 14 | - blas=1.0=mkl 15 | - blinker=1.4=py_1 16 | - blosc=1.20.1=hd408876_0 17 | - brotli=1.0.9=he6710b0_2 18 | - brotlipy=0.7.0=py36he6145b8_1001 19 | - bzip2=1.0.8=h7b6447c_0 20 | - c-ares=1.17.1=h36c2ea0_0 21 | - ca-certificates=2020.10.14=0 22 | - cachetools=2.0.0=py36_0 23 | - cairo=1.14.12=h8948797_3 24 | - certifi=2020.6.20=py36_0 25 | - cffi=1.14.4=py36h261ae71_0 26 | - chardet=3.0.4=py36h9880bd3_1008 27 | - charls=2.1.0=he6710b0_2 28 | - click=7.1.2=pyh9f0ad1d_0 29 | - cloudpickle=1.6.0=py_0 30 | - cryptography=3.3.1=py36h3c74f83_0 31 | - cudatoolkit=11.0.221=h6bb024c_0 32 | - cycler=0.10.0=py36_0 33 | - cytoolz=0.11.0=py36h7b6447c_0 34 | - dask-core=2020.12.0=pyhd3eb1b0_0 35 | - dataclasses=0.7=py36_0 36 | - dbus=1.13.18=hb2f20db_0 37 | - decorator=4.4.2=py_0 38 | - expat=2.2.10=he6710b0_2 39 | - ffmpeg=4.0=hcdf2ecd_0 40 | - fontconfig=2.13.0=h9420a91_0 41 | - freeglut=3.2.1=h58526e2_0 42 | - freetype=2.10.4=h5ab3b9f_0 43 | - giflib=5.1.4=h14c3975_1 44 | - glib=2.66.1=h92f7085_0 45 | - google-auth=1.24.0=pyhd3deb0d_0 46 | - google-auth-oauthlib=0.4.1=py_2 47 | - graphite2=1.3.13=h58526e2_1001 48 | - grpcio=1.33.2=py36he0f7d3b_2 49 | - gst-plugins-base=1.14.0=h8213a91_2 50 | - gstreamer=1.14.0=h28cd5cc_2 51 | - h5py=2.8.0=py36h989c5e5_3 52 | - harfbuzz=1.8.8=hffaf4a1_0 53 | - hdf5=1.10.2=hc401514_3 54 | - icu=58.2=he6710b0_3 55 | - idna=2.10=pyh9f0ad1d_0 56 | - idna_ssl=1.1.0=py36h9f0ad1d_1001 57 | - imagecodecs=2020.5.30=py36hfa7d478_2 58 | - imageio=2.9.0=py_0 59 | - importlib-metadata=3.4.0=py36h5fab9bb_0 60 | - intel-openmp=2020.2=254 61 | - ipython=7.16.1=py36h5ca1d4c_0 62 | - ipython_genutils=0.2.0=py36_0 63 | - jasper=2.0.14=h07fcdf6_1 64 | - jedi=0.17.2=py36_0 65 | - joblib=0.17.0=py_0 66 | - jpeg=9b=h024ee3a_2 67 | - jxrlib=1.1=h7b6447c_2 68 | - kiwisolver=1.3.0=py36h2531618_0 69 | - lcms2=2.11=h396b838_0 70 | - ld_impl_linux-64=2.33.1=h53a641e_7 71 | - libaec=1.0.4=he6710b0_1 72 | - libedit=3.1.20191231=h14c3975_1 73 | - libffi=3.3=he6710b0_2 74 | - libgcc-ng=9.1.0=hdf63c60_0 75 | - libgfortran=3.0.0=1 76 | - libgfortran-ng=7.3.0=hdf63c60_0 77 | - libglu=9.0.0=he1b5a44_1001 78 | - libopencv=3.4.2=hb342d67_1 79 | - libopus=1.3.1=h7b6447c_0 80 | - libpng=1.6.37=hbc83047_0 81 | - libprotobuf=3.13.0.1=hd408876_0 82 | - libstdcxx-ng=9.1.0=hdf63c60_0 83 | - libtiff=4.1.0=h2733197_1 84 | - libuuid=1.0.3=h1bed415_2 85 | - libuv=1.40.0=h7b6447c_0 86 | - libvpx=1.7.0=h439df22_0 87 | - libwebp=1.0.1=h8e7db2f_0 88 | - libxcb=1.14=h7b6447c_0 89 | - libxml2=2.9.10=hb55368b_3 90 | - libzopfli=1.0.3=he6710b0_0 91 | - lz4-c=1.9.2=heb0550a_3 92 | - markdown=3.3.3=pyh9f0ad1d_0 93 | - matplotlib=3.3.3=py36h5fab9bb_0 94 | - matplotlib-base=3.3.3=py36he12231b_0 95 | - mkl=2020.2=256 96 | - mkl-service=2.3.0=py36he8ac12f_0 97 | - mkl_fft=1.2.0=py36h23d657b_0 98 | - mkl_random=1.1.1=py36h0573a6f_0 99 | - multidict=5.1.0=py36h27cfd23_2 100 | - ncurses=6.2=he6710b0_1 101 | - networkx=2.5=py_0 102 | - ninja=1.10.2=py36hff7bd54_0 103 | - numpy=1.19.2=py36h54aff64_0 104 | - numpy-base=1.19.2=py36hfa32c7d_0 105 | - oauthlib=3.0.1=py_0 106 | - olefile=0.46=py36_0 107 | - opencv=3.4.2=py36h6fd60c2_1 108 | - openjpeg=2.3.0=h05c96fa_1 109 | - openssl=1.1.1k=h27cfd23_0 110 | - parso=0.7.0=py_0 111 | - pcre=8.44=he6710b0_0 112 | - pexpect=4.8.0=py36_0 113 | - pickleshare=0.7.5=py36_0 114 | - pillow=8.0.0=py36h9a89aac_0 115 | - pip=20.3.3=py36h06a4308_0 116 | - pixman=0.40.0=h36c2ea0_0 117 | - prompt-toolkit=3.0.8=py_0 118 | - protobuf=3.13.0.1=py36he6710b0_1 119 | - ptyprocess=0.6.0=py36_0 120 | - py-opencv=3.4.2=py36hb342d67_1 121 | - pyasn1=0.4.8=py_0 122 | - pyasn1-modules=0.2.7=py_0 123 | - pycparser=2.20=pyh9f0ad1d_2 124 | - pygments=2.7.1=py_0 125 | - pyjwt=2.0.1=pyhd8ed1ab_0 126 | - pyopenssl=20.0.1=pyhd8ed1ab_0 127 | - pyparsing=2.4.7=py_0 128 | - pyqt=5.9.2=py36h05f1152_2 129 | - pysocks=1.7.1=py36h5fab9bb_3 130 | - python=3.6.12=hcff3b4d_2 131 | - python-dateutil=2.8.1=py_0 132 | - python_abi=3.6=1_cp36m 133 | - pytorch=1.7.1=py3.6_cuda11.0.221_cudnn8.0.5_0 134 | - pywavelets=1.1.1=py36h7b6447c_2 135 | - pyyaml=5.3.1=py36h7b6447c_1 136 | - qt=5.9.7=h5867ecd_1 137 | - readline=8.0=h7b6447c_0 138 | - requests=2.25.1=pyhd3deb0d_0 139 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 140 | - rsa=3.4.2=py_1 141 | - scikit-image=0.17.2=py36hdf5156a_0 142 | - scikit-learn=0.23.2=py36h0573a6f_0 143 | - scipy=1.5.2=py36h0b6359f_0 144 | - setuptools=51.1.2=py36h06a4308_4 145 | - sip=4.19.8=py36hf484d3e_0 146 | - six=1.15.0=py36h06a4308_0 147 | - snappy=1.1.8=he6710b0_0 148 | - sqlite=3.33.0=h62c20be_0 149 | - tensorboard=2.3.0=pyh4dce500_0 150 | - tensorboard-plugin-wit=1.7.0=pyh9f0ad1d_0 151 | - tensorboardx=2.1=py_0 152 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 153 | - tifffile=2021.1.14=pyhd3eb1b0_1 154 | - tk=8.6.10=hbc83047_0 155 | - toolz=0.11.1=py_0 156 | - torchvision=0.8.2=py36_cu110 157 | - tornado=6.1=py36h27cfd23_0 158 | - tqdm=4.56.0=pyhd8ed1ab_0 159 | - traitlets=4.3.3=py36_0 160 | - typing-extensions=3.7.4.3=0 161 | - typing_extensions=3.7.4.3=py_0 162 | - urllib3=1.26.2=pyhd8ed1ab_0 163 | - wcwidth=0.2.5=py_0 164 | - werkzeug=1.0.1=pyh9f0ad1d_0 165 | - wheel=0.36.2=pyhd3eb1b0_0 166 | - xorg-fixesproto=5.0=h14c3975_1002 167 | - xorg-inputproto=2.3.2=h14c3975_1002 168 | - xorg-kbproto=1.0.7=h14c3975_1002 169 | - xorg-libx11=1.6.12=h516909a_0 170 | - xorg-libxau=1.0.9=h14c3975_0 171 | - xorg-libxext=1.3.4=h516909a_0 172 | - xorg-libxfixes=5.0.3=h516909a_1004 173 | - xorg-libxi=1.7.10=h516909a_0 174 | - xorg-xextproto=7.3.0=h14c3975_1002 175 | - xorg-xproto=7.0.31=h14c3975_1007 176 | - xz=5.2.5=h7b6447c_0 177 | - yaml=0.2.5=h7b6447c_0 178 | - yarl=1.6.3=py36h1d69622_0 179 | - zipp=3.4.0=py_0 180 | - zlib=1.2.11=h7b6447c_3 181 | - zstd=1.4.5=h9ceee32_0 182 | prefix: /scratch/miniconda3/envs/mist 183 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | # Partly from https://raw.githubusercontent.com/vcg-uvic/lf-net-release/master/common/argparse_utils.py 2 | 3 | import json 4 | import argparse 5 | import sys 6 | 7 | def str2bool(v): 8 | return v.lower() in ('true', '1', 'yes', 'y', 't') 9 | 10 | def get_mist_config(): 11 | parser = argparse.ArgumentParser() 12 | ## --- General Settings 13 | # Json Path 14 | parser.add_argument('--path_json', type=str, default='') 15 | # Use Seed or Not 16 | parser.add_argument('--set_seed', type=str2bool, default=True) 17 | # Run name and comment 18 | parser.add_argument('--name', type=str, default='test') 19 | parser.add_argument('--comment', type=str, default='') 20 | # Dataset base path 21 | parser.add_argument('--dataset_dir', type=str, default='/home/yjin/datasets') 22 | # Model base path 23 | parser.add_argument('--model_dir', type=str, default='./pretrained_models') 24 | # Choose dataset 25 | parser.add_argument('--dataset', type=str, default='mnist_hard') 26 | # Number of Epoch 27 | parser.add_argument('--epochs', type=int, default=2000) 28 | # Batch Size 29 | parser.add_argument('--batch_size', type=int, default=32) 30 | # Save Weight 31 | parser.add_argument('--save_weights',type=str2bool, default=False) 32 | # Resume 33 | parser.add_argument('--resume',type=str2bool, default=False) 34 | # Validation 35 | parser.add_argument('--run_val',type=str2bool, default=True) 36 | parser.add_argument('--val_path', type=str, default='./val_results') 37 | # Test 38 | parser.add_argument('--test_path', type=str, default='./test_results') 39 | # Data argumentation (only works on PASCAL dataset) 40 | parser.add_argument('--rand_horiz_flip', type=str2bool, default=False) 41 | parser.add_argument('--rand_maskout', type=str2bool, default=False) 42 | # Pretrained model path 43 | parser.add_argument('--pretrained_resnet',type=str2bool,default= True) 44 | # Tensorboard 45 | parser.add_argument('--summary_period', type=float, default=1) 46 | parser.add_argument('--save_period', type=float, default=5) 47 | parser.add_argument('--valid_period', type=float, default=5) 48 | parser.add_argument('--cooldown_period', type=float, default=-1) 49 | parser.add_argument('--log_dir', type=str, default='logs/') 50 | 51 | ## --- Training Settings 52 | # Detector Learning Rate 53 | parser.add_argument('--lr_detector', type=float, default=1e-4) 54 | # Key Points Learning Rate and Iteration 55 | parser.add_argument('--k_iter', type=int, default=2) 56 | parser.add_argument('--lr_k', type=float, default=1e4) 57 | # Classifier Learning Rate 58 | parser.add_argument('--lr_task', type=float, default=1e-4) 59 | # Reconstruction method (gaussian, single_point) 60 | parser.add_argument('--heatmap_reconstruct',type=str,default='single_point') 61 | ## Task network loss function 62 | parser.add_argument('--loss_type', type=str, default='MSE') 63 | 64 | ## --- Network Settings 65 | # Num of classes 66 | parser.add_argument('--num_classes', type=int, default=10) 67 | # Image size 68 | parser.add_argument('--image_size', type=int, default=80) 69 | 70 | ## --- Detector Settings 71 | # Detector Backbone VGG16 or ResNet 72 | parser.add_argument('--detector_backbone',type=str,default='CustomResNet') 73 | # Softmax kernal size to heatmap ratio 74 | parser.add_argument('--sm_kernal_size_ratio', type=float, default=0.2) 75 | # NMS kernal size to heatmap ratio 76 | parser.add_argument('--nms_kernal_size_ratio', type=float, default=0.05) 77 | # Number of Key Points 78 | parser.add_argument('--k', type=int, default=9) 79 | # Spatial Softmax 80 | parser.add_argument('--spatial_softmax', type=str2bool, default=True) 81 | parser.add_argument('--softmax_strength', type=float, default=10) 82 | # Sub pixel accuracy 83 | parser.add_argument('--sub_pixel_kp', type=str2bool, default=False) 84 | ## Bbox size 85 | parser.add_argument('--anchor_size', type=float, default=0.25) 86 | 87 | ## --- Classifier Settings 88 | # Patch extraction 89 | parser.add_argument('--patch_from_featuremap', type=str2bool, default=False) 90 | # Patch Size 91 | parser.add_argument('--patch_size', type=int, default=32) 92 | 93 | config = parser.parse_args() 94 | 95 | # overwrite with configs in json 96 | if config.path_json != '': 97 | with open(config.path_json) as f: 98 | params = json.load(f) 99 | for key,value in params.items(): 100 | setattr(config,key,value) 101 | 102 | return config 103 | 104 | def print_config(config): 105 | print('---------------------- CONFIG ----------------------') 106 | print() 107 | args = list(vars(config)) 108 | args.sort() 109 | for arg in args: 110 | print(arg.rjust(25,' ') + ' ' + str(getattr(config, arg))) 111 | print() 112 | print('----------------------------------------------------') 113 | 114 | def config_to_string(config): 115 | string = '\n\n' 116 | string += 'python ' + ' '.join(sys.argv) 117 | string += '\n\n' 118 | # string += '---------------------- CONFIG ----------------------\n' 119 | args = list(vars(config)) 120 | args.sort() 121 | for arg in args: 122 | string += arg.rjust(25,' ') + ' ' + str(getattr(config, arg)) + '\n\n' 123 | # string += '----------------------------------------------------\n' 124 | return string 125 | -------------------------------------------------------------------------------- /utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def one_hot_classification_loss(logits, labels_gt, num_classes, loss_type='MSE', power=2): 4 | diag = {} 5 | logits = torch.mean(logits, dim=1) 6 | logits_gt = torch.nn.functional.one_hot(labels_gt, num_classes).float() 7 | logits_gt = torch.mean(logits_gt, dim=1) 8 | loss = torch.nn.MSELoss(reduction='none')(logits, logits_gt) 9 | loss = loss.sum(dim=1) /2 10 | diag['loss_per_sample'] = loss.detach() 11 | return loss.mean(), diag 12 | 13 | -------------------------------------------------------------------------------- /utils/summary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | from tensorboardX import SummaryWriter 5 | from PIL import Image, ImageDraw, ImageFont 6 | import matplotlib.pyplot as plt 7 | import matplotlib 8 | from io import BytesIO 9 | 10 | def figure_to_numpy(figure, close=True): 11 | buf = BytesIO() 12 | figure.savefig(buf, format='png', bbox_inches='tight', transparent=False, dpi=300) 13 | if close: 14 | plt.close(figure) 15 | buf.seek(0) 16 | arr = matplotlib.image.imread(buf, format='png')[:,:,:3] 17 | arr = np.moveaxis(arr, source=2, destination=0) 18 | return arr 19 | 20 | class CustomSummaryWriter(SummaryWriter): 21 | # def __init__(self): 22 | # super(CustomSummaryWriter, self).__init__( 23 | 24 | def add_images_heatmap(self, name, images, heatmap, iteration): 25 | heatmap = self.draw_heatmap_on_images(images, heatmap) 26 | self.add_images(name, heatmap, iteration) 27 | 28 | def add_images(self, name, images, iteration, boxes_infer=None, boxes_gt= None, labels=None, resize=None, match=None): 29 | # images [B, C, H, W] 30 | max_images = min(images.shape[0],20) 31 | if len(images.shape) == 3: 32 | images = images.unsqueeze(0) 33 | if len(images.shape) == 4: 34 | images = images[0:max_images, ...] 35 | elif len(images.shape) == 5: 36 | images = images[0:max_images, ...] 37 | images = images.view(-1, images.shape[-3], images.shape[-2], images.shape[-1]) 38 | else: 39 | raise Exception('images.shape() {}'.format(images.shape)) 40 | 41 | if resize is not None: 42 | w = int(images.shape[-1] * resize) 43 | h = int(images.shape[-2] * resize) 44 | images = torch.nn.functional.interpolate(images, (h,w), mode='nearest') 45 | 46 | if (images.shape[1]!=3): 47 | images = torch.mean(images,dim=1).unsqueeze(1).repeat(1,3,1,1) 48 | images = ((images - images.min()) / (images.max() - images.min()) * 255.0).byte() 49 | # images = (images * 255.0).byte() 50 | if boxes_infer is not None: 51 | if resize is not None: 52 | boxes_infer = boxes_infer * resize 53 | images = self.draw_boxes_on_images(images, boxes_infer, labels, match) 54 | 55 | if boxes_gt is not None: 56 | if resize is not None: 57 | boxes_gt = boxes_gt * resize 58 | images = self.draw_boxes_on_images(images, boxes_gt) 59 | 60 | image = torchvision.utils.make_grid(images, nrow=max_images, padding=1, pad_value=255) 61 | self.add_image(name, image, iteration) 62 | 63 | 64 | def draw_boxes_on_images(self, images, boxes, labels=None, match=None): 65 | fnt = ImageFont.load_default() 66 | image_np = np.zeros(images.shape, dtype=np.uint8) 67 | for i in range(images.shape[0]): 68 | img = images[i,...].cpu().permute(1, 2, 0).numpy() 69 | H,W,_ = img.shape 70 | img = Image.fromarray(img, 'RGB') 71 | draw = ImageDraw.Draw(img) 72 | for j in range(boxes.shape[1]): 73 | kp = boxes[i,j,:].tolist() 74 | if match is not None: 75 | color = (0,255,0) if int(round(match[i,j,0].item())) else (255,0,0) 76 | draw.rectangle(kp, outline=color, fill=None) 77 | draw.text((kp[2]-20, kp[3]-20), '{}'.format(int(match[i,j,1].item()*100)), fill=color, font=fnt) 78 | if labels is not None: 79 | color = (0,255,0) if int(round(match[i,j,2].item())) else (255,0,0) 80 | draw.text((kp[0]+2, kp[1]), str(labels[i][j].item()), fill=color, font=fnt) 81 | draw.text((W-15*(j+1), H-15), str(labels[i][j].item()), fill=(100,100,0), font=fnt) 82 | else: 83 | color = (0,0, 255) 84 | draw.rectangle(kp, outline=color, fill=None) 85 | 86 | 87 | img = np.asarray(img) 88 | image_np[i,...] = np.transpose(img, (2, 0, 1)) 89 | return torch.from_numpy(image_np) 90 | 91 | def draw_heatmap_on_images(self, images, heatmap): 92 | plt.switch_backend('agg') 93 | # plt.set_cmap('jet') 94 | plt.set_cmap('jet') 95 | hmin=heatmap.min() 96 | hmax=heatmap.max() 97 | fig, axes = plt.subplots(1, images.shape[0], figsize=(16,2)) 98 | if images.shape[0] == 1: 99 | axes = np.array([axes]) 100 | for i in range(images.shape[0]): 101 | I = images[i,...].cpu().permute(1, 2, 0).numpy() 102 | H = heatmap[i,...].detach().cpu().numpy() 103 | ax = axes[i] 104 | ax.imshow(I) 105 | im = ax.imshow(H, alpha=1.0, vmin=hmin, vmax=hmax) 106 | 107 | for ax in axes: 108 | ax.set_xticks([]) 109 | ax.set_yticks([]) 110 | 111 | fig.subplots_adjust(wspace=0, hspace=0) 112 | plt.colorbar(im, ax=axes.ravel().tolist()) 113 | output = figure_to_numpy(fig) 114 | 115 | return torch.from_numpy(output) 116 | -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def to_gpu(x): 5 | if isinstance(x, torch.Tensor): 6 | return x.cuda() 7 | return tuple([t.cuda() for t in x]) 8 | 9 | def eval_accuracy (bboxes, bboxes_gt, labels, labels_gt, num_class, background_class=True): 10 | correct_detection = torch.zeros(labels.shape) 11 | # K [B, N, (x,y,w,h)] 12 | # labels [B, N] 13 | output = {} 14 | 15 | # total number of bbox in a batch 16 | N = (bboxes.shape[0] * bboxes.shape[1]) 17 | 18 | # min IOU used for evaluation 19 | minIoU = 0.5 20 | 21 | # compute IoU between each keypoint and keypoint_gt 22 | bbox_min = (bboxes[:,:,0:2] - bboxes[:,:,2:4] / 2).unsqueeze(2) # [B, nk, 1, 2] 23 | bbox_max = (bboxes[:,:,0:2] + bboxes[:,:,2:4] / 2).unsqueeze(2) 24 | bbox_gt_min = (bboxes_gt[:,:,0:2] - bboxes_gt[:,:,2:4] / 2).unsqueeze(1) # [B, 1, nkg, 2] 25 | bbox_gt_max = (bboxes_gt[:,:,0:2] + bboxes_gt[:,:,2:4] / 2).unsqueeze(1) 26 | 27 | botleft = torch.max(bbox_min, bbox_gt_min) 28 | topright = torch.min(bbox_max, bbox_gt_max) 29 | 30 | inter = torch.prod(torch.nn.functional.relu(topright - botleft), dim=3) 31 | area_bbox = torch.prod(bbox_max - bbox_min, dim=3) 32 | area_bbox_gt = torch.prod(bbox_gt_max - bbox_gt_min, dim=3) 33 | union = area_bbox + area_bbox_gt - inter 34 | iou = inter / union # [B, k, kg, 1] 35 | iou[iou != iou] = 0 36 | 37 | # set iou of background class to 0 38 | if background_class: 39 | iou = (labels_gt!=0).unsqueeze(1).type(torch.float32)*(labels!=0).unsqueeze(-1).type(torch.float32) *iou 40 | 41 | # total number of objects in batch 42 | if background_class: 43 | num_objects = (labels_gt!=0).sum().item() 44 | num_detetcions = (labels!=0).sum().item() 45 | else: 46 | num_objects = N 47 | num_detetcions = N 48 | 49 | # generate label for visualization 50 | match_det = (torch.max(iou, dim=2)[0] > minIoU) 51 | selected_gt = torch.gather(labels_gt, dim=1, index=torch.max(iou, dim=2)[1]) 52 | match_class = torch.eq(labels, selected_gt) 53 | output['keypoint_match_detection'] = torch.stack([match_det.float(), torch.max(iou, dim=2)[0],match_class.float()], dim=2) 54 | 55 | # compute detection accuracy 56 | acc_iou = ((torch.max(iou, dim=1)[0] > minIoU)).sum().float() / num_objects 57 | output['acc_det'] = acc_iou 58 | 59 | # prepare iou matrix for computing detection and classification accuracy 60 | labels_match = torch.eq(labels.unsqueeze(2), labels_gt.unsqueeze(1)) # [B, k, kg] 61 | iou = iou * labels_match.float() 62 | 63 | 64 | # prepare distance matrix for computing detection and classification accuracy (dist>1 outside gt bbox) 65 | dist = torch.max(torch.abs(bboxes[:,:,0:2].unsqueeze(2) - bboxes_gt[:,:,0:2].unsqueeze(1))/(bboxes_gt[:,:,2:4].unsqueeze(1)/2),dim=-1)[0] 66 | 67 | # replace invalid value in dist matrix to 2 so it won't be chosen 68 | dist[dist!=dist] =2 69 | dist[dist ==float('inf')] = 2 70 | 71 | # if label do not match, replace to 2 72 | dist[~labels_match] = 2 73 | 74 | # replace dist to background class to 2 75 | if background_class: 76 | dist[(labels_gt==0).unsqueeze(1).repeat([1,labels.shape[1],1])] = 2 77 | 78 | 79 | # compute precision and recall 80 | tp_iou = 0 81 | tp_center = 0 82 | for b in range(bboxes_gt.shape[0]): 83 | for k in range(bboxes_gt.shape[1]): 84 | val, idx = torch.max(iou[b,:,k], dim=0) 85 | if val >= minIoU: 86 | iou[b, idx, :] = 0.0 87 | tp_iou += 1 88 | 89 | val, idx = torch.min(dist[b,:,k], dim=0) 90 | if val <= 1: 91 | dist[b, idx, :] = 2 92 | tp_center += 1 93 | correct_detection[b,idx] = 1 94 | if num_detetcions ==0: 95 | precision_iou = 0 96 | precision_center= 0 97 | else: 98 | precision_iou = tp_iou / num_detetcions 99 | precision_center= tp_center / num_detetcions 100 | recall_iou = tp_iou/ num_objects 101 | recall_center = tp_center/ num_objects 102 | if recall_iou == 0 and precision_iou == 0: 103 | f1_iou = 0 104 | else: 105 | f1_iou = 2 * precision_iou * recall_iou / (precision_iou+recall_iou) 106 | if recall_center ==0 and precision_iou ==0: 107 | f1_center = 0 108 | else: 109 | f1_center = 2 * precision_center * recall_center / (precision_center+recall_center) 110 | 111 | 112 | output['f1_iou'] = f1_iou 113 | output['f1_center'] = f1_center 114 | output['precision_iou'] = precision_iou 115 | output['precision_center'] = precision_center 116 | output['recall_iou'] = recall_iou 117 | output['recall_center'] = recall_center 118 | 119 | output['tp_center'] = tp_center 120 | output['tp_iou'] = tp_iou 121 | output['num_objects'] = num_objects 122 | output['num_detetcions'] = num_detetcions 123 | 124 | output['correct_detection'] = correct_detection 125 | 126 | # compute pure classification accuracy 127 | dt_1h = torch.nn.functional.one_hot(labels, num_class).sum(dim=1)[:,1:] 128 | gt_1h = torch.nn.functional.one_hot(labels_gt, num_class).sum(dim=1)[:,1:] 129 | acc_class_all_detect = 1.0 - torch.relu(gt_1h-dt_1h).sum().float()/num_objects 130 | 131 | output['acc_class'] = acc_class_all_detect 132 | 133 | return output 134 | 135 | 136 | def gaussian(size, std=0.5): 137 | y, x = torch.meshgrid([torch.linspace(0, 1, steps=size[0]), torch.linspace(0, 1, steps=size[1])]) 138 | x = 2 * (x - 0.5) 139 | y = 2 * (y - 0.5) 140 | g = (x * x + y * y) / (2 * std * std) 141 | g = torch.exp(-g) 142 | g = g / (std * math.sqrt(2 * math.pi)) 143 | return g 144 | 145 | def gaussian2(size, center=None, std=0.5): 146 | if center is None: 147 | center = torch.tensor([[0.5, 0.5]]) 148 | 149 | y, x = torch.meshgrid([torch.linspace(0, 1, steps=size[0]), torch.linspace(0, 1, steps=size[1])]) 150 | # print(x.unsqueeze(0).shape, .shape) 151 | x = 2 * (x.unsqueeze(0) - center[:,0,None,None]) 152 | y = 2 * (y.unsqueeze(0) - center[:,1,None,None]) 153 | g = (x * x + y * y) / (2 * std * std) 154 | g = torch.exp(-g) 155 | return g 156 | 157 | 158 | def circle_mask(size, center=None, radius=0.5): 159 | if center is None: 160 | center = torch.tensor([[0.5, 0.5]]) 161 | 162 | y, x = torch.meshgrid([torch.linspace(0, 1, steps=size[0]), torch.linspace(0, 1, steps=size[1])]) 163 | # print(x.unsqueeze(0).shape, .shape) 164 | x = 2 * (x.unsqueeze(0) - center[:,0,None,None]) 165 | y = 2 * (y.unsqueeze(0) - center[:,1,None,None]) 166 | d = (x * x + y * y) < (radius * radius) 167 | return d.float() 168 | 169 | def half_mask(shape): 170 | angle = torch.rand((shape[0], 1, 1)) * math.pi * 2 171 | y, x = torch.meshgrid([torch.linspace(-1, 1, steps=shape[-2]), torch.linspace(-1, 1, steps=shape[-1])]) 172 | x = x.unsqueeze(0) 173 | y = y.unsqueeze(0) 174 | nx = torch.cos(angle) 175 | d = x * torch.cos(angle) + y * torch.sin(angle) 176 | mask = (d > 0).float() 177 | return mask 178 | 179 | def square_mask(shape, size = 0.5): 180 | y, x = torch.meshgrid([torch.linspace(-1, 1, steps=shape[-2]), torch.linspace(-1, 1, steps=shape[-1])]) 181 | d = torch.max(torch.abs(x), torch.abs(y)) 182 | mask = (d < size).float() 183 | return mask 184 | 185 | def calc_center_of_mass(heatmap,kernel_size): 186 | heatmap_exp = torch.exp(heatmap) 187 | heatmap_unf = torch.nn.functional.unfold(heatmap_exp, (kernel_size, kernel_size),padding = kernel_size//2).transpose(1,2) 188 | w_x = (torch.arange(kernel_size)-kernel_size//2).unsqueeze(0).expand(kernel_size,-1).reshape(-1,1).float() 189 | w_y = (torch.arange(kernel_size)-kernel_size//2).unsqueeze(1).expand(-1,kernel_size).reshape(-1,1).float() 190 | w_s = torch.ones(kernel_size,kernel_size).reshape(-1,1).float() 191 | heatmap_unf_x = heatmap_unf.matmul(w_x) 192 | heatmap_unf_y = heatmap_unf.matmul(w_y) 193 | heatmap_unf_s = heatmap_unf.matmul(w_s) 194 | offset_unf = torch.cat([heatmap_unf_x/heatmap_unf_s, heatmap_unf_y/heatmap_unf_s],dim=-1).transpose(1, 2) 195 | offset = torch.nn.functional.fold(offset_unf, (heatmap.shape[2], heatmap.shape[3]), (1, 1)) 196 | grid_x = torch.arange(heatmap.shape[3]).unsqueeze(0).expand(heatmap.shape[2],-1).float() 197 | grid_y = torch.arange(heatmap.shape[2]).unsqueeze(1).expand(-1,heatmap.shape[3]).float() 198 | grid_xy = torch.cat([grid_x.unsqueeze(0),grid_y.unsqueeze(0)],dim=0) 199 | center = grid_xy+offset 200 | return center 201 | def inverse_heatmap(keypoints, out_shape): 202 | heatmap = torch.zeros(out_shape) 203 | batch = torch.arange(keypoints.shape[0]).repeat(keypoints.shape[1], 1).permute(1, 0) 204 | x = keypoints[:, :, 0] 205 | y = keypoints[:, :, 1] 206 | x = torch.clamp((x + 0.5).long(), 0, out_shape[-1] - 1) 207 | y = torch.clamp((y + 0.5).long(), 0, out_shape[-2] - 1) 208 | heatmap[batch, 0, y, x] = 1.0 209 | return heatmap.detach() 210 | 211 | def inverse_heatmap_gaussian(bboxes, out_shape, var_scale=0.125): 212 | # out_shape (B, 1, H, W) 213 | B, _, H, W = out_shape 214 | bboxes = torch.clamp(bboxes, 0.000001) 215 | index_x = torch.arange(W).repeat(H) 216 | index_y = torch.arange(H).unsqueeze(-1).repeat(1,W).reshape(-1) 217 | index = torch.cat((index_x.unsqueeze(-1),index_y.unsqueeze(-1)),dim=-1).float() 218 | exp_term = torch.matmul(torch.pow(((bboxes[:,:,:2].unsqueeze(-2)-index)/(bboxes[:,:,2:]*var_scale).unsqueeze(-2)),2),torch.tensor([[0.5],[0.5]])).squeeze() 219 | norm = torch.exp(-exp_term)#/(bboxes[:,:,[2]]*var_scale*bboxes[:,:,[3]]*var_scale)/2/math.pi 220 | heatmap = torch.sum(norm,dim=1).reshape(out_shape) 221 | 222 | return heatmap.detach() 223 | 224 | def construct_dist_mat(kp_1, kp_2): 225 | # distance square matrix between two sets of points 226 | # kp_1, kp_2 [B,N,2] 227 | xy_1_sq_sum_vec = torch.matmul(kp_1**2,torch.ones(2,1)) 228 | xy_2_sq_sum_vec = torch.matmul(kp_2**2,torch.ones(2,1)) 229 | # row: kp_1 column: kp_2 230 | xy_12_sq_sum_mat = xy_1_sq_sum_vec + xy_2_sq_sum_vec.transpose(-1,-2) 231 | xy_mat = torch.matmul(kp_1, kp_2.transpose(-1,-2)) 232 | dist_mat = xy_12_sq_sum_mat - 2*xy_mat 233 | dist_mat = torch.max(dist_mat,torch.zeros_like(dist_mat)) 234 | return dist_mat 235 | 236 | def xys_to_xywh(boxes): 237 | return torch.cat([boxes, boxes[...,2,None]], dim=-1) 238 | 239 | def xyxy_to_xywh(boxes): 240 | wh = (boxes[:,:,2:4] - boxes[:,:,0:2]) 241 | center = (boxes[:,:,0:2] +wh/ 2) 242 | return torch.cat([center, wh], dim=2) 243 | 244 | def xywh_to_xyxy(boxes): 245 | K_min = (boxes[:,:,0:2] - boxes[:,:,2:4] / 2) 246 | K_max = (boxes[:,:,0:2] + boxes[:,:,2:4] / 2) 247 | return torch.cat([K_min, K_max], dim=2) 248 | 249 | def xys_to_xyxy(boxes): 250 | return xywh_to_xyxy(xys_to_xywh(boxes)) 251 | 252 | def scale_keypoints(boxes, scale=1.0): 253 | # boxes [B, N, 3/4] (xys or xywh) 254 | return torch.cat((boxes[:,:,:2], boxes[:,:,2:] * scale), dim=2) 255 | 256 | if __name__ == '__main__': 257 | eps = 1e-5 258 | # test center of mass 259 | print("Testing clac_center_of_mass ...") 260 | heatmap = heatmap = torch.randn(32, 1, 5, 5) 261 | kernel_size = 3 262 | center = calc_center_of_mass(heatmap, kernel_size) 263 | # check shape 264 | if heatmap.shape[0] != center.shape[0] or heatmap.shape[2] != center.shape[2] or heatmap.shape[3] != center.shape[3]: 265 | raise Exception("output shape of calc_center_of_mass is different from input shape") 266 | # check calculation 267 | heatmap_exp = torch.exp(heatmap) 268 | c_x = 1+(-torch.sum(heatmap_exp[0,0,0:3,0])+torch.sum(heatmap_exp[0,0,0:3,2]))/torch.sum(heatmap_exp[0,0,0:3,0:3]) 269 | c_y = 1+(-torch.sum(heatmap_exp[0,0,0,0:3])+torch.sum(heatmap_exp[0,0,2,0:3]))/torch.sum(heatmap_exp[0,0,0:3,0:3]) 270 | if torch.abs(c_x - center[0,0,1,1])>eps or torch.abs(c_y - center[0,1,1,1])>eps: 271 | raise Exception("calc_center_of_mass output wrong result") 272 | print("Pass") 273 | 274 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | def get_time_stamp(): 4 | dateTimeObj = datetime.now() 5 | time_stamp = '{:02d}'.format(dateTimeObj.year%100) + \ 6 | '{:02d}'.format(dateTimeObj.month%100) + \ 7 | '{:02d}'.format(dateTimeObj.day%100) + \ 8 | '-' + \ 9 | '{:02d}'.format(dateTimeObj.hour%100) + \ 10 | '{:02d}'.format(dateTimeObj.minute%100) 11 | return time_stamp 12 | 13 | class tensorboard_scheduler(): 14 | def __init__(self, eval_interval, save_interval, valid_interval,stop_time=-1): 15 | self.eval_interval = eval_interval 16 | self.save_interval = save_interval 17 | self.valid_interval = valid_interval 18 | self.start_time = datetime.now() 19 | self.eval_counter = 0 20 | self.save_counter = 0 21 | self.valid_counter = 0 22 | self.stop_time = stop_time 23 | def schedule(self): 24 | 25 | delta_secs = (datetime.now() - self.start_time).total_seconds() 26 | delta_mins = delta_secs/60 27 | 28 | if self.stop_time!=-1 and delta_mins>self.stop_time: 29 | return False, False, False 30 | 31 | if delta_mins > self.eval_counter*self.eval_interval: 32 | eval_flag = True 33 | self.eval_counter = int(delta_mins/self.eval_interval) + 1 34 | else: 35 | eval_flag =False 36 | 37 | if delta_mins > self.save_counter*self.save_interval: 38 | save_flag = True 39 | self.save_counter = int(delta_mins/self.save_interval) + 1 40 | else: 41 | save_flag = False 42 | 43 | if delta_mins > self.valid_counter*self.valid_interval: 44 | valid_flag = True 45 | self.valid_counter = int(delta_mins/self.valid_interval) + 1 46 | else: 47 | valid_flag = False 48 | 49 | return eval_flag, save_flag, valid_flag 50 | def get_delta_time(self): 51 | delta_secs = (datetime.now() - self.start_time).total_seconds() 52 | return delta_secs 53 | 54 | 55 | if __name__ == '__main__': 56 | # test scheduler 57 | import time 58 | import random 59 | scheduler = tensorboard_scheduler(2/60, 10/60, 20/60, 2) 60 | 61 | step = 0.4 62 | cur_time = 0 63 | for i in range(200): 64 | eval_flag, save_flag, valid_flag = scheduler.schedule() 65 | if eval_flag: 66 | print('eval at time {}'.format(scheduler.get_delta_time())) 67 | time.sleep(step/4) 68 | 69 | if save_flag: 70 | print('save at time {}'.format(scheduler.get_delta_time())) 71 | time.sleep(step/2) 72 | 73 | if valid_flag: 74 | print('valid at time {}'.format(scheduler.get_delta_time())) 75 | time.sleep(step*10) 76 | time.sleep(step+step*2*random.random()) 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /utils/viz_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image, ImageDraw, ImageFont 4 | 5 | def save_bbox_images(image, bbox, label, name, path, background_class): 6 | image = image.cpu().permute(1, 2, 0).numpy() 7 | image = draw_boxes_on_images(image, bbox, label, background_class) 8 | image.save(os.path.join(path,'{}.jpg'.format(name)),quality=95) 9 | 10 | def draw_boxes_on_images(image, bbox, label, background_class): 11 | scale = 224/max(image.shape) 12 | H = int(scale*image.shape[0]) 13 | W = int(scale*image.shape[1]) 14 | bbox = bbox*scale 15 | fnt = ImageFont.load_default() 16 | image = Image.fromarray(np.uint8(image*255), 'RGB') 17 | image = image.resize((W, H)) 18 | draw = ImageDraw.Draw(image) 19 | for i in range(bbox.shape[0]): 20 | if background_class and label[i]=='__background__': 21 | continue 22 | kp = bbox[i,:].tolist() 23 | color = (0,0, 255) 24 | draw.rectangle(kp, outline=color, fill=None) 25 | color = (0,255,0) 26 | if kp[0]<0: 27 | x_0 = kp[2] - len(label[i])*8 28 | else: 29 | x_0 = kp[0] + 1 30 | if kp[1]<0: 31 | x_1 = kp[3] - 12 32 | else: 33 | x_1 = kp[1] + 1 34 | draw.text((x_0, x_1), label[i], width=12, fill=color, font=fnt) 35 | return image --------------------------------------------------------------------------------