├── LICENSE ├── NOTICE ├── README.md ├── dataset ├── NovelObjects__test.json ├── NovelObjects__train.json ├── NovelSpaces__test.json └── NovelSpaces__train.json ├── inference.png ├── source ├── config.py ├── eval.py ├── losses │ ├── __init__.py │ ├── clustering_losses.py │ ├── fgbg_losses.py │ ├── loss_utils.py │ ├── losses.py │ └── rpn_losses.py ├── models │ ├── __init__.py │ ├── backbones.py │ ├── clustering_models.py │ ├── fgbg_model.py │ ├── model.py │ ├── model_utils.py │ ├── poke_rcnn.py │ └── rpn_models.py ├── pipeline │ ├── actor.py │ ├── evaluator.py │ ├── tester.py │ └── trainer.py ├── replay_memory │ ├── __init__.py │ ├── replay_memory.py │ ├── replay_pil.py │ └── replay_tensor.py ├── requirements.txt ├── startx.py ├── tools │ ├── coco_tools.py │ ├── data_utils.py │ ├── dist_training_tools.py │ └── logger.py └── train.py ├── trained_model_novel_objects └── clustering_model_weights_900.pth ├── trained_model_novel_spaces └── clustering_model_weights_900.pth └── training_video.gif /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | We use code from the COCO API project 2 | 3 | ======================================================================= 4 | License for cocoapi 5 | ======================================================================= 6 | 7 | Copyright (c) 2014, Piotr Dollar and Tsung-Yi Lin 8 | All rights reserved. 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are met: 12 | 13 | 1. Redistributions of source code must retain the above copyright notice, this 14 | list of conditions and the following disclaimer. 15 | 2. Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 26 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | The views and conclusions contained in the software and documentation are those 31 | of the authors and should not be interpreted as representing official policies, 32 | either expressed or implied, of the FreeBSD Project. 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Learning About Objects by Learning to Interact with Them](https://arxiv.org/pdf/2006.09306.pdf) 2 | 3 | By Martin Lohmann, Jordi Salvador, Aniruddha Kembhavi, and Roozbeh Mottaghi 4 | 5 | We present a computational framework to discover objects and learn their physical properties along the paradigm of 6 | Learning from Interaction. Our agent, when placed within the AI2-THOR environment, interacts with its world by 7 | applying forces, and uses the resulting raw visual changes to learn instance segmentation and relative mass estimation 8 | of interactable objects, without access to ground truth labels or external guidance. Our agent learns efficiently and 9 | effectively; not just for objects it has interacted with before, but also for novel instances from seen categories 10 | as well as novel object categories. 11 | 12 | ![](inference.png) 13 | 14 | ## Citing 15 | ``` 16 | @inproceedings{ml2020learnfromint, 17 | author = {Lohmann, Martin and Salvador, Jordi and Kembhavi, Aniruddha and Mottaghi, Roozbeh}, 18 | title = {Learning About Objects by Learning to Interact with Them}, 19 | booktitle = {NeurIPS}, 20 | year = {2020} 21 | } 22 | ``` 23 | 24 | ## Setup 25 | 26 | #### Requirements 27 | 28 | This code has been developed and tested on Ubuntu 16.04.4 LTS. We assume xserver-xorg is installed, and 29 | CUDA drivers are available along at least two compute devices (with 12 GB of memory each) for training or one device 30 | for evaluation. We use `python3.6`. 31 | 32 | #### Structure 33 | 34 | The following subfolders are available: 35 | - `dataset`, containing training and test datasets for both `NovelSpaces` and 36 | `NovelObjects` scenarios. 37 | - `source`, containing the training and eval scripts as well as all used classes structured in several folders and a 38 | `requirements.txt` file. 39 | - `trained_model_novel_spaces` and `trained_model_novel_objects`, containing the trained model weights used for the 40 | results reported in the manuscript for the corresponding datasets. 41 | 42 | #### Environment 43 | 44 | We recommend creating a virtual environment with `python3.6` to run the code. For example, from the top 45 | level folder, we can run 46 | - `python3.6 -mvenv learnfromint` or 47 | - `virtualenv --python=python3.6 learnfromint`. 48 | 49 | Then, we can activate it by `source learnfromint/bin/activate`. 50 | 51 | In order to install all python requirements, we can: 52 | - `cd source`, 53 | - `cat requirements.txt | xargs -n 1 -L 1 pip3 install`, and 54 | - `python3.6 -c "import ai2thor.controller; ai2thor.controller.Controller(download_only=True)"`, which will download 55 | the required binaries for AI2-THOR. 56 | 57 | ## Running the code 58 | 59 | #### Training 60 | 61 | If xorg is not already running (even if it is installed), we provide a utility script that must be run as root: 62 | - `sudo python3.6 startx.py &> ~/logxserver &` 63 | 64 | Then, we can run the training script from the `source` folder: 65 | - `python3.6 train.py [output_folder] ../dataset 0` for `NovelObjects`, or 66 | - `python3.6 train.py [output_folder] ../dataset 1` for `NovelSpaces`. 67 | 68 | Note that, depending on the compute capabilities of the machine, training can take in the order of 2 days to 69 | complete. 70 | 71 | #### Evaluation 72 | 73 | Again, make sure xorg is running or `sudo python3.6 startx.py &> ~/logxserver &`. 74 | 75 | Then, we can for example run eval on the pretrained models from the `source` folder: 76 | - `python3.6 eval.py ../trained_model_novel_objects ../dataset 0 &> ../log_eval0 &` for `NovelObjects`, or 77 | - `python3.6 eval.py ../trained_model_novel_spaces ../dataset 1 &> ../log_eval1 &` for `NovelSpaces` 78 | 79 | and track the results by e.g. 80 | - `tail -f ../log_eval0` 81 | 82 | In order to access a summary of the results, once the evaluation is complete, we can just 83 | - `cat ../log_eval0 | grep RESULTS` 84 | 85 | ## Evaluation 86 | 87 | #### About stochasticity 88 | Even though our model does not require interaction at test time, to minimize storage space and data downloads, we 89 | provide our evaluation dataset in this release in terms of AI2-THOR controller states. Some minor stochasticity is 90 | involved when the controller renders these states into model inputs (images) and ground truth labels. For this reason, 91 | the evaluation metrics for a model checkpoint can fluctuate slightly. 92 | 93 | #### Results 94 | The results obtained by the `eval.py` script should fluctuate around the following values: 95 | 96 | | Dataset | BBox AP50 | BBox AP | Segm AP50 | Segm AP | Mass+BBox AP50 | Mass mean per-class accuracy | 97 | |:------------|:---------:|:-------:|:---------:|:-------:|:--------------:|:----------------------------:| 98 | | NovelObjects| 24.19 | 11.65 | 22.00 | 10.24 | 11.85 | 50.79 | 99 | | NovelSpaces | 27.59 | 13.44 | 25.01 | 11.00 | 11.01 | 55.86 | 100 | 101 | #### Qualitative example 102 | Here is an illustration of the different ingredients of our training process: 103 | 104 | ![](training_video.gif) 105 | -------------------------------------------------------------------------------- /inference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/learning_from_interaction/a266bc16d682832aa854348fa557a30d86b84674/inference.png -------------------------------------------------------------------------------- /source/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class GlobalConfig: 6 | def __init__(self): 7 | self.resolution = 300 8 | self.grid_size = 100 9 | self.stride = self.resolution // self.grid_size 10 | self.depth = True 11 | self.respawn_until_object = False 12 | self.superpixels = True 13 | self.use_of = False 14 | self.model_gpu = 0 15 | self.actor_gpu = 1 16 | self.of_gpu = 7 17 | self.max_pokes = 32 # This limit is required to support one hot encoding of poking masks 18 | self.val_scenes = 1 19 | self.distributed = False 20 | self.correct_depth = True 21 | 22 | 23 | global_config = GlobalConfig() 24 | 25 | 26 | class ActorConfig: 27 | def __init__(self): 28 | # configs for self-supervision module 29 | self.video_mode = global_config.use_of 30 | self.raw_feedback = False 31 | self.superpixel_postprocessed_feedback = True 32 | self.superpixel_postprocessing_threshold = .25 33 | self.connectedness_postprocessed_feedback = False 34 | self.fatten = False 35 | self.pixel_change_threshold = .01 36 | self.hsv = True 37 | self.check_change_kernel = np.array([[np.exp(-np.sqrt((x - 2) ** 2 + (y - 2) ** 2) / 1) 38 | for x in range(5)] for y in range(5)]) 39 | 40 | # configs for videoPCA 41 | self.pca = False and self.video_mode 42 | self.num_pca_components = 4 43 | self.smooth_mask = False and self.pca 44 | kernel = torch.tensor([[np.exp(-np.sqrt((x - 7) ** 2 + (y - 7) ** 2) / 2 / 5 ** 2) 45 | for x in range(15)] 46 | for y in range(15)]).unsqueeze(0).unsqueeze(1) 47 | self.pca_smoothing_kernel = kernel / kernel.sum() 48 | centering_kernel = torch.tensor([[np.exp(-np.sqrt((x - 20) ** 2 + (y - 20) ** 2) / 2 / 5 ** 2) 49 | for x in range(41)] 50 | for y in range(41)]).unsqueeze(0).unsqueeze(1) 51 | self.pca_centering_kernel = centering_kernel 52 | self.soft_mask_threshold = .005 53 | self.colres1 = 100 # 100 54 | self.colres2 = 25 # 25 55 | self.colres3 = 10 # 10 56 | self.num_color_bins = np.arange(self.colres1 * self.colres2 * self.colres3) + 1 57 | self.hyst_thresholds = (.8, .5) 58 | 59 | # configs for interaction 60 | self.instance_only = False 61 | self.force = 250 62 | self.force_buckets = [5, 30, 200] 63 | self.scaleable = True 64 | self.handDistance = 1.5 65 | self.visibilityDistance = 10 66 | self.max_poke_attempts = 3 67 | self.max_poke_keep = 1 68 | self.remove_after_poke = False 69 | 70 | # The following attributes filter the objects counted in the ground truth 71 | self.mass_buckets = [.5, 2.] 72 | self.mass_threshold = 150 73 | self.max_pixel_threshold = 300 ** 2 74 | self.min_pixel_threshold = 10 75 | self.data_files = ['unary_dataset__detectron2__60_30_30__train.json', 76 | 'unary_dataset__detectron2__60_30_30__valid.json'] 77 | self.use_dataset = True 78 | 79 | 80 | actor_config = ActorConfig() 81 | 82 | 83 | class BackboneConfig: 84 | def __init__(self): 85 | self.small = False 86 | 87 | 88 | class ModelConfigFgBg(BackboneConfig): 89 | def __init__(self): 90 | super(ModelConfigFgBg, self).__init__() 91 | self.uncertainty = False 92 | self.superpixel = False 93 | self.fatten = True 94 | 95 | 96 | class ClusteringModelConfig(BackboneConfig): 97 | def __init__(self): 98 | super(ClusteringModelConfig, self).__init__() 99 | self.backbone = 'unet' # unet or r50fpn 100 | self.out_dim = 16 101 | self.max_masks = global_config.max_pokes 102 | self.overlapping_objects = False 103 | self.filter = False 104 | self.uncertainty = int(False) 105 | self.distance_function = 'L2' # L2 or Cosine 106 | self.threshold = 1 # 1./.9 for L2/Cosine 107 | self.margin_threshold = (1, 1) 108 | self.reset_value = 10000 # 10000 / 0 for L2 / Cosine 109 | self.use_coordinate_embeddings = True 110 | self.freeze = False 111 | 112 | 113 | class ROIModuleConfig: 114 | def __init__(self): 115 | self.boxes = np.array([[0, 0, 9, 9], 116 | [0, 10, 9, 9], 117 | [10, 0, 9, 9], 118 | [10, 10, 9, 9], 119 | [4, 4, 9, 9], 120 | [-4, -4, 9, 9], 121 | [-4, 4, 9, 9], 122 | [4, -4, 9, 9], 123 | [0, 0, 19, 19], 124 | [0, 10, 19, 19], 125 | [10, 0, 19, 19], 126 | [10, 10, 19, 19], 127 | [0, 4, 19, 9], 128 | [4, 0, 9, 19], 129 | [-9, -4, 19, 9], 130 | [-4, -9, 9, 19], 131 | [0, 0, 39, 39]]) # offset_x, offset_y, delta_x, delta_y 132 | ''' 133 | base stride is 60 pixels = 20 grid cells 134 | 5 30x30 boxes / cell inside the base grid 135 | 3 30x30 boxes / cell half way between neighbouring grid cells 136 | 4 60x60 boxes / cell (1 centered, 3 half way) 137 | 2 30x60 box 138 | 2 60x30 box 139 | 1 120x120 box 140 | 141 | IoU thresholds for small cells should be tighter than for large ones 142 | 143 | ulc = upper left corners are at 0, 20, 40, 60, 80 144 | views will be [ulc + offset, ulc + offset + delta + 1] 145 | ''' 146 | self.positive_thresholds = [.35] * 8 + [.25] * 4 + [.3] * 4 + [.2] 147 | self.negative_thresholds = [.2] * 8 + [.15] * 4 + [.2] * 4 + [.1] 148 | self.num_anchors = len(self.boxes) 149 | self.num_rois = 16 150 | self.coarse_grid_size = 5 151 | self.poking_filter_threshold = 2.5 152 | self.nms_threshold = .4 153 | 154 | 155 | class RPNModelConfig(BackboneConfig): 156 | def __init__(self): 157 | super(RPNModelConfig, self).__init__() 158 | self.roi_config = ROIModuleConfig() 159 | self.teacher_forcing = False 160 | self.num_anchors = self.roi_config.num_anchors 161 | self.nms = True 162 | self.regression = False 163 | self.uncertainty = int(False) 164 | 165 | 166 | class MemoryConfigPIL: 167 | def __init__(self): 168 | self.capacity = 20000 169 | self.prioritized_replay = True 170 | self.bias_correct = False 171 | self.warm_start_memory = None 172 | self.flip_prob = .5 173 | self.jitter_prob = .8 174 | self.jitter = .3 175 | self.initial_priority = .5 176 | 177 | 178 | class MemoryConfigTensor: 179 | def __init__(self): 180 | self.capacity = 20000 181 | self.warm_start_memory = None 182 | self.num_workers = 0 183 | self.sizes = [(3 + global_config.depth, global_config.resolution, global_config.resolution), 184 | (global_config.max_pokes, global_config.grid_size, global_config.grid_size), 185 | (global_config.grid_size, global_config.grid_size), 186 | (global_config.grid_size, global_config.grid_size)] + \ 187 | ([(global_config.grid_size, global_config.grid_size)] if global_config.superpixels else []) 188 | self.dtypes = [torch.float32, # image 189 | torch.bool, # obj_masks 190 | torch.float32, # foreground 191 | torch.float32] + ([torch.int32] if global_config.superpixels else []) # background 192 | 193 | 194 | class ObjectnessLossConfig: 195 | def __init__(self): 196 | self.filter = False # has to match the corresponding attribute in model config file 197 | self.filter_threshold = -.3 198 | self.prioritized_replay = True 199 | self.foreground_threshold = 1.5 200 | self.objectness_weight = 1 201 | self.smoothness_weight = 0 202 | self.kernel = actor_config.check_change_kernel 203 | self.kernel_size = (self.kernel.shape[0] - 1) // 2 204 | self.check_change_kernel = actor_config.check_change_kernel 205 | self.superpixel_for_action_feedback = False 206 | self.robustify = None 207 | self.point_feedback_for_action = False 208 | self.localize_object_around_poking_point = True 209 | self.prioritize_default = .5 210 | self.prioritize_function = lambda score: (score - .5) ** 2 + .02 211 | 212 | 213 | class ObjectnessClusteringLossConfig(ObjectnessLossConfig): 214 | def __init__(self): 215 | super(ObjectnessClusteringLossConfig, self).__init__() 216 | self.threshold = 1 # Should match the threshold in model config file 217 | self.center_foreground = False 218 | self.scaleable = True 219 | 220 | 221 | class MaskAndMassLossConfig(ObjectnessClusteringLossConfig): 222 | def __init__(self): 223 | super(MaskAndMassLossConfig, self).__init__() 224 | self.mass_loss_weight = .1 225 | self.instance_only = False 226 | 227 | 228 | class ObjectnessRPNLossConfig(ObjectnessLossConfig): 229 | def __init__(self): 230 | super(ObjectnessRPNLossConfig, self).__init__() 231 | self.filter = False # No filter implemented yet for this model 232 | self.roi_config = ROIModuleConfig() 233 | self.regression = False # has to match the entry in RPNModelConfig 234 | self.deltas = [(0, 0, 0, 0), 235 | (0, 0, 1, 0), 236 | (0, 0, -1, 0), 237 | (0, 0, 0, 1), 238 | (0, 0, 0, -1), 239 | (1, 0, 1, 0), 240 | (-1, 0, -1, 0), 241 | (0, 1, 0, 1), 242 | (0, -1, 0, -1)] 243 | self.regression_weight = 1 244 | 245 | 246 | class FgBgLossConfig: 247 | def __init__(self): 248 | self.prioritized_replay = False 249 | self.restrict_positives = False 250 | self.restrict_negatives = True and not self.restrict_positives 251 | self.kernel = actor_config.check_change_kernel 252 | self.kernel_size = (self.kernel.shape[0] - 1) // 2 253 | self.foreground_threshold = 1.5 254 | 255 | 256 | class TrainerConfig: 257 | def __init__(self): 258 | self.log_path = None 259 | self.checkpoint_path = None 260 | self.save_frequency = 100 261 | self.ground_truth = 0 # 0 = self, 1 = poke, 2 = mask, 3 = poke+mask, 4 = visualize, 5 = generate test set 262 | self.num_actors = 35 263 | self.episodes = 900 264 | self.new_datapoints_per_episode = 70 265 | self.batch_size = 64 266 | self.lr_schedule = lambda episode, episodes: 5e-4 267 | self.weight_decay = 1e-4 268 | self.update_schedule = lambda episode, episodes: int(15 + 30 * episode / episodes) 269 | self.poking_schedule = lambda episode, episodes: 20 270 | self.prefill_memory = 3000 271 | self.eval_during_train = False 272 | self.unfreeze = -1 273 | 274 | 275 | class TestingConfig: 276 | def __init__(self): 277 | self.num_actors = 25 278 | self.bs = 50 279 | self.colors = [(0, 0, 200), (0, 255, 255), (255, 0, 255), (255, 255, 0), 280 | (120, 255, 0), (255, 120, 0), (0, 255, 120), (0, 120, 255), (120, 0, 255), (120, 255, 120), 281 | (60, 177, 0), (177, 60, 0), (0, 177, 60), (0, 60, 177), (60, 0, 177), (60, 177, 60)] 282 | -------------------------------------------------------------------------------- /source/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | import torch 6 | 7 | from models.clustering_models import ClusteringModel 8 | from pipeline.evaluator import Evaluator 9 | from replay_memory.replay_pil import ReplayPILDataset 10 | from config import MemoryConfigPIL, TestingConfig, ClusteringModelConfig 11 | from config import global_config 12 | from tools.logger import init_logging, LOGGER 13 | from tools.coco_tools import save_coco_dataset 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser( 18 | description="self-supervised-objects eval", 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 20 | ) 21 | parser.add_argument( 22 | "model_folder", 23 | type=str, 24 | help="required trained model folder name", 25 | ) 26 | parser.add_argument( 27 | "dataset_folder", 28 | type=str, 29 | help="required dataset folder name", 30 | ) 31 | parser.add_argument( 32 | "dataset", 33 | type=int, 34 | help="required dataset type should be 0 (NovelObjects) or 1 (NovelSpaces) and match the one used for training", 35 | ) 36 | parser.add_argument( 37 | "-c", 38 | "--checkpoint", 39 | required=False, 40 | default=900, 41 | type=int, 42 | help="checkpoint to evaluate", 43 | ) 44 | parser.add_argument( 45 | "-g", 46 | "--model_gpu", 47 | required=False, 48 | default=0, 49 | type=int, 50 | help="gpu id to run model", 51 | ) 52 | parser.add_argument( 53 | "-l", 54 | "--loaders_gpu", 55 | required=False, 56 | default=0, 57 | type=int, 58 | help="gpu id to run thor data loaders", 59 | ) 60 | parser.add_argument( 61 | "-i", 62 | "--interaction_threshold", 63 | required=False, 64 | default=-100.0, 65 | type=float, 66 | help="interaction logits threshold", 67 | ) 68 | parser.add_argument( 69 | "-p", 70 | "--checkpoint_prefix", 71 | required=False, 72 | default="clustering_model_weights_", 73 | type=str, 74 | help="prefix for checkpoints in output folder", 75 | ) 76 | parser.add_argument( 77 | "-d", 78 | "--det_file", 79 | required=False, 80 | default=None, 81 | type=str, 82 | help="precomputed detections result", 83 | ) 84 | args = parser.parse_args() 85 | 86 | return args 87 | 88 | 89 | if __name__ == '__main__': 90 | args = get_args() 91 | 92 | init_logging() 93 | LOGGER.info("Running eval with args {}".format(args)) 94 | 95 | output_folder = os.path.normpath(args.model_folder) 96 | dataset_folder = os.path.normpath(args.dataset_folder) 97 | dataset = args.dataset 98 | 99 | assert os.path.isdir(output_folder), 'Output folder does not exist' 100 | assert os.path.isdir(dataset_folder), 'Dataset folder does not exist' 101 | assert dataset in [0, 1], 'Dataset argument should be either 0 (NovelObjects) or 1 (NovelSpaces)' 102 | 103 | results_folder = os.path.join(output_folder, "inference") 104 | os.makedirs(results_folder, exist_ok=True) 105 | LOGGER.info("Writing output to {}".format(results_folder)) 106 | 107 | data_file = ['NovelObjects__test.json', 'NovelSpaces__test.json'][dataset] 108 | data_path = os.path.join(dataset_folder, data_file) 109 | coco_gt_path = save_coco_dataset(data_path, results_folder) 110 | 111 | if args.det_file is None: 112 | global_config.model_gpu = args.model_gpu 113 | global_config.actor_gpu = args.loaders_gpu 114 | model = ClusteringModel(ClusteringModelConfig()).cuda(global_config.model_gpu) 115 | 116 | cp_name = os.path.join(output_folder, "{}{}.pth".format(args.checkpoint_prefix, args.checkpoint)) 117 | LOGGER.info("Loading checkpoint {}".format(cp_name)) 118 | model.load_state_dict(torch.load(cp_name, map_location="cpu")) 119 | 120 | id = "{}__cp{}".format(os.path.basename(output_folder), args.checkpoint) 121 | 122 | eval = Evaluator(model, ReplayPILDataset(MemoryConfigPIL()), loss_function=None, tester_config=TestingConfig()) 123 | det_file = eval.inference(data_path, results_folder, args.interaction_threshold, id, interactable_classes=[0, 1, 2]) 124 | else: 125 | det_file = args.det_file 126 | LOGGER.info("Using precomputed detections in {}".format(det_file)) 127 | 128 | results = {} 129 | for anno_type in ['bbox', 'segm', 'mass']: 130 | results.update(Evaluator.evaluate(coco_gt_path, det_file, annotation_type=anno_type)) 131 | 132 | results_file = det_file.replace("_inf.json", "_results.json") 133 | with open(results_file, "w") as f: 134 | json.dump(results, f, indent=4, sort_keys=True) 135 | LOGGER.info("Full results saved in {}".format(results_file)) 136 | 137 | LOGGER.info("Eval done") 138 | -------------------------------------------------------------------------------- /source/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/learning_from_interaction/a266bc16d682832aa854348fa557a30d86b84674/source/losses/__init__.py -------------------------------------------------------------------------------- /source/losses/clustering_losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from config import ObjectnessClusteringLossConfig, MaskAndMassLossConfig, global_config 5 | from losses.loss_utils import focal_loss_function, embedding_loss_function, MassLoss 6 | from losses.losses import LossFunction, ObjectnessLoss 7 | 8 | 9 | class MaskAndMassLoss(LossFunction): 10 | """ 11 | This loss assumes that the model predicts: pixel feature embeddings, pixel objectness logits, pixel force logits 12 | 13 | The loss encourages the model to achieve the following 14 | - Pixel objectness logits are large for pixels that are likely to move when poked. 15 | - Force logits are non-random at locations of high objectness. At these locations, the largest of the logits 16 | corresponds to a force that moves the object, but so that no lower force will move the object 17 | - Pixel feature embeddings are close to each other for pixels that are likely to move together after a poke 18 | 19 | Note: Terminology "Force" vs "Mass" is somewhat ambiguous in this class. Both refer to the same. 20 | """ 21 | 22 | def __init__(self, loss_config: MaskAndMassLossConfig): 23 | super(MaskAndMassLoss, self).__init__() 24 | self.loss_summary_length = 2 + global_config.superpixels + loss_config.filter 25 | self.prioritized_replay = loss_config.prioritized_replay 26 | self.config = loss_config 27 | self.focal_loss = focal_loss_function(loss_config.objectness_weight) 28 | self.embedding_loss = embedding_loss_function(loss_config.threshold) 29 | self.mass_loss = MassLoss.apply 30 | self.logsm = torch.nn.LogSoftmax(dim=0) 31 | self.distance_function = lambda x, y: ((x - y) ** 2).sum(dim=0) 32 | 33 | def __call__(self, model_predictions, targets, weights, superpixels=None): 34 | objectnesss, masses, embeddings = model_predictions[:3] 35 | device = embeddings[0].device 36 | 37 | object_masks, foreground_masks, background_masks, mass_masks = targets 38 | 39 | embedding_losses, objectness_losses, mass_losses = [], [], [] 40 | 41 | for weight, embedding, objectness, mass, object_mask, foreground, background, mass_mask in zip( 42 | weights, embeddings, objectnesss, masses, object_masks, foreground_masks, background_masks, mass_masks): 43 | 44 | embedding_loss = self.compute_embedding_loss(embedding, object_mask, weight, device) 45 | if not self.config.instance_only: 46 | if self.config.scaleable: 47 | mass_loss = self.mass_loss(mass, mass_mask, weight) 48 | else: 49 | mass_loss = self.mass_loss_nonscaleable(mass, mass_mask, weight) 50 | mass_losses.append(mass_loss * self.config.mass_loss_weight) 51 | 52 | objectness_loss = self.compute_objectness_loss(objectness, foreground, background, weight) 53 | 54 | embedding_losses.append(embedding_loss) 55 | objectness_losses.append(objectness_loss) 56 | 57 | losses = [embedding_losses, objectness_losses] + ([mass_losses] if not self.config.instance_only else []) 58 | losses = self.stack_losses(losses, device) 59 | 60 | if self.prioritized_replay: 61 | priorities = self.compute_priorities([embedding_losses]) 62 | return losses, priorities 63 | 64 | return losses 65 | 66 | def compute_objectness_loss(self, objectness, foreground, background, weight): 67 | b = objectness.shape[0] > 1 68 | objectness, uncertainty = objectness[0], objectness[1] if b else None 69 | objectness_loss = self.focal_loss(objectness, foreground, background, weight) 70 | if b: 71 | uncertainty_foreground = foreground * (objectness < 0) + background * (objectness > 0) 72 | uncertainty_background = foreground * (objectness >= 0) + background * (objectness <= 0) 73 | uncertainty_loss = self.focal_loss(uncertainty, uncertainty_foreground, uncertainty_background, weight) 74 | return objectness_loss + uncertainty_loss 75 | return objectness_loss 76 | 77 | def compute_embedding_loss(self, embedding, object_mask, weight, device): 78 | max_objects = object_mask.shape[0] 79 | embedding_loss = torch.tensor(0., dtype=torch.float32).to(device) 80 | 81 | objs = [object_mask[i] for i in range(max_objects) if object_mask[i].sum() > 0] 82 | for obj in objs: 83 | center = embedding[:, obj].mean(1).unsqueeze(1).unsqueeze(2) 84 | embedding_loss = embedding_loss + self.embedding_loss(self.distance_function(embedding, center), 85 | obj, weight) 86 | return embedding_loss / max(len(objs), 1) 87 | 88 | def mass_loss_nonscaleable(self, mass, mass_mask, weight): 89 | mass = self.logsm(mass) 90 | classes = mass_mask.argmax(dim=0) 91 | pointwise_ce = sum(mass[i] * (classes == i) for i in range(3)) 92 | loss = (pointwise_ce * mass_mask.sum(dim=0)).sum(dim=(0, 1)) 93 | return - loss * weight 94 | 95 | def process_feedback(self, actions: list, feedback: list, superpixels=None): 96 | targets = [] 97 | num_successes = 0 98 | if not superpixels: 99 | superpixels = [None] * len(actions) 100 | for act, fb, superpixel in zip(actions, feedback, superpixels): 101 | target, new_successes = self.process_single_feedback(act, fb, superpixel) 102 | targets.append(target) 103 | num_successes += new_successes 104 | return targets, num_successes 105 | 106 | def process_single_feedback(self, actions, feedbacks, superpixel): 107 | object_masks = np.zeros((global_config.max_pokes, 108 | global_config.grid_size, global_config.grid_size), dtype=bool) 109 | foreground_mask = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 110 | mass_mask = np.zeros((3, global_config.grid_size, global_config.grid_size), dtype=np.float32) 111 | background_mask = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 112 | successes = 0 113 | 114 | for i, (action, feedback) in enumerate(zip(actions, feedbacks)): 115 | if self.config.instance_only: 116 | mask, mass_fb = feedback, None 117 | else: 118 | mask, mass_fb = feedback[0], feedback[1] 119 | action, mass = action['point'], action['force'] if not self.config.instance_only else None 120 | weights = self.get_neighbourhood(action, superpixel) 121 | score = self.get_score(mask, action) 122 | if score > self.config.foreground_threshold: 123 | object_masks[i, mask] = True 124 | foreground_mask += weights 125 | if not self.config.instance_only: 126 | if self.config.scaleable: 127 | mass_mask += self.mass_feedback_vector(mass, mass_fb) * weights[None, ...] 128 | else: 129 | mass_mask += self.mass_feedback_vector_nonscaleable(mass, mass_fb) * weights[None, ...] 130 | successes += 1 131 | else: 132 | background_mask += weights 133 | return (object_masks, foreground_mask, background_mask, mass_mask), successes 134 | 135 | @staticmethod 136 | def mass_feedback_vector(mass, feedback): 137 | vec = np.zeros(3, dtype=np.float32) 138 | if feedback == 2: 139 | return vec[:, None, None] 140 | if feedback == 0: 141 | vec[mass] = 1 142 | elif feedback == -1: 143 | for i in range(0, mass): 144 | vec[i] = 1 145 | elif feedback == 1: 146 | for i in range(mass + 1, 3): 147 | vec[i] = 1 148 | vec = vec - vec.mean() 149 | vec = vec / np.abs(vec).sum() 150 | return vec[:, None, None] 151 | 152 | @staticmethod 153 | def mass_feedback_vector_nonscaleable(mass, feedback): 154 | vec = np.zeros(3, dtype=np.float32) 155 | if feedback < 2: 156 | vec[feedback] = 1 157 | return vec[:, None, None] 158 | 159 | def get_score(self, mask, action): 160 | if not self.config.localize_object_around_poking_point: 161 | return mask.sum() 162 | x, y = action 163 | dx1 = min(x, self.config.kernel_size) 164 | dx2 = min(global_config.grid_size - 1 - x, self.config.kernel_size) + 1 165 | dy1 = min(y, self.config.kernel_size) 166 | dy2 = min(global_config.grid_size - 1 - y, self.config.kernel_size) + 1 167 | x1, x2, y1, y2 = x - dx1, x + dx2, y - dy1, y + dy2 168 | return (mask[x1:x2, y1:y2] * self.config.check_change_kernel[self.config.kernel_size - dx1: 169 | self.config.kernel_size + dx2, 170 | self.config.kernel_size - dy1: 171 | self.config.kernel_size + dy2]).sum() 172 | 173 | def get_neighbourhood(self, action, superpixel): 174 | x, y = action 175 | if self.config.point_feedback_for_action: 176 | weights = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 177 | weights[x, y] = 1 178 | elif superpixel is not None and self.config.superpixel_for_action_feedback: 179 | weights = (superpixel == superpixel[x, y]).astype(np.float32) 180 | else: 181 | weights = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 182 | dx1 = min(x, self.config.kernel_size) 183 | dx2 = min(global_config.grid_size - 1 - x, self.config.kernel_size) + 1 184 | dy1 = min(y, self.config.kernel_size) 185 | dy2 = min(global_config.grid_size - 1 - y, self.config.kernel_size) + 1 186 | x1, x2, y1, y2 = x - dx1, x + dx2, y - dy1, y + dy2 187 | weights[x1:x2, y1:y2] = self.config.kernel[self.config.kernel_size - dx1: 188 | self.config.kernel_size + dx2, 189 | self.config.kernel_size - dy1: 190 | self.config.kernel_size + dy2] 191 | return weights 192 | 193 | @staticmethod 194 | def stack_losses(losses, device): 195 | losses = [torch.stack(l) if len(l) > 0 else torch.tensor(0.).to(device) 196 | for l in losses] 197 | losses = [l.sum() / (l > 0).sum().clamp(min=1) for l in losses] 198 | return losses 199 | 200 | def compute_priorities(self, losses: list): 201 | priorities = [] 202 | for loss in zip(*losses): 203 | score = min(iou.item() if iou > 0 else .5 for iou in loss) 204 | priorities.append((score - .5) ** 2 + .02) 205 | return priorities 206 | 207 | 208 | class ObjectnessClusteringLoss(ObjectnessLoss): 209 | """ 210 | This loss is for instance segmentation only. 211 | It assumes that the model predicts: pixel feature embeddings, pixel objectness logits, 212 | (and optionally a single filter logit for the entire image). 213 | 214 | The losses for objectness and filter logit are described in the parent class. 215 | - Pixel feature embeddings are close to each other for pixels that are likely to move together after a poke 216 | """ 217 | 218 | def __init__(self, loss_config: ObjectnessClusteringLossConfig): 219 | super(ObjectnessClusteringLoss, self).__init__(loss_config) 220 | self.loss_summary_length = 2 + global_config.superpixels + loss_config.filter 221 | 222 | self.embedding_loss = embedding_loss_function(loss_config.threshold) 223 | self.distance_function = lambda x, y: ((x - y) ** 2).sum(dim=0) 224 | 225 | def __call__(self, model_predictions, targets, weights, superpixels=None): 226 | objectnesss, embeddings = model_predictions[:2] 227 | device = embeddings[0].device 228 | if self.config.filter: 229 | filter_logits = model_predictions[-1] 230 | else: 231 | filter_logits = [None] * embeddings.shape[0] 232 | 233 | object_masks, foreground_masks, background_masks = targets 234 | 235 | filter_and_smoothness_losses = \ 236 | self.compute_filter_and_smoothness_loss(filter_logits, model_predictions[:2], object_masks, 237 | weights, superpixels) 238 | 239 | embedding_losses, objectness_losses = [], [] 240 | 241 | for weight, embedding, objectness, filter_logit, object_mask, foreground, background in zip( 242 | weights, embeddings, objectnesss, filter_logits, object_masks, foreground_masks, background_masks): 243 | 244 | if self.config.filter and filter_logit < self.config.filter_threshold: 245 | continue 246 | 247 | embedding_loss = self.compute_embedding_loss(embedding, object_mask, weight, device) 248 | 249 | if self.config.center_foreground: 250 | foreground = self.center_foreground(foreground, embedding) 251 | objectness_loss = self.compute_objectness_loss(objectness, foreground, background, weight) 252 | 253 | embedding_losses.append(embedding_loss) 254 | objectness_losses.append(objectness_loss) 255 | 256 | losses = self.stack_losses([embedding_losses, objectness_losses], device) 257 | 258 | losses += filter_and_smoothness_losses 259 | 260 | if self.prioritized_replay: 261 | priorities = self.compute_priorities([embedding_losses]) 262 | return losses, priorities 263 | 264 | return losses 265 | 266 | def compute_embedding_loss(self, embedding, object_mask, weight, device): 267 | max_objects = object_mask.shape[0] 268 | embedding_loss = torch.tensor(0., dtype=torch.float32).to(device) 269 | 270 | objs = [object_mask[i] for i in range(max_objects) if object_mask[i].sum() > 0] 271 | for obj in objs: 272 | if self.config.robustify is not None: 273 | obj = self.robustify(embedding, obj) 274 | center = embedding[:, obj].mean(1).unsqueeze(1).unsqueeze(2) 275 | embedding_loss = embedding_loss + self.embedding_loss(self.distance_function(embedding, center), 276 | obj, weight) 277 | return embedding_loss / max(len(objs), 1) 278 | 279 | @staticmethod 280 | def center_foreground(foreground, embedding): 281 | if torch.any(foreground > .999): 282 | with torch.no_grad(): 283 | emb = embedding.transpose(0, 1).transpose(1, 2) 284 | feats = emb[torch.where(foreground > .999)] 285 | masks = ((embedding.unsqueeze(0) - feats.unsqueeze(2).unsqueeze(3)) ** 2).sum( 286 | dim=1) < 1 287 | meanfeats = torch.stack([emb[mask].mean(0) for mask in masks]) 288 | meanmasks = ((embedding.unsqueeze(0) - meanfeats.unsqueeze(2).unsqueeze(3)) ** 2).sum(dim=1) < 1 289 | ious = (masks * meanmasks).sum(dim=(1, 2)).float() / (masks | meanmasks).sum(dim=(1, 2)) > .8 290 | foreground[torch.where(foreground > .999)] += .2 * (2 * ious - 1).float() 291 | return foreground.clone() 292 | 293 | def robustify(self, embedding, obj): 294 | with torch.no_grad(): 295 | center = embedding[:, obj].mean(1).unsqueeze(1).unsqueeze(2) 296 | distances = self.distance_function(embedding, center) 297 | throw_out = (distances > 1 - self.config.robustify[0]) * obj 298 | put_in = (distances < 1 + self.config.robustify[1]) * (~ obj) 299 | return (obj * (~ throw_out)) | put_in 300 | 301 | 302 | class ObjectnessClusteringLossGT(ObjectnessClusteringLoss): 303 | """ 304 | This loss is an alternative for clustering based instance segmentation models trained fully supervised. 305 | """ 306 | 307 | def __init__(self, loss_config: ObjectnessClusteringLossConfig): 308 | super(ObjectnessClusteringLossGT, self).__init__(loss_config) 309 | 310 | def __call__(self, model_predictions, targets, weights, superpixels=None): 311 | seed_scores, embeddings = model_predictions[:2] 312 | device = embeddings[0].device 313 | 314 | object_masks, _, _ = targets 315 | 316 | embedding_losses, seed_losses = [], [] 317 | 318 | for weight, embedding, seed_score, object_mask in zip( 319 | weights, embeddings, seed_scores, object_masks): 320 | embedding_loss = self.compute_embedding_loss(embedding, object_mask, weight, device) 321 | 322 | seed_loss = self.compute_seed_loss(seed_score, embedding, object_mask, weight) 323 | 324 | embedding_losses.append(embedding_loss) 325 | seed_losses.append(seed_loss) 326 | 327 | losses = self.stack_losses([embedding_losses, seed_losses], device) 328 | 329 | if self.prioritized_replay: 330 | priorities = self.compute_priorities([embedding_losses]) 331 | return losses, priorities 332 | 333 | return losses 334 | 335 | def compute_seed_loss(self, seed_score, embedding, object_mask, weight): 336 | foreground, background = self.make_seed_targets(embedding, object_mask, seed_score) 337 | return self.focal_loss(seed_score, foreground, background, weight) 338 | 339 | def make_seed_targets(self, embedding, masks, seed_score): 340 | with torch.no_grad(): 341 | foreground, background = torch.zeros_like(seed_score), torch.zeros_like(seed_score) 342 | dim = len(foreground.view(-1)) 343 | masks = [mask for mask in masks if mask.sum() > 0] 344 | for mask in masks: 345 | center = embedding[:, mask].mean(1).unsqueeze(1).unsqueeze(2) 346 | distances = self.distance_function(embedding, center) 347 | dmin = distances.view(-1).kthvalue(6)[0] 348 | dmax = distances.view(-1).kthvalue(dim - 5)[0] 349 | foreground += distances < dmin 350 | background += distances > dmax 351 | return foreground, background 352 | 353 | 354 | class PooledMaskAndMassLoss(LossFunction): 355 | """ 356 | This loss is the equivalent of MaskAndMassLoss for models that do not produce pixel-wise force logits, 357 | but instance-wise force logits 358 | """ 359 | 360 | def __init__(self, loss_config: MaskAndMassLossConfig): 361 | super(PooledMaskAndMassLoss, self).__init__() 362 | self.loss_summary_length = 2 + global_config.superpixels + loss_config.filter 363 | self.prioritized_replay = loss_config.prioritized_replay 364 | self.config = loss_config 365 | self.focal_loss = focal_loss_function(loss_config.objectness_weight) 366 | self.embedding_loss = embedding_loss_function(loss_config.threshold) 367 | self.mass_loss = MassLoss.apply 368 | self.logsm = torch.nn.LogSoftmax(dim=1) 369 | self.distance_function = lambda x, y: ((x - y) ** 2).sum(dim=0) 370 | 371 | def mass_loss_nonscaleable(self, mass_logit, mass_target, weight): 372 | mass = self.logsm(mass_logit) 373 | classes = mass_target.argmax(dim=1) 374 | pointwise_ce = sum(mass[:, i] * (classes == i) for i in range(3)) 375 | loss = (pointwise_ce * mass_target.sum(dim=1)).sum() 376 | return - loss * weight 377 | 378 | def __call__(self, model_predictions, targets, weights, superpixels=None): 379 | objectnesss, embeddings, mass_logits, _ = model_predictions 380 | device = embeddings[0].device 381 | 382 | object_masks, foreground_masks, background_masks, mass_targets = targets 383 | 384 | embedding_losses, objectness_losses, mass_losses = [], [], [] 385 | 386 | for weight, embedding, objectness, mass_logit, object_mask, foreground, background, mass_target in zip( 387 | weights, embeddings, objectnesss, mass_logits, object_masks, foreground_masks, background_masks, 388 | mass_targets): 389 | 390 | embedding_loss = self.compute_embedding_loss(embedding, object_mask, weight, device) 391 | if self.config.scaleable: 392 | mass_loss = self.mass_loss(mass_logit, mass_target, weight) 393 | else: 394 | mass_loss = self.mass_loss_nonscaleable(mass_logit, mass_target, weight) 395 | objectness_loss = self.compute_objectness_loss(objectness, foreground, background, weight) 396 | 397 | embedding_losses.append(embedding_loss) 398 | mass_losses.append(mass_loss) 399 | objectness_losses.append(objectness_loss) 400 | 401 | losses = self.stack_losses([embedding_losses, objectness_losses, mass_losses], device) 402 | 403 | if self.prioritized_replay: 404 | priorities = self.compute_priorities([embedding_losses]) 405 | return losses, priorities 406 | 407 | return losses 408 | 409 | def compute_embedding_loss(self, embedding, object_mask, weight, device): 410 | max_objects = object_mask.shape[0] 411 | embedding_loss = torch.tensor(0., dtype=torch.float32).to(device) 412 | 413 | objs = [object_mask[i] for i in range(max_objects) if object_mask[i].sum() > 0] 414 | for obj in objs: 415 | center = embedding[:, obj].mean(1).unsqueeze(1).unsqueeze(2) 416 | embedding_loss = embedding_loss + self.embedding_loss(self.distance_function(embedding, center), 417 | obj, weight) 418 | return embedding_loss / max(len(objs), 1) 419 | 420 | def process_feedback(self, actions: list, feedback: list, superpixels=None): 421 | targets = [] 422 | num_successes = 0 423 | if not superpixels: 424 | superpixels = [None] * len(actions) 425 | for act, fb, superpixel in zip(actions, feedback, superpixels): 426 | target, new_successes = self.process_single_feedback(act, fb, superpixel) 427 | targets.append(target) 428 | num_successes += new_successes 429 | return targets, num_successes 430 | 431 | def process_single_feedback(self, actions, feedbacks, superpixel): 432 | object_masks = np.zeros((global_config.max_pokes, 433 | global_config.grid_size, global_config.grid_size), dtype=bool) 434 | foreground_mask = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 435 | mass_targets = np.zeros((global_config.max_pokes, 3), dtype=np.float32) 436 | background_mask = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 437 | successes = 0 438 | 439 | for i, (action_and_mass, mask_and_mass) in enumerate(zip(actions, feedbacks)): 440 | mask, mass_fb = mask_and_mass 441 | action, mass = action_and_mass['point'], action_and_mass['force'] 442 | weights = self.get_neighbourhood(action, superpixel) 443 | score = self.get_score(mask, action) 444 | if score > self.config.foreground_threshold: 445 | object_masks[i, mask] = True 446 | foreground_mask += weights 447 | if self.config.scaleable: 448 | mass_targets[i] = self.mass_feedback_vector(mass, mass_fb) 449 | else: 450 | mass_targets[i] = self.mass_feedback_vector_nonscaleable(mass, mass_fb) 451 | successes += 1 452 | else: 453 | background_mask += weights 454 | return (object_masks, foreground_mask, background_mask, mass_targets), successes 455 | 456 | @staticmethod 457 | def mass_feedback_vector(mass, feedback): 458 | vec = np.zeros(3, dtype=np.float32) 459 | if feedback == 2: 460 | return vec 461 | if feedback == 0: 462 | vec[mass] = 1 463 | elif feedback == -1: 464 | for i in range(0, mass): 465 | vec[i] = 1 466 | elif feedback == 1: 467 | for i in range(mass + 1, 3): 468 | vec[i] = 1 469 | vec = vec - vec.mean() 470 | vec = vec / np.abs(vec).sum() 471 | return vec 472 | 473 | @staticmethod 474 | def mass_feedback_vector_nonscaleable(mass, feedback): 475 | vec = np.zeros(3, dtype=np.float32) 476 | if feedback < 2: 477 | vec[feedback] = 1 478 | return vec 479 | 480 | def get_score(self, mask, action): 481 | if not self.config.localize_object_around_poking_point: 482 | return mask.sum() 483 | x, y = action 484 | dx1 = min(x, self.config.kernel_size) 485 | dx2 = min(global_config.grid_size - 1 - x, self.config.kernel_size) + 1 486 | dy1 = min(y, self.config.kernel_size) 487 | dy2 = min(global_config.grid_size - 1 - y, self.config.kernel_size) + 1 488 | x1, x2, y1, y2 = x - dx1, x + dx2, y - dy1, y + dy2 489 | return (mask[x1:x2, y1:y2] * self.config.check_change_kernel[self.config.kernel_size - dx1: 490 | self.config.kernel_size + dx2, 491 | self.config.kernel_size - dy1: 492 | self.config.kernel_size + dy2]).sum() 493 | 494 | def get_neighbourhood(self, action, superpixel): 495 | x, y = action 496 | if self.config.point_feedback_for_action: 497 | weights = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 498 | weights[x, y] = 1 499 | elif superpixel is not None and self.config.superpixel_for_action_feedback: 500 | weights = (superpixel == superpixel[x, y]).astype(np.float32) 501 | else: 502 | weights = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 503 | dx1 = min(x, self.config.kernel_size) 504 | dx2 = min(global_config.grid_size - 1 - x, self.config.kernel_size) + 1 505 | dy1 = min(y, self.config.kernel_size) 506 | dy2 = min(global_config.grid_size - 1 - y, self.config.kernel_size) + 1 507 | x1, x2, y1, y2 = x - dx1, x + dx2, y - dy1, y + dy2 508 | weights[x1:x2, y1:y2] = self.config.kernel[self.config.kernel_size - dx1: 509 | self.config.kernel_size + dx2, 510 | self.config.kernel_size - dy1: 511 | self.config.kernel_size + dy2] 512 | return weights 513 | 514 | def compute_objectness_loss(self, objectness, foreground, background, weight): 515 | b = objectness.shape[0] > 1 516 | objectness, uncertainty = objectness[0], objectness[1] if b else None 517 | objectness_loss = self.focal_loss(objectness, foreground, background, weight) 518 | if b: 519 | uncertainty_foreground = foreground * (objectness < 0) + background * (objectness > 0) 520 | uncertainty_background = foreground * (objectness >= 0) + background * (objectness <= 0) 521 | uncertainty_loss = self.focal_loss(uncertainty, uncertainty_foreground, uncertainty_background, weight) 522 | return objectness_loss + uncertainty_loss 523 | return objectness_loss 524 | 525 | @staticmethod 526 | def stack_losses(losses, device): 527 | losses = [torch.stack(l) if len(l) > 0 else torch.tensor(0.).to(device) for l in losses] 528 | losses = [l.sum() / (l > 0).sum().clamp(min=1) for l in losses] 529 | return losses 530 | 531 | def compute_priorities(self, losses: list): 532 | priorities = [] 533 | for loss in zip(*losses): 534 | score = min(iou.item() if iou > 0 else .5 for iou in loss) 535 | priorities.append((score - .5) ** 2 + .02) 536 | return priorities 537 | -------------------------------------------------------------------------------- /source/losses/fgbg_losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from config import FgBgLossConfig, global_config 5 | from losses.loss_utils import focal_loss_function 6 | from losses.losses import LossFunction 7 | 8 | 9 | class FgBgLossFunction(LossFunction): 10 | """ 11 | This is a simple loss for foreground-background segmentation models with some adjustability for unlabeled pixels. 12 | """ 13 | def __init__(self, loss_config: FgBgLossConfig): 14 | super(FgBgLossFunction, self).__init__() 15 | self.loss_summary_length = 1 16 | self.prioritized_replay = loss_config.prioritized_replay 17 | self.config = loss_config 18 | self.focal_loss = focal_loss_function(1) 19 | 20 | def __call__(self, model_predictions: tuple, targets: tuple, weights, *superpixels): 21 | _, foreground_masks, background_masks = targets 22 | objectness_losses = [] 23 | for weight, objectness, foreground, background in zip( 24 | weights, model_predictions[0], foreground_masks, background_masks): 25 | objectness_loss = self.compute_objectness_loss(objectness, foreground, background, weight) 26 | objectness_losses.append(objectness_loss) 27 | losses = [torch.stack(objectness_losses)] 28 | 29 | if self.prioritized_replay: 30 | priorities = self.compute_priorities(losses) 31 | return [l.sum() / (l > 0).sum().clamp(min=1) for l in losses], priorities 32 | 33 | return [l.sum() / (l > 0).sum().clamp(min=1) for l in losses] 34 | 35 | def process_feedback(self, actions: list, feedback: list, superpixels=None): 36 | targets = [] 37 | num_successes = 0 38 | for act, fb in zip(actions, feedback): 39 | target, new_successes = self.process_single_feedback(act, fb) 40 | targets.append(target) 41 | num_successes += new_successes 42 | return targets, num_successes 43 | 44 | def process_single_feedback(self, actions, feedbacks): 45 | foreground_mask = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 46 | background_mask = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 47 | poking_mask = np.zeros((global_config.max_pokes, global_config.grid_size, global_config.grid_size), 48 | dtype=np.bool) 49 | successes = 0 50 | 51 | for i, (action, mask, pm) in enumerate(zip(actions, feedbacks, poking_mask)): 52 | weights = self.get_neighbourhood(action['point']) 53 | if mask.sum() > self.config.foreground_threshold: 54 | if self.config.restrict_positives: 55 | foreground_mask += weights 56 | pm[:] = mask > 0 57 | else: 58 | foreground_mask = (foreground_mask + mask) > 0 59 | successes += 1 60 | elif self.config.restrict_negatives: 61 | background_mask += weights 62 | if not self.config.restrict_negatives: 63 | background_mask = ~ foreground_mask 64 | return (poking_mask, foreground_mask, background_mask), successes 65 | 66 | def get_neighbourhood(self, action): 67 | x, y = action 68 | weights = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 69 | dx1 = min(x, self.config.kernel_size) 70 | dx2 = min(global_config.grid_size - 1 - x, self.config.kernel_size) + 1 71 | dy1 = min(y, self.config.kernel_size) 72 | dy2 = min(global_config.grid_size - 1 - y, self.config.kernel_size) + 1 73 | x1, x2, y1, y2 = x - dx1, x + dx2, y - dy1, y + dy2 74 | weights[x1:x2, y1:y2] = self.config.kernel[self.config.kernel_size - dx1: 75 | self.config.kernel_size + dx2, 76 | self.config.kernel_size - dy1: 77 | self.config.kernel_size + dy2] 78 | return weights 79 | 80 | def compute_objectness_loss(self, objectness, foreground, background, weight): 81 | b = objectness.shape[0] > 1 82 | objectness, uncertainty = objectness[0], objectness[1] if b else None 83 | objectness_loss = self.focal_loss(objectness, foreground, background, weight) 84 | if b: 85 | uncertainty_foreground = foreground * (objectness < 0) + background * (objectness > 0) 86 | uncertainty_background = foreground * (objectness >= 0) + background * (objectness <= 0) 87 | uncertainty_loss = self.focal_loss(uncertainty, uncertainty_foreground, uncertainty_background, weight) 88 | return objectness_loss + uncertainty_loss 89 | return objectness_loss 90 | 91 | def compute_priorities(self, losses: list): 92 | raise NotImplementedError 93 | 94 | 95 | class SoftMaskLossFunction(LossFunction): 96 | """ 97 | This is an L2 loss for fitting soft fg-bg targets. It is used for the videoPCA baseline. 98 | """ 99 | def __init__(self): 100 | super(SoftMaskLossFunction, self).__init__() 101 | self.loss_summary_length = 2 102 | 103 | def __call__(self, model_predictions: tuple, targets: tuple, weights, *superpixels): 104 | soft_masks = targets[0] 105 | objectness_losses = [] 106 | for weight, objectness, soft_mask in zip(weights, model_predictions[0].sigmoid(), soft_masks): 107 | loss = weight * ((objectness - soft_mask)**2).sum() 108 | objectness_losses.append(loss) 109 | losses = [torch.stack(objectness_losses)] 110 | 111 | return [l.sum() / (l > 0).sum().clamp(min=1) for l in losses] 112 | 113 | def process_feedback(self, actions: list, feedback: list, superpixels=None): 114 | targets = [] 115 | num_successes = 0 116 | for act, fb in zip(actions, feedback): 117 | target, new_successes = self.process_single_feedback(act, fb) 118 | targets.append(target) 119 | num_successes += new_successes 120 | return targets, num_successes 121 | 122 | @staticmethod 123 | def process_single_feedback(actions, feedbacks): 124 | soft_mask = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 125 | successes = 0 126 | 127 | for i, (action, mask) in enumerate(zip(actions, feedbacks)): 128 | soft_mask += mask 129 | successes += mask.sum() > 0 130 | return (soft_mask, ), successes 131 | 132 | def compute_priorities(self, losses: list): 133 | raise NotImplementedError 134 | -------------------------------------------------------------------------------- /source/losses/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def focal_loss_function(hyperparameter: float): 6 | """ 7 | :param hyperparameter: Scales the size of the gradient in backward 8 | :return: an autograd.Function ready to apply 9 | 10 | This autograd.Function computes the IoU in the forward pass, and has gradients with similar properties as the 11 | focal loss in the backward pass. 12 | """ 13 | 14 | class FocalTypeLoss(torch.autograd.Function): 15 | @staticmethod 16 | def forward(ctx, logits, positives, negatives, weight): 17 | """ 18 | :param logits: Any shape, dtype float 19 | :param positives: Same shape as logits, dtype float 20 | :param negatives: Same shape as logits, dtype float 21 | :return: scalar, an IoU type quantity 22 | 23 | Note that the output is only for collecting statistic. It is not differentiable, and the backward pass is 24 | independent of it. 25 | """ 26 | predicted_positives = logits > 0 27 | positives_binary = positives > 0 28 | ctx.save_for_backward(logits, predicted_positives.float(), positives, negatives, weight) 29 | intersection = (predicted_positives * positives_binary).sum().float() 30 | union = (predicted_positives | positives_binary).sum().float().clamp(min=1) 31 | return intersection / union 32 | 33 | @staticmethod 34 | def backward(ctx, dummy_gradient): 35 | """ 36 | :param dummy_gradient: Argument is not used 37 | :return: Gradient for the logits in forward. 38 | """ 39 | logits, predicted_positives, positives, negatives, weight = ctx.saved_tensors 40 | 41 | predicted_negatives = 1 - predicted_positives 42 | logits_exp = logits.exp() 43 | target_pos_factor = 1 / (1 + logits_exp) * (-logits.clamp(min=0) ** 2 / 2).exp() 44 | target_neg_factor = logits_exp / (1 + logits_exp) * (-logits.clamp(max=0) ** 2 / 2).exp() 45 | 46 | # The four 1's here are in principle all hyperparameters, but we've found them hard to choose. 47 | gradient = (predicted_positives * positives * 1 * target_pos_factor 48 | + predicted_negatives * positives * 1 * target_pos_factor 49 | - predicted_positives * negatives * 1 * target_neg_factor 50 | - predicted_negatives * negatives * 1 * target_neg_factor) 51 | return -gradient * hyperparameter * weight, None, None, None 52 | 53 | return FocalTypeLoss.apply 54 | 55 | 56 | def weighted_kl_loss(hyperparameter: float): 57 | """ 58 | :param hyperparameter: Scales the size of the gradient in backward 59 | :return: an autograd.Function ready to apply 60 | 61 | This autograd.Function computes the IoU in the forward pass, and has gradients with similar properties as the 62 | focal loss in the backward pass. 63 | """ 64 | 65 | class WeightedKLLoss(torch.autograd.Function): 66 | @staticmethod 67 | def forward(ctx, logits, soft_mask): 68 | """ 69 | :param logits: Any shape, dtype float 70 | :param positives: Same shape as logits, dtype float 71 | :param negatives: Same shape as logits, dtype float 72 | :return: scalar, an IoU type quantity 73 | 74 | Note that the output is only for collecting statistic. It is not differentiable, and the backward pass is 75 | independent of it. 76 | """ 77 | negatives = soft_mask < 1e-6 78 | 79 | logits_exp = logits.exp() 80 | target_pos_factor = 1 / (1 + logits_exp) 81 | sigmoid = logits_exp * target_pos_factor 82 | 83 | loss = - (soft_mask * torch.log(sigmoid) + negatives * torch.log(target_pos_factor) * hyperparameter) 84 | 85 | ctx.save_for_backward(soft_mask, negatives, target_pos_factor, sigmoid) 86 | return loss.mean() 87 | 88 | @staticmethod 89 | def backward(ctx, dummy_gradient): 90 | """ 91 | :param dummy_gradient: Argument is not used 92 | :return: Gradient for the logits in forward. 93 | """ 94 | soft_mask, negatives, target_pos_factor, target_neg_factor = ctx.saved_tensors 95 | 96 | gradient = soft_mask * target_pos_factor - negatives * target_neg_factor * hyperparameter 97 | return - gradient, None, None, None 98 | 99 | return WeightedKLLoss.apply 100 | 101 | 102 | def weighted_focal_loss(hyperparameter: float): 103 | """ 104 | :param hyperparameter: Scales the size of the gradient in backward 105 | :return: an autograd.Function ready to apply 106 | 107 | This autograd.Function computes the IoU in the forward pass, and has gradients with similar properties as the 108 | focal loss in the backward pass. 109 | """ 110 | 111 | class WeightedKLLoss(torch.autograd.Function): 112 | @staticmethod 113 | def forward(ctx, logits, soft_mask): 114 | """ 115 | :param logits: Any shape, dtype float 116 | :param positives: Same shape as logits, dtype float 117 | :param negatives: Same shape as logits, dtype float 118 | :return: scalar, an IoU type quantity 119 | 120 | Note that the output is only for collecting statistic. It is not differentiable, and the backward pass is 121 | independent of it. 122 | """ 123 | negatives = soft_mask < 1e-6 124 | 125 | logits_exp = logits.exp() 126 | target_pos_factor = 1 / (1 + logits_exp) 127 | sigmoid = logits_exp * target_pos_factor 128 | 129 | loss = - (soft_mask * torch.log(sigmoid) + negatives * torch.log(target_pos_factor) * hyperparameter) 130 | 131 | ctx.save_for_backward(soft_mask, negatives, target_pos_factor, sigmoid, logits) 132 | return loss.mean() 133 | 134 | @staticmethod 135 | def backward(ctx, dummy_gradient): 136 | """ 137 | :param dummy_gradient: Argument is not used 138 | :return: Gradient for the logits in forward. 139 | """ 140 | soft_mask, negatives, target_pos_factor, target_neg_factor, logits = ctx.saved_tensors 141 | 142 | target_pos_factor = target_pos_factor * (-logits.clamp(min=0) ** 2 / 2).exp() 143 | target_neg_factor = target_neg_factor * (-logits.clamp(max=0) ** 2 / 2).exp() 144 | 145 | gradient = soft_mask * target_pos_factor - negatives * target_neg_factor * hyperparameter 146 | return - gradient, None, None, None 147 | 148 | return WeightedKLLoss.apply 149 | 150 | 151 | def real_focal_loss_function(hyperparameter: float): 152 | """ 153 | This function is similar to focal_loss_function, but on the backward pass returns the true focal loss, with 154 | hyperparameter gamma=2 (default of original paper). 155 | It is slightly less aggressive in suppressing gradients from confident predictions than our version of the 156 | focal loss, and performed slightly worse in preliminary experiments. 157 | """ 158 | 159 | class RealFocalLoss(torch.autograd.Function): 160 | @staticmethod 161 | def forward(ctx, logits, positives, negatives, weight): 162 | predicted_positives = logits > 0 163 | positives_binary = positives > 0 164 | ctx.save_for_backward(logits, predicted_positives.float(), positives, negatives, weight) 165 | intersection = (predicted_positives * positives_binary).sum().float() 166 | union = (predicted_positives | positives_binary).sum().float().clamp(min=1) 167 | return intersection / union 168 | 169 | @staticmethod 170 | def backward(ctx, dummy_gradient): 171 | logits, predicted_positives, positives, negatives, weight = ctx.saved_tensors 172 | 173 | predicted_negatives = 1 - predicted_positives 174 | logits_exp = logits.exp() 175 | target_pos_factor = (2 * logits_exp * (1 + 1 / logits_exp).log() + 1) / (1 + logits_exp) ** 3 176 | target_neg_factor = (2 / logits_exp * (1 + logits_exp).log() + 1) / (1 + 1 / logits_exp) ** 3 177 | 178 | gradient = (predicted_positives * positives * 1 * target_pos_factor 179 | + predicted_negatives * positives * 1 * target_pos_factor 180 | - predicted_positives * negatives * 1 * target_neg_factor 181 | - predicted_negatives * negatives * 1 * target_neg_factor) 182 | return -gradient * hyperparameter * weight, None, None, None 183 | 184 | return RealFocalLoss.apply 185 | 186 | 187 | def embedding_loss_function(threshold: float): 188 | """ 189 | :param threshold: Distance threshold for computing masks 190 | :return: an autograd.Function that can be used as loss for distances 191 | 192 | This loss on distances pushes positives to distances closer than the threshold, and negatives to distances 193 | larger than the threshold, and behaves otherwise analogous to a focal loss. 194 | """ 195 | 196 | class EmbeddingLoss(torch.autograd.Function): 197 | @staticmethod 198 | def forward(ctx, distances, mask, weight): 199 | """ 200 | :param distances: Any shape, dtype float 201 | :param mask: Same shape as distances, dtype bool 202 | :return: Scalar, and IoU 203 | 204 | Note that the output is only for collecting statistic. It is not differentiable, and the backward pass is 205 | independent of it. 206 | """ 207 | predicted_mask = distances < threshold 208 | distances = distances / threshold 209 | ctx.save_for_backward(distances, mask, weight) 210 | intersection = (predicted_mask * mask).sum().float() 211 | union = (predicted_mask | mask).sum().float().clamp(min=1) 212 | return intersection / union 213 | 214 | @staticmethod 215 | def backward(ctx, dummy_gradient): 216 | """ 217 | :param dummy_gradient: Argument is not used 218 | :return: Gradient for the distances in forward. 219 | """ 220 | distances, mask, weight = ctx.saved_tensors 221 | area = mask.sum().float() / 200 222 | 223 | # '1.5' is chosen here because 1.5x/(1+x) ~= e^(-(x/1.5)**4) at x = 1, i.e. gradients balance at x=1. 224 | positive_factor = 1.5 * distances / (1 + distances) 225 | negative_factor = (-(distances / 1.5) ** 4).exp() 226 | gradient = ~mask * negative_factor - mask * positive_factor 227 | 228 | return -gradient / area * weight, None, None 229 | 230 | return EmbeddingLoss.apply 231 | 232 | 233 | class MassLoss(torch.autograd.Function): 234 | """ 235 | The backward pass of this loss function is essentially the same as FocalLoss. In the forward pass, we compute the 236 | covariance instead of IoU. 237 | """ 238 | @staticmethod 239 | def forward(ctx, logits, soft_targets, weight): 240 | ctx.save_for_backward(logits, soft_targets, weight) 241 | return ((logits - logits.mean(dim=0).unsqueeze(0)) * soft_targets).mean() 242 | 243 | @staticmethod 244 | def backward(ctx, dummy_grad): 245 | logits, soft_targets, weight = ctx.saved_tensors 246 | logits_exp = logits.exp() 247 | target_pos_factor = 1 / (1 + logits_exp) * (-logits.clamp(min=0) ** 2 / 2).exp() 248 | target_neg_factor = logits_exp / (1 + logits_exp) * (-logits.clamp(max=0) ** 2 / 2).exp() 249 | grad = soft_targets * (target_pos_factor * (soft_targets > 0) + 250 | target_neg_factor * (soft_targets < 0)) 251 | 252 | return - grad * weight, None, None 253 | 254 | 255 | class SmoothSquare(torch.autograd.Function): 256 | """ 257 | Behaves like the square in the forward pass, but clips gradients at +-1 in the backward pass 258 | """ 259 | @staticmethod 260 | def forward(ctx, x): 261 | ctx.save_for_backward(x) 262 | return x ** 2 263 | 264 | @staticmethod 265 | def backward(ctx, grad): 266 | x, = ctx.saved_tensors 267 | x = x.clamp(min=-1, max=1) 268 | return grad * 2 * x 269 | 270 | 271 | class VarianceLoss(nn.Module): 272 | def forward(self, scores): 273 | """ 274 | :param scores: Shape batch_size x D, dtype float 275 | :return: Shape BS. D times the variance of scores over dimension 1. 276 | 277 | Note: Gradients are clipped in backward pass. 278 | """ 279 | mean = scores.mean(1).unsqueeze(1) 280 | return (SmoothSquare.apply(scores - mean)).sum() 281 | 282 | 283 | class SmoothnessPenalty(nn.Module): 284 | """ 285 | This loss penalizes fluctuations of a feature vector within superpixels. 286 | """ 287 | def __init__(self): 288 | super(SmoothnessPenalty, self).__init__() 289 | self.loss = VarianceLoss() 290 | 291 | def forward(self, embedding, superpixel): 292 | """ 293 | :param embedding: Shape D x H x W, dtype float 294 | :param superpixel: Shape H x W, dtype int 295 | :return: Scalar 296 | 297 | Each integer in the H x W tensor superpixels denotes a region in the image found by a superpixel algorithm 298 | from sklearn. The loss penalizes the variance of the embedding vector over each superpixel. 299 | """ 300 | if len(embedding.shape) == 2: 301 | embedding = embedding.unsqueeze(0) 302 | losses = [] 303 | for i in torch.unique(superpixel): 304 | scores = embedding[:, superpixel == i] 305 | losses.append(self.loss(scores)) 306 | return sum(losses) 307 | -------------------------------------------------------------------------------- /source/losses/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from config import ObjectnessLossConfig, global_config 5 | from losses.loss_utils import focal_loss_function, SmoothnessPenalty 6 | 7 | 8 | class LossFunction: 9 | def __init__(self): 10 | self.loss_summary_length = 0 # the length of the output of the loss function 11 | self.prioritized_replay = False 12 | 13 | def __call__(self, model_predictions: tuple, targets: tuple, weights, *superpixels): 14 | """ 15 | :param model_predictions: The output of the model's forward 16 | :param targets: The targets supplied by the memory's iterator 17 | :param weights: The weights for weighting each datapoint's gradient (for bias reduction in prioritized replay) 18 | :param superpixels: Optionally, the list of superpixels for the images corresponding to the model_predictions. 19 | :return: The loss, that is one or several scalars 20 | """ 21 | raise NotImplementedError 22 | 23 | def process_feedback(self, actions: list, feedback: list, *superpixels: list): 24 | """ 25 | :param actions: for each data point, a list of poking locations used in its collection 26 | :param feedback: for each data point, the feedback received for each of the poking locations 27 | :param superpixels: Optionally, the superpixels corresponding to each data point 28 | :return: The targets for each data point, in a format ready to be added to the memory 29 | 30 | This function post-processes the feedback obtained from the Actors, in a format ready to be added to the memory. 31 | """ 32 | raise NotImplementedError 33 | 34 | def compute_priorities(self, losses: list): 35 | """ 36 | :param losses: list of losses for each data point 37 | :return: list of priorities for each data point, to be used by the replay memory for optional prioritized replay 38 | """ 39 | raise NotImplementedError 40 | 41 | 42 | class ObjectnessLoss(LossFunction): 43 | """ 44 | A superclass that implements certain utilities used by models that learn an objectness score for greedy poking. 45 | It is not a stand alone loss, and does not implement the __call__ method. It does implement the process_feedback 46 | method. 47 | 48 | The loss encourages the model to achieve the following 49 | - Pixel objectness logits are large for pixels that are likely to move when poked. 50 | - Optionally, features are encouraged to be constant along a superpixel. 51 | - An optional filter logit is large when an image is likely to contain easy to poke objects. 52 | 53 | If the filter logits are used, images whose filter logit is too negative will not be used to learn 54 | objectness or features used in instance mask prediction. 55 | """ 56 | 57 | def __init__(self, loss_config: ObjectnessLossConfig): 58 | super(ObjectnessLoss, self).__init__() 59 | self.prioritized_replay = loss_config.prioritized_replay 60 | self.config = loss_config 61 | self.focal_loss = focal_loss_function(loss_config.objectness_weight) 62 | if global_config.superpixels and loss_config.smoothness_weight > 0: 63 | self.smoothness_loss = SmoothnessPenalty() 64 | 65 | def __call__(self, model_predictions: tuple, targets: tuple, weights: list, *superpixels): 66 | raise NotImplementedError 67 | 68 | def process_feedback(self, actions: list, feedback: list, superpixels=None): 69 | targets = [] 70 | num_successes = 0 71 | if not superpixels: 72 | superpixels = [None] * len(actions) 73 | for act, fb, superpixel in zip(actions, feedback, superpixels): 74 | target, new_successes = self.process_single_feedback(act, fb, superpixel) 75 | targets.append(target) 76 | num_successes += new_successes 77 | return targets, num_successes 78 | 79 | def process_single_feedback(self, actions, feedbacks, superpixel): 80 | object_masks = np.zeros((global_config.max_pokes, 81 | global_config.grid_size, global_config.grid_size), dtype=bool) 82 | foreground_mask = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 83 | background_mask = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 84 | successes = 0 85 | 86 | for i, (action, mask) in enumerate(zip(actions, feedbacks)): 87 | weights = self.get_neighbourhood(action['point'], superpixel) 88 | score = self.get_score(mask, action['point']) 89 | if score > self.config.foreground_threshold: 90 | object_masks[i, mask] = True 91 | foreground_mask += weights 92 | successes += 1 93 | else: 94 | background_mask += weights 95 | return (object_masks, foreground_mask, background_mask), successes 96 | 97 | def get_score(self, mask, action): 98 | if not self.config.localize_object_around_poking_point: 99 | return mask.sum() 100 | x, y = action 101 | dx1 = min(x, self.config.kernel_size) 102 | dx2 = min(global_config.grid_size - 1 - x, self.config.kernel_size) + 1 103 | dy1 = min(y, self.config.kernel_size) 104 | dy2 = min(global_config.grid_size - 1 - y, self.config.kernel_size) + 1 105 | x1, x2, y1, y2 = x - dx1, x + dx2, y - dy1, y + dy2 106 | return (mask[x1:x2, y1:y2] * self.config.check_change_kernel[self.config.kernel_size - dx1: 107 | self.config.kernel_size + dx2, 108 | self.config.kernel_size - dy1: 109 | self.config.kernel_size + dy2]).sum() 110 | 111 | def get_neighbourhood(self, action, superpixel): 112 | x, y = action 113 | if self.config.point_feedback_for_action: 114 | weights = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 115 | weights[x, y] = 1 116 | elif superpixel is not None and self.config.superpixel_for_action_feedback: 117 | weights = (superpixel == superpixel[x, y]).astype(np.float32) 118 | else: 119 | weights = np.zeros((global_config.grid_size, global_config.grid_size), dtype=np.float32) 120 | dx1 = min(x, self.config.kernel_size) 121 | dx2 = min(global_config.grid_size - 1 - x, self.config.kernel_size) + 1 122 | dy1 = min(y, self.config.kernel_size) 123 | dy2 = min(global_config.grid_size - 1 - y, self.config.kernel_size) + 1 124 | x1, x2, y1, y2 = x - dx1, x + dx2, y - dy1, y + dy2 125 | weights[x1:x2, y1:y2] = self.config.kernel[self.config.kernel_size - dx1: 126 | self.config.kernel_size + dx2, 127 | self.config.kernel_size - dy1: 128 | self.config.kernel_size + dy2] 129 | return weights 130 | 131 | def compute_filter_and_smoothness_loss(self, filter_logits, features_for_smoothing: tuple, 132 | object_masks, weights, superpixels): 133 | filter_and_smoothness_losses = [] 134 | if self.config.filter: 135 | filter_logit_positives = (object_masks.sum(dim=(1, 2, 3)) > 0).float() 136 | filter_logit_negatives = 1 - filter_logit_positives 137 | filter_logits_losses = torch.stack([self.focal_loss(fl, flp, fln, weight) for weight, fl, flp, fln 138 | in zip(weights, filter_logits, filter_logit_positives, 139 | filter_logit_negatives)]) 140 | filter_and_smoothness_losses.append(filter_logits_losses.sum()) 141 | 142 | if superpixels is not None and self.config.smoothness_weight > 0: 143 | smoothness_losses = [] 144 | features_for_smoothing = list(zip(*features_for_smoothing)) 145 | 146 | for weight, features, superpixel in zip(weights, features_for_smoothing, superpixels): 147 | smoothness_loss = sum(self.smoothness_loss(feature, superpixel) for feature in features) 148 | smoothness_losses.append(smoothness_loss * self.config.smoothness_weight * weight) 149 | filter_and_smoothness_losses.append(torch.stack(smoothness_losses).sum()) 150 | 151 | return filter_and_smoothness_losses 152 | 153 | def compute_objectness_loss(self, objectness, foreground, background, weight): 154 | b = objectness.shape[0] > 1 155 | objectness, uncertainty = objectness[0], objectness[1] if b else None 156 | objectness_loss = self.focal_loss(objectness, foreground, background, weight) 157 | if b: 158 | uncertainty_foreground = foreground * (objectness < 0) + background * (objectness > 0) 159 | uncertainty_background = foreground * (objectness >= 0) + background * (objectness <= 0) 160 | uncertainty_loss = self.focal_loss(uncertainty, uncertainty_foreground, uncertainty_background, weight) 161 | return objectness_loss + uncertainty_loss 162 | return objectness_loss 163 | 164 | @staticmethod 165 | def stack_losses(losses, device): 166 | losses = [torch.stack(l) if len(l) > 0 else torch.tensor(0.).to(device) for l in losses] 167 | losses = [l.sum() / (l > 0).sum().clamp(min=1) for l in losses] 168 | return losses 169 | 170 | def compute_priorities(self, losses: list): 171 | priorities = [] 172 | for loss in zip(*losses): 173 | score = min(iou.item() if iou > 0 else self.config.prioritize_default for iou in loss) 174 | priorities.append(self.config.prioritize_function(score)) 175 | return priorities 176 | -------------------------------------------------------------------------------- /source/losses/rpn_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import pad 3 | 4 | from config import ObjectnessRPNLossConfig 5 | from losses.losses import ObjectnessLoss 6 | from config import global_config 7 | from models.rpn_models import RoIModule 8 | 9 | 10 | class ObjectnessRPNLoss(ObjectnessLoss): 11 | """ 12 | This loss assumes that the model predicts: anchor box objectness logits, mask logits for a choice of anchor boxes 13 | (anchor box selection is greedy / not differentiable), pixel objectness logits, (optional filter logit) 14 | 15 | The losses for objectness and filter logit are described in the parent class. 16 | - The loss for anchor box objectness is focal loss, with positive/negative targets if IoU of anchor box with 17 | poking mask is sufficiently large/small. I.e. standard loss for RPN 18 | - The loss for mask logits is focal loss, with targets the poking mask that has highest IoU with anchor box 19 | corresponding to the mask logits. This is almost (but not quite) like in Mask R-CNN. 20 | """ 21 | 22 | def __init__(self, loss_config: ObjectnessRPNLossConfig): 23 | super(ObjectnessRPNLoss, self).__init__(loss_config) 24 | self.loss_summary_length = 3 + loss_config.regression + global_config.superpixels + loss_config.filter 25 | self.roi_module = RoIModule(loss_config.roi_config) 26 | self.anchor_regression_loss = anchor_regression_loss_function(loss_config.regression_weight) 27 | 28 | def __call__(self, model_predictions: tuple, targets: tuple, weights, superpixels=None): 29 | poking_scores, masks, anchor_box_scores, anchor_box_regressions, _, selected_anchors = model_predictions 30 | device = masks.device 31 | object_masks, foreground_masks, background_masks = targets 32 | 33 | filter_and_smoothness_losses = \ 34 | self.compute_filter_and_smoothness_loss(None, (poking_scores,), object_masks, weights, superpixels) 35 | 36 | anchor_score_losses, mask_losses, objectness_losses = [], [], [] 37 | anchor_regression_losses = [] if self.config.regression else None 38 | 39 | for roi_masks, scores, regressions, anchors, objectness, masks, foreground, background, weight in zip( 40 | masks, anchor_box_scores, anchor_box_regressions, selected_anchors, poking_scores, 41 | object_masks, foreground_masks, background_masks, weights): 42 | 43 | poking_locations = foreground + background 44 | with torch.no_grad(): 45 | masks_cum = pad(masks.cumsum(-1).cumsum(-2), [1, 0, 1, 0]) 46 | poking_cum = pad(poking_locations.cumsum(-1).cumsum(-2), [1, 0, 1, 0]) 47 | 48 | # RPN stage loss 49 | regressed_anchors = self.make_anchors_for_loss(regressions) 50 | anchor_score_targets, regression_matches = self.roi_module.match_anchors(regressed_anchors, regressions, 51 | masks_cum, poking_cum) 52 | positives, negatives = anchor_score_targets 53 | 54 | anchor_score_loss = self.focal_loss(scores, positives.float(), negatives.float(), weight) 55 | anchor_score_losses.append(anchor_score_loss) 56 | 57 | if self.config.regression: 58 | anchor_regression_loss = self.compute_anchor_regression_loss(regression_matches, weight, device) 59 | anchor_regression_losses.append(anchor_regression_loss) 60 | 61 | # Mask stage loss 62 | mask_loss = self.compute_mask_loss(roi_masks, anchors, masks, masks_cum, weight) 63 | mask_losses.append(mask_loss) 64 | 65 | # Poking loss 66 | objectness_loss = self.compute_objectness_loss(objectness, foreground, background, weight) 67 | objectness_losses.append(objectness_loss) 68 | 69 | losses = [anchor_score_losses, mask_losses, objectness_losses] 70 | if anchor_regression_losses is not None: 71 | losses.append(anchor_regression_losses) 72 | 73 | losses = self.stack_losses(losses, device) 74 | 75 | losses += filter_and_smoothness_losses 76 | 77 | if self.prioritized_replay: 78 | priorities = self.compute_priorities([mask_losses]) 79 | return losses, priorities 80 | 81 | return losses 82 | 83 | def compute_anchor_regression_loss(self, matches, weight, device): 84 | if matches is None: 85 | anchor_regression_loss = torch.tensor(0, dtype=torch.float32).to(device) 86 | else: 87 | anchor_regression_loss = self.anchor_regression_loss(*(matches + (weight,))) / matches[0].shape[0] 88 | 89 | return anchor_regression_loss 90 | 91 | def compute_mask_loss(self, roi_masks, selected_anchors, masks, masks_cum, weight): 92 | ious, _ = self.roi_module.compute_ious(masks_cum.unsqueeze(0), anchor_boxes=selected_anchors.unsqueeze(0)) 93 | matched_masks = self.roi_module.match_masks(ious.squeeze(0), masks) 94 | if self.config.robustify is not None: 95 | matched_masks = self.robustify_targets(roi_masks, matched_masks) 96 | positives = matched_masks.float() 97 | negatives = 1 - positives 98 | return self.focal_loss(roi_masks, positives, negatives, weight) 99 | 100 | def robustify_targets(self, roi_masks, matched_masks, step=0.1, max_iter=25): 101 | # This is similar to using the "robust set loss" proposed in 102 | # Pathak et al. "Learning instance segmentation by interaction" 103 | iou_threshold = self.config.robustify 104 | with torch.no_grad(): 105 | log_probs = roi_masks.clone() 106 | indexer = torch.zeros_like(matched_masks, dtype=torch.bool) 107 | 108 | for i in range(max_iter): 109 | iou_orig = self.iou(log_probs > 0, matched_masks) 110 | unconverged = ~(iou_orig > iou_threshold) 111 | if not torch.any(unconverged): 112 | break 113 | 114 | indexer *= False 115 | indexer[unconverged] = matched_masks[unconverged] 116 | log_probs[indexer] += step 117 | iou_up_in = self.iou(log_probs > 0, matched_masks) 118 | 119 | log_probs[indexer] -= step 120 | indexer *= False 121 | indexer[unconverged] = ~ matched_masks[unconverged] 122 | log_probs[indexer] -= step 123 | iou_down_out = self.iou(log_probs > 0, matched_masks) 124 | 125 | indexer *= False 126 | indexer[unconverged] = matched_masks[unconverged] 127 | log_probs[indexer] += step 128 | 129 | improved_in = (iou_up_in > iou_orig) * unconverged 130 | indexer *= False 131 | indexer[improved_in] = ~ matched_masks[improved_in] 132 | log_probs[indexer] += step 133 | 134 | improved_out = (~ improved_in) * (iou_down_out > iou_orig) * unconverged 135 | indexer *= False 136 | indexer[improved_out] = matched_masks[improved_out] 137 | log_probs[indexer] -= step 138 | 139 | new_masks = log_probs > 0 140 | 141 | return new_masks 142 | 143 | @staticmethod 144 | def iou(mask1, mask2): 145 | intersection = (mask1 * mask2).sum(dim=(-2, -1)).float() 146 | union = (mask1 | mask2).sum(dim=(-2, -1)).float().clamp(min=1) 147 | return intersection / union 148 | 149 | def make_anchors_for_loss(self, regressions): 150 | anchors = [] 151 | for delta in self.config.deltas: 152 | anchors.append(self.roi_module.make_regressed_anchors(regressions, delta)) 153 | return torch.stack(anchors) 154 | 155 | def cuda(self, k: int): 156 | self.roi_module.cuda(k) 157 | return self 158 | 159 | 160 | def anchor_regression_loss_function(hyperparameter): 161 | class AnchorRegressionLoss(torch.autograd.Function): 162 | @staticmethod 163 | def forward(ctx, logits, boxes, masks, weight): 164 | space = tuple(masks.shape[-2:]) 165 | masks = masks.unsqueeze(0).expand(9, -1, -1, -1) 166 | mask_areas = masks[..., -1, -1] 167 | box_areas = (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1]) 168 | anch0 = boxes[..., 0].reshape(-1) 169 | anch1 = boxes[..., 1].reshape(-1) 170 | anch2 = boxes[..., 2].reshape(-1) 171 | anch3 = boxes[..., 3].reshape(-1) 172 | flat_masks = masks.reshape(*((-1,) + space)) 173 | m = torch.arange(anch0.shape[0]) 174 | int0 = flat_masks[m, anch0, anch1] 175 | int1 = flat_masks[m, anch2, anch3] 176 | int2 = flat_masks[m, anch0, anch3] 177 | int3 = flat_masks[m, anch2, anch1] 178 | intersections = (int0 + int1 - int2 - int3).view(9, -1).float() 179 | ious = intersections / (mask_areas + box_areas - intersections).clamp(min=1) 180 | mask = ious[1:] < ious[0].unsqueeze(0) 181 | ctx.save_for_backward(logits, mask, weight) 182 | return mask.float().mean() 183 | 184 | @staticmethod 185 | def backward(ctx, dummy_grad): 186 | logits, mask, weight = ctx.saved_tensors 187 | mask = mask.view(4, 2, -1) 188 | sign = (mask[:, 0, :].float() - mask[:, 1, :].float()) 189 | return - sign.t() * (-logits ** 2).exp() * hyperparameter * weight, None, None, None 190 | 191 | return AnchorRegressionLoss.apply 192 | -------------------------------------------------------------------------------- /source/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/learning_from_interaction/a266bc16d682832aa854348fa557a30d86b84674/source/models/__init__.py -------------------------------------------------------------------------------- /source/models/backbones.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | # from detectron2 import model_zoo 4 | # from detectron2.config import get_cfg 5 | # from detectron2.modeling import build_backbone 6 | 7 | from config import BackboneConfig, global_config 8 | from models.model_utils import ResBlock, UpConv 9 | 10 | 11 | class UNetBackbone(nn.Module): 12 | def __init__(self, model_config: BackboneConfig): 13 | super(UNetBackbone, self).__init__() 14 | self.config = model_config 15 | 16 | self.conv1 = nn.Sequential(nn.Conv2d(3 + global_config.depth, 32, 17 | kernel_size=5, stride=3, padding=2, bias=False), 18 | nn.BatchNorm2d(32), 19 | nn.ReLU()) 20 | if model_config.small: 21 | self.block1 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=5, padding=2, stride=2, bias=False), 22 | nn.BatchNorm2d(64), nn.ReLU()) 23 | self.block2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=5, padding=2, stride=2, bias=False), 24 | nn.BatchNorm2d(128), nn.ReLU()) 25 | self.block3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=5, padding=2, stride=2, bias=False), 26 | nn.BatchNorm2d(256), nn.ReLU()) 27 | else: 28 | self.block1 = ResBlock(32, first_kernel_size=3) 29 | self.block2 = ResBlock(64, first_kernel_size=3) 30 | self.block3 = ResBlock(128, first_kernel_size=3) 31 | 32 | self.upconv1 = UpConv(256, 128, 128) 33 | self.upconv2 = UpConv(128, 64, 64) 34 | self.upconv3 = UpConv(64, 32, 64) 35 | 36 | def forward(self, x): 37 | """ 38 | :param x: Shape BS x 3(+1) x resolution x resolution 39 | :return: (Shape BS x D x grid_size x grid_size, BS x grid_size x grid_size), BS x D' (optional) 40 | """ 41 | x = self.conv1(x) 42 | x1 = self.block1(x) 43 | x2 = self.block2(x1) 44 | x3 = self.block3(x2) 45 | y3 = self.upconv1(x3, x2) 46 | y2 = self.upconv2(y3, x1) 47 | y1 = self.upconv3(y2, x) 48 | y = torch.nn.functional.interpolate(y1, size=(global_config.grid_size, global_config.grid_size), 49 | mode='bilinear') 50 | 51 | return y, x3.mean(dim=(2, 3)).unsqueeze(2).unsqueeze(3) 52 | 53 | 54 | # class R50FPNBackbone(nn.Module): 55 | # def __init__(self): 56 | # super(R50FPNBackbone, self).__init__() 57 | # self.backbone = build_backbone(make_rpn50_fpn_config()) 58 | # 59 | # def forward(self, x): 60 | # x = torch.nn.functional.interpolate(x, size=(800, 800), mode='bilinear') 61 | # y = self.backbone(x) 62 | # y_filter = y['p6'].mean(dim=(2, 3)).unsqueeze(2).unsqueeze(3) 63 | # y = y['p3'] 64 | # if global_config.grid_size != 100: 65 | # y = torch.nn.functional.interpolate(y, size=(global_config.grid_size, 66 | # global_config.grid_size), mode='bilinear') 67 | # return y, y_filter 68 | # 69 | # 70 | # def make_rpn50_fpn_config(): 71 | # cfg = get_cfg() 72 | # cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml")) 73 | # cfg.MODEL.PIXEL_MEAN = cfg.MODEL.PIXEL_MEAN + [1.5] * global_config.depth 74 | # cfg.MODEL.PIXEL_STD = cfg.MODEL.PIXEL_STD + [1.] * global_config.depth 75 | # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0. 76 | # cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 77 | # cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = True 78 | # return cfg 79 | -------------------------------------------------------------------------------- /source/models/clustering_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from random import sample 5 | 6 | from config import ClusteringModelConfig, global_config 7 | from models.backbones import UNetBackbone # , R50FPNBackbone 8 | from models.model import Model 9 | 10 | 11 | class ClusteringModel(Model): 12 | """ 13 | This model consists of a backbone that produces features at resolution grid_size x grid_size, and three heads. 14 | 15 | The first head predicts a D x grid_size x grid_size tensor of D dimensional grid cell feature vectors. 16 | Segmentation masks are computed by clustering these feature vectors around the feature vectors of seed grid cells. 17 | 18 | The second head predicts confidence scores for the seed grid cells. Seed grid cells are greedily picked according 19 | to their confidence scores as poking locations, and to determine masks. 20 | 21 | The third head predicts three mass logits. These are sampled at poking locations, and determine the force used 22 | for poking. 23 | 24 | Optionally, the network also outputs a single scalar for each image, predicting whether the image contains any 25 | pokeable objects. This could be used in combination with an appropriate loss function that only backpropagates 26 | gradients into the feature vectors / confidence scores for those images that actually contain objects. 27 | """ 28 | 29 | def __init__(self, model_config: ClusteringModelConfig): 30 | super(ClusteringModel, self).__init__() 31 | self.config = model_config 32 | if model_config.backbone == 'r50fpn': 33 | self.backbone = R50FPNBackbone() 34 | else: 35 | self.backbone = UNetBackbone(model_config) 36 | 37 | backbone_out_dim = 256 if model_config.backbone == 'r50fpn' else 64 38 | self.head = nn.Sequential(nn.Conv2d(backbone_out_dim + 2, 128, kernel_size=1), nn.ReLU(), 39 | nn.Conv2d(128, model_config.out_dim + 1 + self.config.uncertainty, 1)) 40 | 41 | self.mass_head = nn.Sequential(nn.Conv2d(backbone_out_dim, 128, kernel_size=1), nn.ReLU(), 42 | nn.Conv2d(128, 3, 1)) 43 | 44 | if model_config.filter: 45 | self.filter_head = nn.Sequential(nn.Conv2d(256, 1024, kernel_size=1, stride=1, padding=0, bias=False), 46 | nn.BatchNorm2d(1024), 47 | nn.ReLU(), 48 | nn.Conv2d(1024, 1, kernel_size=1, stride=1, padding=0)) 49 | 50 | self.poking_grid = [dict(point=(i, j), force=force) for i in range(global_config.grid_size) 51 | for j in range(global_config.grid_size) for force in [0, 1, 2]] 52 | coordinate_embedding_grid = np.array([[[x, y] for y in range(global_config.grid_size)] 53 | for x in range(global_config.grid_size)]).transpose((2, 0, 1)) 54 | self.register_buffer('coordinate_embedding_grid', torch.from_numpy(coordinate_embedding_grid).float()) 55 | 56 | if model_config.distance_function == 'L2': 57 | self.distance_function = lambda x, y: ((x - y) ** 2).sum(0) 58 | elif model_config.distance_function == 'Cosine': 59 | distance_function = lambda x, y: 1 - (x * y).sum(0) / (np.linalg.norm(x, axis=0) 60 | * np.linalg.norm(y, axis=0) + 1e-4) 61 | inv = lambda x: x 62 | self.distance_function = lambda x, y: inv(distance_function(x, y)) 63 | 64 | def forward(self, x, targets=None): 65 | """ 66 | :param x: Shape BS x 3(+1) x resolution x resolution 67 | :param targets: Not used. This model is a single shot detection model. 68 | :return: (Shape BS x grid_size x grid_size, BS x 3 x grid_size x grid_size, BS x D x grid_size x grid_size), 69 | BS (optional) 70 | """ 71 | 72 | z, x3 = self.backbone(x) 73 | y = self.head(torch.cat([z, self.coordinate_embedding_grid.repeat(z.shape[0], 1, 1, 1)], dim=1)) 74 | 75 | obj = y[:, self.config.out_dim:] 76 | mass = self.mass_head(z) 77 | 78 | ret = (obj, mass, y[:, :self.config.out_dim]) 79 | 80 | if self.config.filter: 81 | ret = ret + (self.filter_head(x3).squeeze(),) 82 | 83 | return ret 84 | 85 | def compute_actions(self, images, pokes, episode, episodes): 86 | obj_pokes, random_pokes = pokes // 2, pokes - pokes // 2 # We chose half of actions random for exploration 87 | actions = [] 88 | 89 | with torch.no_grad(): 90 | out = self.forward(images) 91 | obj, m, emb = out[:3] 92 | embeddings = emb.cpu().numpy() 93 | objectnesss = obj[:, self.config.uncertainty].cpu().numpy() 94 | masss = m.argmax(dim=1).cpu().numpy() 95 | 96 | for embedding, objectness, mass in zip(embeddings, objectnesss, masss): 97 | action, i = [], 0 98 | while i < obj_pokes and objectness.max() > -10000: 99 | i += 1 100 | a, _ = self.action_and_mask(embedding, objectness, -10000) 101 | action.append(dict(point=a, force=mass[a[0], a[1]])) 102 | action += sample(self.poking_grid, random_pokes) 103 | actions.append(action) 104 | return actions, out 105 | 106 | def compute_masks(self, images, threshold): 107 | masks, scores, actions = [], [], [] 108 | self.eval() 109 | with torch.no_grad(): 110 | out = self.forward(images) 111 | obj, m, emb = out[:3] 112 | embeddings = emb.cpu().numpy() 113 | objectnesss = obj[:, 0].cpu().numpy() 114 | masss = m.argmax(dim=1).cpu().numpy() 115 | 116 | for embedding, objectness, mass in zip(embeddings, objectnesss, masss): 117 | action, mask, new_scores = [], [], [] 118 | score = objectness.max() 119 | i = 0 120 | while score > threshold and i < global_config.max_pokes: 121 | new_scores.append(float(1 / (1 + np.exp(-score)))) 122 | a, m = self.action_and_mask(embedding, objectness, threshold) 123 | mask.append(m) 124 | action.append(dict(point=a, force=mass[a[0], a[1]])) 125 | score = objectness.max() 126 | i += 1 127 | masks.append(mask) 128 | scores.append(new_scores) 129 | actions.append(action) 130 | return actions, masks, out, scores 131 | 132 | def action_and_mask(self, embedding, objectness, threshold): 133 | argmax = objectness.argmax() 134 | action = (argmax // global_config.grid_size, argmax % global_config.grid_size) 135 | mask = self.make_mask(embedding, objectness, action, threshold) 136 | return action, mask 137 | 138 | def make_mask(self, embedding, objectness, action, threshold): 139 | margin_threshold = self.config.margin_threshold[threshold > -10000] 140 | center = embedding[:, action[0], action[1]][:, None, None] 141 | distances = self.distance_function(embedding, center) 142 | mask = distances < self.config.threshold 143 | center = embedding[:, mask].mean(axis=1)[:, None, None] 144 | distances = self.distance_function(embedding, center) 145 | mask = distances < self.config.threshold 146 | if self.config.threshold != margin_threshold: 147 | mask2 = distances < margin_threshold 148 | else: 149 | mask2 = mask 150 | objectness[mask2] = threshold - 1 151 | if not self.config.overlapping_objects: 152 | embedding[:, mask2] = self.config.reset_value * np.ones_like(embedding[:, mask2]) 153 | return mask 154 | 155 | @staticmethod 156 | def upsample(mask): 157 | return mask.repeat(global_config.stride, axis=0).repeat(global_config.stride, axis=1) 158 | 159 | def load(self, path): 160 | state_dict = torch.load(path, map_location='cuda:%d' % global_config.model_gpu) 161 | print(self.load_state_dict(state_dict, strict=False)) 162 | if self.config.freeze: 163 | self.freeze_detection_net(False) 164 | 165 | def toggle_detection_net(self, freeze): 166 | for param in self.backbone.parameters(): 167 | param.requires_grad = freeze 168 | for param in self.head.parameters(): 169 | param.requires_grad = freeze 170 | 171 | def toggle_mass_head(self, freeze): 172 | for param in self.mass_head.parameters(): 173 | param.requires_grad = freeze 174 | 175 | 176 | class ClusteringModelPooled(Model): 177 | """ 178 | Similar to ClusteringModel, but mass predictions are instance-wise, computed on the mean pooled features of the 179 | instance mask. In some sense, this is a two stage detection approach. 180 | """ 181 | 182 | def __init__(self, model_config: ClusteringModelConfig): 183 | super(ClusteringModelPooled, self).__init__() 184 | self.config = model_config 185 | self.backbone = UNetBackbone(model_config) 186 | 187 | backbone_out_dim = 64 188 | self.head = nn.Sequential(nn.Conv2d(backbone_out_dim + 2, 128, kernel_size=1), nn.ReLU(), 189 | nn.Conv2d(128, model_config.out_dim + 1 + self.config.uncertainty, 1)) 190 | 191 | self.mass_head = nn.Sequential(nn.Linear(64, 256, bias=False), nn.ReLU(), 192 | nn.Linear(256, 3)) 193 | 194 | self.poking_grid = [((i, j), mass) for i in range(global_config.grid_size) 195 | for j in range(global_config.grid_size) for mass in [0, 1, 2]] 196 | coordinate_embedding_grid = np.array([[[x, y] for y in range(global_config.grid_size)] 197 | for x in range(global_config.grid_size)]).transpose((2, 0, 1)) 198 | self.register_buffer('coordinate_embedding_grid', torch.from_numpy(coordinate_embedding_grid).float()) 199 | 200 | if model_config.distance_function == 'L2': 201 | self.distance_function = lambda x, y: ((x - y) ** 2).sum(0) 202 | elif model_config.distance_function == 'Cosine': 203 | distance_function = lambda x, y: 1 - (x * y).sum(0) / (np.linalg.norm(x, axis=0) 204 | * np.linalg.norm(y, axis=0) + 1e-4) 205 | inv = lambda x: x 206 | self.distance_function = lambda x, y: inv(distance_function(x, y)) 207 | 208 | def forward(self, x, targets=None): 209 | z, x3 = self.backbone(x) 210 | y = self.head(torch.cat([z, self.coordinate_embedding_grid.repeat(z.shape[0], 1, 1, 1)], dim=1)) 211 | 212 | obj = y[:, self.config.out_dim:] 213 | 214 | mass_logits = self.compute_mass_logits(z, targets) 215 | 216 | return obj, y[:, :self.config.out_dim], mass_logits, z 217 | 218 | def compute_actions(self, images, pokes, episode, episodes): 219 | obj_pokes, random_pokes = pokes // 2, pokes - pokes // 2 220 | actions = [] 221 | 222 | with torch.no_grad(): 223 | out = self.forward(images) 224 | obj, emb, _, z = out 225 | embeddings = emb.cpu().numpy() 226 | objectnesss = obj[:, self.config.uncertainty].cpu().numpy() 227 | 228 | for embedding, objectness, features in zip(embeddings, objectnesss, z): 229 | action, i = [], 0 230 | while i < obj_pokes and objectness.max() > -10000: 231 | i += 1 232 | a, m = self.action_and_mask(embedding, objectness, -10000) 233 | force_logits = self.compute_single_action_mass_logits(m, features) 234 | action.append(dict(point=a, force=force_logits.argmax().item())) 235 | action += sample(self.poking_grid, random_pokes) 236 | actions.append(action) 237 | return actions, out 238 | 239 | def compute_masks(self, images, threshold): 240 | masks, scores, actions = [], [], [] 241 | self.eval() 242 | with torch.no_grad(): 243 | out = self.forward(images) 244 | obj, emb, _, z = out 245 | embeddings = emb.cpu().numpy() 246 | objectnesss = obj[:, self.config.uncertainty].cpu().numpy() 247 | 248 | for embedding, objectness, features in zip(embeddings, objectnesss, z): 249 | action, mask, new_scores = [], [], [] 250 | score = objectness.max() 251 | i = 0 252 | while score > threshold and i < global_config.max_pokes: 253 | new_scores.append(float(1 / (1 + np.exp(-score)))) 254 | a, m = self.action_and_mask(embedding, objectness, threshold) 255 | force_logits = self.compute_single_action_mass_logits(m, features) 256 | mask.append(m) 257 | action.append(dict(point=a, force=force_logits.argmax().item())) 258 | score = objectness.max() 259 | i += 1 260 | masks.append(mask) 261 | scores.append(new_scores) 262 | actions.append(action) 263 | return actions, masks, out, scores 264 | 265 | def action_and_mask(self, embedding, objectness, threshold): 266 | argmax = objectness.argmax() 267 | action = (argmax // global_config.grid_size, argmax % global_config.grid_size) 268 | mask = self.make_mask(embedding, objectness, action, threshold) 269 | return action, mask 270 | 271 | def make_mask(self, embedding, objectness, action, threshold): 272 | margin_threshold = self.config.margin_threshold[threshold > -10000] 273 | center = embedding[:, action[0], action[1]][:, None, None] 274 | distances = self.distance_function(embedding, center) 275 | mask = distances < self.config.threshold 276 | center = embedding[:, mask].mean(axis=1)[:, None, None] 277 | distances = self.distance_function(embedding, center) 278 | mask = distances < self.config.threshold 279 | if self.config.threshold != margin_threshold: 280 | mask2 = distances < margin_threshold 281 | else: 282 | mask2 = mask 283 | objectness[mask2] = threshold - 1 284 | if not self.config.overlapping_objects: 285 | embedding[:, mask2] = self.config.reset_value * np.ones_like(embedding[:, mask2]) 286 | return mask 287 | 288 | def compute_mass_logits(self, z, targets): 289 | targets = targets[0] if targets is not None else [None] * z.shape[0] 290 | mass_logits = [] 291 | for features, masks in zip(z, targets): 292 | logits = [] 293 | if masks is None: 294 | logits.append(torch.zeros(3, device=z.device)) 295 | else: 296 | for mask in masks: 297 | logits.append(self.compute_single_action_mass_logits(mask, features)) 298 | mass_logits.append(torch.stack(logits)) 299 | return torch.stack(mass_logits) 300 | 301 | def compute_single_action_mass_logits(self, mask, features): 302 | if mask.sum() > 0: 303 | pooled_features = features[:, mask].mean(dim=1).unsqueeze(0) 304 | logits = self.mass_head(pooled_features).squeeze(0) 305 | else: 306 | logits = torch.zeros(3, dtype=torch.float32).to(features.device) 307 | return logits 308 | 309 | @staticmethod 310 | def upsample(mask): 311 | return mask.repeat(global_config.stride, axis=0).repeat(global_config.stride, axis=1) 312 | 313 | def load(self, path): 314 | state_dict = torch.load(path, map_location='cuda:%d' % global_config.model_gpu) 315 | print(self.load_state_dict(state_dict, strict=False)) 316 | if self.config.freeze: 317 | self.toggle_detection_net(False) 318 | 319 | def toggle_detection_net(self, freeze): 320 | for param in self.backbone.parameters(): 321 | param.requires_grad = freeze 322 | for param in self.head.parameters(): 323 | param.requires_grad = freeze 324 | 325 | def toggle_mass_head(self, freeze): 326 | for param in self.mass_head.parameters(): 327 | param.requires_grad = freeze 328 | -------------------------------------------------------------------------------- /source/models/fgbg_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from skimage.segmentation import felzenszwalb 5 | from skimage.measure import label 6 | from torchvision import transforms 7 | 8 | from config import ModelConfigFgBg, global_config 9 | from models.backbones import UNetBackbone 10 | from models.model import Model 11 | 12 | 13 | class FgBgModel(Model): 14 | """ 15 | This is a foreground-background segmentation model rewired to propose instances by post-processing the foreground 16 | with superpixels and extracting connected components. 17 | 18 | It is not designed to be used as an active model (in the sense that its compute_action) is inefficient. 19 | """ 20 | def __init__(self, model_config: ModelConfigFgBg): 21 | super(FgBgModel, self).__init__() 22 | assert model_config.uncertainty is False, 'This is a passive model' 23 | self.config = model_config 24 | self.backbone = UNetBackbone(model_config) 25 | self.head = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1), nn.ReLU(), 26 | nn.Conv2d(128, 1 + self.config.uncertainty, 1)) 27 | self.poking_grid = [(i, j) for i in range(global_config.grid_size) for j in range(global_config.grid_size)] 28 | self.to_pil = transforms.ToPILImage() 29 | 30 | def forward(self, images: torch.tensor, targets=None): 31 | y, _ = self.backbone(images) 32 | y = self.head(y) 33 | return y, None 34 | 35 | def compute_actions(self, images: torch.tensor, num_pokes: int, episode: int, episodes: int): 36 | with torch.no_grad(): 37 | out, _ = self.forward(images) 38 | objectness = out[:, 0].sigmoid().cpu().numpy() 39 | 40 | x = [self.connected_components(mask, scores) 41 | for mask, scores in zip(objectness > .5, objectness)] 42 | 43 | actions = [[z[1] for z in y] for y in x] 44 | 45 | return 'this should never have run', actions, (out,) 46 | 47 | def compute_masks(self, images: torch.tensor, threshold: float): 48 | if type(images) == tuple and self.config.superpixel: 49 | images, superpixels = images[0], images[1] 50 | else: 51 | superpixels = self.compute_superpixels(images) 52 | 53 | with torch.no_grad(): 54 | out, _ = self.forward(images) 55 | objectness = out[:, 0].sigmoid().cpu().numpy() 56 | 57 | x = [self.connected_components(mask, scores) 58 | for mask, scores in zip(objectness > threshold, objectness)] 59 | 60 | pred_masks = [[z[0] for z in y] for y in x] 61 | actions = [[z[1] for z in y] for y in x] 62 | pred_scores = [[z[2] for z in y] for y in x] 63 | 64 | pred_masks = [[self.postprocess_mask_with_sp(mask, superpixels) for mask in masks] 65 | for masks, superpixels in zip(pred_masks, superpixels)] 66 | 67 | return actions, pred_masks, (out,), pred_scores 68 | 69 | def connected_components(self, mask, scores): 70 | fat_mask = self.fatten(mask) if self.config.fatten else mask 71 | labels = label(fat_mask) * mask 72 | labels[labels == 0] = -1 73 | 74 | mps = [] 75 | for i in np.unique(labels): 76 | if i == -1: 77 | continue 78 | m = labels == i 79 | if m.sum() > 5: 80 | s = float((m * scores).max()) 81 | p = (m * scores).argmax() 82 | p = dict(point=(p // global_config.grid_size, p % global_config.grid_size)) 83 | mps.append((m, p, s)) 84 | if len(mps) > 0: 85 | mps = sorted(mps, key=lambda x: x[-1], reverse=True) 86 | return mps 87 | 88 | def compute_superpixels(self, images): 89 | images = [np.array(self.to_pil(image)) for image in images[:, :3].cpu()] 90 | return [felzenszwalb(image, scale=200, sigma=.5, min_size=200)[::global_config.stride, 91 | ::global_config.stride].astype(np.int32) 92 | for image in images] 93 | 94 | @staticmethod 95 | def postprocess_mask_with_sp(mask, superpixels): 96 | smoothed_mask = np.zeros_like(mask) 97 | superpixels = [superpixels == i for i in np.unique(superpixels)] 98 | for superpixel in superpixels: 99 | if mask[superpixel].sum() / superpixel.sum() > .25: 100 | smoothed_mask[superpixel] = True 101 | return smoothed_mask 102 | 103 | @staticmethod 104 | def fatten(mask): 105 | fat_mask = mask.copy() 106 | fat_mask[:-1] = fat_mask[:-1] | mask[1:] 107 | fat_mask[1:] = fat_mask[1:] | mask[:-1] 108 | fat_mask[:, :-1] = fat_mask[:, :-1] | mask[:, 1:] 109 | fat_mask[:, 1:] = fat_mask[:, 1:] | mask[:, :-1] 110 | return fat_mask 111 | -------------------------------------------------------------------------------- /source/models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Model(nn.Module): 6 | """ 7 | Abstract Model class for instance segmentation. 8 | """ 9 | 10 | def __init__(self): 11 | super(Model, self).__init__() 12 | 13 | def forward(self, images: torch.tensor, *targets): 14 | """ 15 | As usual, this function feeds directly into the loss function, from which it also receives its gradients. 16 | The targets are an optional input, used for example in multi stage detectors with "teacher forcing" style 17 | training. 18 | """ 19 | raise NotImplementedError 20 | 21 | def compute_actions(self, images: torch.tensor, num_pokes: int, episode: int, episodes: int): 22 | """ 23 | This is the function through which the model is used during data collection. It is run with the model in 24 | eval mode, and no gradients are computed. The function should return num_pokes poking locations and forces for 25 | each image. It should also return the output of the forward pass, from which statistics can be computed to 26 | monitor the training. 27 | """ 28 | raise NotImplementedError 29 | 30 | def compute_masks(self, images: torch.tensor, threshold: float): 31 | """ 32 | This is the function through which the model performs inference. It is run in eval mode and no gradients are 33 | computed. For each image, it should predict poking locations of all objects, segmentation masks and relative 34 | masses for these objects, and confidence scores for these proposals. It should also return the output of the 35 | forward pass. 36 | """ 37 | raise NotImplementedError 38 | 39 | 40 | -------------------------------------------------------------------------------- /source/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class UpConv(nn.Module): 6 | """ 7 | This module upconvolutes by a factor of 2, and also has a lateral connection typical of UNets. 8 | """ 9 | def __init__(self, din, d_horizontal, dout): 10 | super(UpConv, self).__init__() 11 | self.upconv = nn.ConvTranspose2d(din, dout, kernel_size=2, stride=2, padding=0) 12 | self.conv = nn.Sequential(nn.Conv2d(d_horizontal + dout, dout, kernel_size=3, stride=1, padding=0), 13 | nn.BatchNorm2d(dout), 14 | nn.ReLU()) 15 | 16 | def forward(self, x, y): 17 | x = self.upconv(x) 18 | x, y = self.crop(x, y) 19 | return self.conv(torch.cat([x, y], dim=1)) 20 | 21 | @staticmethod 22 | def crop(x, y): 23 | width_x, width_y = x.shape[-1], y.shape[-1] 24 | if width_x == width_y: return x, y 25 | if width_x > width_y: 26 | low = (width_x - width_y) // 2 27 | high = width_x - width_y - low 28 | return x[:, :, low:-high, low:-high], y 29 | low = (width_y - width_x) // 2 30 | high = width_y - width_x - low 31 | return x, y[:, :, low:-high, low:-high] 32 | 33 | 34 | class ResBlock(nn.Module): 35 | """ 36 | This downconvolutes by a factor of 2. It is more light-weight than the original ResBlock. 37 | """ 38 | def __init__(self, din, first_kernel_size=1, small=0): 39 | super(ResBlock, self).__init__() 40 | self.first_conv = nn.Sequential( 41 | nn.Conv2d(din, 2 * din, kernel_size=first_kernel_size, padding=(first_kernel_size - 1) // 2, stride=1, 42 | bias=False, groups=din // small if small > 0 else 1), 43 | nn.BatchNorm2d(2 * din), 44 | nn.ReLU(), 45 | nn.Conv2d(2 * din, din, kernel_size=1, padding=0, stride=1, 46 | bias=False, groups=1), 47 | nn.BatchNorm2d(din), 48 | nn.ReLU() 49 | ) 50 | self.second_conv = nn.Sequential( 51 | nn.Conv2d(din, 2 * din, kernel_size=3, padding=1, stride=2, 52 | bias=False, groups=din // small if small > 0 else 1), 53 | nn.BatchNorm2d(2 * din), 54 | nn.ReLU() 55 | ) 56 | 57 | def forward(self, x): 58 | y = self.first_conv(x) 59 | return self.second_conv(x + y) 60 | -------------------------------------------------------------------------------- /source/models/poke_rcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from random import sample 5 | from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN 6 | from detectron2.modeling.postprocessing import detector_postprocess 7 | from detectron2.structures import Instances, Boxes, BitMasks 8 | from detectron2.utils.events import EventStorage 9 | 10 | from losses.losses import ObjectnessLossConfig, ObjectnessLoss 11 | from losses.loss_utils import focal_loss_function 12 | from models.model import Model 13 | from models.backbones import make_rpn50_fpn_config 14 | from config import global_config 15 | 16 | 17 | class PokeRCNN(Model): 18 | """ 19 | This wraps a standard detectron2 MaskRCNN (including standard losses) for instance segmentation, but also predicts 20 | objectness logits like the clustering models, and can therefore be used for fully self-supervised training. 21 | """ 22 | 23 | def __init__(self, uncertainty=False): 24 | super(PokeRCNN, self).__init__() 25 | self.mask_rcnn = MaskRCNNWithPokeHead(uncertainty) 26 | self.poking_grid = [(i, j) for i in range(global_config.grid_size) for j in range(global_config.grid_size)] 27 | self.register_buffer('background_mask', torch.zeros(1, 800, 800, dtype=torch.int)) 28 | self.register_buffer('background_box', torch.tensor([[1, 1, 799, 799]])) 29 | 30 | def forward(self, images: torch.tensor, targets=None): 31 | batched_inputs = self.rescale_and_zip(images, targets) 32 | if targets is None: 33 | return self.mask_rcnn.inference(batched_inputs) 34 | return self.mask_rcnn(batched_inputs) 35 | 36 | def rescale_and_zip(self, images, targets=None): 37 | with torch.no_grad(): 38 | if targets is None: 39 | targets = [None] * images.shape[0] 40 | else: 41 | targets = list(zip(*targets)) 42 | batched_output = [] 43 | for image, target in zip(images, targets): 44 | d = {"image": nn.functional.interpolate(image.unsqueeze(0), (800, 800), mode='bilinear').squeeze(0)} 45 | if target is not None: 46 | masks, foreground, background = target 47 | instances = self.scale_and_process_masks(masks) 48 | d["instances"] = instances 49 | d["poking_targets"] = torch.stack([foreground, background]) 50 | batched_output.append(d) 51 | return batched_output 52 | 53 | def scale_and_process_masks(self, masks): 54 | device = masks.device 55 | dummy_mask = torch.zeros_like(masks[0]).unsqueeze(0) 56 | non_emptys = masks.sum(dim=(1, 2)) > 0 57 | non_empty = non_emptys.sum().item() 58 | masks = torch.cat([masks[non_emptys], dummy_mask], dim=0) if non_empty else dummy_mask 59 | masks = (nn.functional.interpolate(masks.float().unsqueeze(1), 60 | size=(800, 800)) > .5).squeeze(1) 61 | 62 | if non_empty: 63 | box_coordinates = [torch.where(mask) for mask in masks[:-1]] 64 | box_coordinates = torch.tensor([[x[1].min(), x[0].min(), x[1].max(), x[0].max()] for x in box_coordinates]) 65 | box_coordinates = torch.cat([box_coordinates.to(device), self.background_box], dim=0) 66 | else: 67 | box_coordinates = self.background_box 68 | 69 | instances = Instances((800, 800)) 70 | instances.gt_boxes = Boxes(box_coordinates) 71 | instances.gt_masks = BitMasks(masks) 72 | classes = torch.zeros(non_empty + 1, dtype=torch.int64) 73 | classes[-1] = 1 74 | instances.gt_classes = classes 75 | return instances.to(device) 76 | 77 | @staticmethod 78 | def select_largest_on_mask(mask, scores): 79 | mask = mask.reshape(global_config.grid_size, global_config.stride, 80 | global_config.grid_size, global_config.stride).mean(axis=(1, 3)) > .5 81 | argmax = (mask * scores).argmax() if np.any(mask) else scores.argmax() 82 | return argmax // global_config.grid_size, argmax % global_config.grid_size 83 | 84 | def compute_actions(self, images: torch.tensor, num_pokes: int, episode: int, episodes: int): 85 | with torch.no_grad(): 86 | results, poking_scores = self.forward(images) 87 | poking_scores_numpy = poking_scores[:, 0].sigmoid().cpu().numpy() 88 | detections = [result['instances'].pred_masks.cpu().numpy() for result in results] 89 | 90 | actions = [] 91 | for masks, poking_score in zip(detections, poking_scores_numpy): 92 | action = [] 93 | for mask in masks[:num_pokes // 2]: 94 | point = self.select_largest_on_mask(mask, poking_score) 95 | action.append(dict(point=point)) 96 | action += sample(self.poking_grid, num_pokes - len(action)) 97 | actions.append(action) 98 | 99 | return actions, (poking_scores,) 100 | 101 | def compute_masks(self, images: torch.tensor, threshold: float): 102 | ret_masks, scores, actions = [], [], [] 103 | self.eval() 104 | with torch.no_grad(): 105 | results, poking_scores = self.forward(images) 106 | poking_scores_numpy = poking_scores[:, 0].sigmoid().cpu().numpy() 107 | detections = [(result['instances'].pred_masks.cpu().numpy(), result['instances'].scores.cpu().numpy(), 108 | result['instances'].pred_classes.cpu().numpy()) 109 | for result in results] 110 | 111 | for (masks, mask_scores, classes), poking_score in zip(detections, poking_scores_numpy): 112 | action, new_masks, new_scores = [], [], [] 113 | for mask, mask_score, cl, _ in zip(masks, mask_scores, classes, [None] * global_config.max_pokes): 114 | if cl > 0: 115 | continue 116 | if mask_score < threshold: 117 | break 118 | point = self.select_largest_on_mask(mask, poking_score) 119 | action.append(dict(point=point)) 120 | new_masks.append(mask) 121 | new_scores.append(mask_score) 122 | ret_masks.append(new_masks) 123 | scores.append(new_scores) 124 | actions.append(action) 125 | return actions, ret_masks, (poking_scores,), scores 126 | 127 | 128 | class MaskRCNNWithPokeHead(GeneralizedRCNN): 129 | def __init__(self, uncertainty=True): 130 | super(MaskRCNNWithPokeHead, self).__init__(make_rpn50_fpn_config()) 131 | self.poking_head = nn.Sequential(nn.Conv2d(256, 64, kernel_size=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(), 132 | nn.Conv2d(64, 2, kernel_size=1)) 133 | 134 | self.poking_loss = MaskPokingLoss(uncertainty) 135 | self.event_storage = EventStorage() 136 | 137 | def forward(self, batched_inputs): 138 | with self.event_storage: 139 | images = self.preprocess_image(batched_inputs) 140 | gt_instances = [x["instances"] for x in batched_inputs] 141 | features = self.backbone(images.tensor) 142 | proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 143 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) 144 | poking_scores = self.poking_head(features['p3']) 145 | poking_targets = torch.stack([x['poking_targets'] for x in batched_inputs]) 146 | poking_losses = self.poking_loss(poking_scores, poking_targets) 147 | 148 | losses = list(detector_losses.values()) + list(proposal_losses.values()) + [poking_losses] 149 | return losses 150 | 151 | def inference(self, batched_inputs, *kwargs): 152 | images = self.preprocess_image(batched_inputs) 153 | features = self.backbone(images.tensor) 154 | proposals, _ = self.proposal_generator(images, features, None) 155 | results, _ = self.roi_heads(images, features, proposals, None) 156 | poking_scores = self.poking_head(features['p3']) 157 | return self.postprocess(results, batched_inputs), poking_scores 158 | 159 | @staticmethod 160 | def postprocess(instances, batched_inputs): 161 | processed_results = [] 162 | for results_per_image, input_per_image in zip(instances, batched_inputs): 163 | r = detector_postprocess(results_per_image.to('cpu'), 300, 300) 164 | processed_results.append({"instances": r}) 165 | return processed_results 166 | 167 | 168 | class DummyObjectnessLoss(ObjectnessLoss): 169 | def __init__(self, conf: ObjectnessLossConfig): 170 | super(DummyObjectnessLoss, self).__init__(conf) 171 | assert conf.prioritized_replay is False 172 | self.loss_summary_length = 6 173 | # NOTE: Since per-image mask losses are not accessible in MaskRCNN, prioritized replay is not supported. 174 | 175 | def __call__(self, losses, targets, weights, superpixels=None): 176 | if type(losses) == list: 177 | return losses 178 | return torch.zeros(self.loss_summary_length) 179 | 180 | 181 | class MaskPokingLoss(nn.Module): 182 | def __init__(self, uncertainty): 183 | super(MaskPokingLoss, self).__init__() 184 | self.uncertainty = uncertainty 185 | self.loss = focal_loss_function(1) 186 | self.register_buffer('dummy_weight', torch.tensor(1, dtype=torch.float32)) 187 | 188 | def forward(self, poking_scores, poking_targets): 189 | foreground, background = poking_targets[:, 0], poking_targets[:, 1] 190 | objectness, uncertainty = poking_scores[:, 0], poking_scores[:, 1] 191 | objectness_loss = self.loss(objectness, foreground, background, self.dummy_weight) 192 | if self.uncertainty: 193 | unc_foreground = (objectness >= 0) * background + (objectness <= 0) * foreground 194 | unc_background = (objectness > 0) * foreground + (objectness < 0) * background 195 | else: 196 | unc_foreground = foreground 197 | unc_background = background 198 | uncertainty_loss = self.loss(uncertainty, unc_foreground, unc_background, self.dummy_weight) 199 | return objectness_loss + uncertainty_loss 200 | -------------------------------------------------------------------------------- /source/models/rpn_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from random import sample 4 | 5 | from config import ROIModuleConfig, global_config, RPNModelConfig 6 | from models.backbones import UNetBackbone 7 | from models.model import Model 8 | 9 | 10 | class RPNWithMask(Model): 11 | """ 12 | This is a DeepMask inspired region proposal network with instance mask and objectness score predictions. It can 13 | be trained self-supervised. Anchor box regression is optional. 14 | """ 15 | def __init__(self, model_config: RPNModelConfig): 16 | super(RPNWithMask, self).__init__() 17 | self.config = model_config 18 | 19 | self.backbone = UNetBackbone(model_config) 20 | 21 | self.anchor_box_scoring_head = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, stride=2), 22 | nn.BatchNorm2d(128), 23 | nn.ReLU(), 24 | nn.Conv2d(128, (1 + 4 * model_config.regression) 25 | * model_config.num_anchors, kernel_size=2)) 26 | 27 | self.masking_head = nn.Sequential(nn.Conv2d(64, 32, kernel_size=5, padding=2, bias=False), 28 | nn.BatchNorm2d(32), 29 | nn.ReLU(), 30 | nn.Conv2d(32, 32, kernel_size=5, padding=2, bias=False), 31 | nn.BatchNorm2d(32), 32 | nn.ReLU(), 33 | nn.Conv2d(32, 1, kernel_size=5, padding=2)) 34 | 35 | self.poking_head = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1), 36 | nn.BatchNorm2d(64), 37 | nn.ReLU(), 38 | nn.Conv2d(64, 1 + self.config.uncertainty, kernel_size=1)) 39 | 40 | self.roi_module = RoIModule(model_config.roi_config) 41 | 42 | self.poking_grid = [(i, j) for i in range(global_config.grid_size) for j in range(global_config.grid_size)] 43 | 44 | def forward(self, images: torch.tensor, targets=None): 45 | 46 | y, x3 = self.backbone(images) 47 | 48 | scores_and_regression = self.anchor_box_scoring_head(x3) 49 | shape = tuple(scores_and_regression.shape) 50 | shape = (shape[0], -1, 1 + 4 * self.config.regression) + shape[2:] 51 | scores_and_regression = scores_and_regression.view(*shape) 52 | shape = tuple(scores_and_regression.shape) 53 | anchor_box_scores = scores_and_regression[:, :, 0] 54 | if self.config.regression: 55 | anchor_box_regression = scores_and_regression[:, :, 1:].transpose(2, 3).transpose(3, 4).contiguous() 56 | else: 57 | anchor_box_regression = torch.zeros(shape[:2] + shape[3:] + (4,), dtype=torch.float32).to(y.device) 58 | poking_scores = self.poking_head(y) 59 | 60 | anchors = self.roi_module.make_and_regress_anchors(anchor_box_regression) 61 | 62 | selected_anchors, selected_scores = self.roi_module.select_anchors(anchor_box_scores, anchors) 63 | 64 | masked_features = self.mask_features(y, selected_anchors) 65 | 66 | masks = self.masking_head(masked_features).reshape(y.shape[0], -1, 67 | global_config.grid_size, global_config.grid_size) 68 | 69 | return poking_scores, masks, anchor_box_scores, anchor_box_regression, selected_scores, selected_anchors 70 | 71 | def compute_actions(self, images: torch.tensor, num_pokes: int, episode: int, episodes: int): 72 | obj_pokes, random_pokes = num_pokes // 2, num_pokes - num_pokes // 2 73 | actions = [] 74 | 75 | with torch.no_grad(): 76 | out = self.forward(images) 77 | poking_scores, _, _, _, _, selected_anchors = out 78 | poking_scores = poking_scores[:, self.config.uncertainty] 79 | for anchors, scores in zip(selected_anchors, poking_scores): 80 | action = [] 81 | scores = scores.sigmoid() 82 | for i in range(obj_pokes): 83 | argmax = (scores * self.make_anchor_mask(anchors[i])).argmax().item() 84 | action.append(dict(point=(argmax // global_config.grid_size, argmax % global_config.grid_size))) 85 | action += sample(self.poking_grid, random_pokes) 86 | actions.append(action) 87 | return actions, out 88 | 89 | def compute_masks(self, images: torch.tensor, threshold: float): 90 | pred_masks, pred_scores, actions = [], [], [] 91 | self.eval() 92 | with torch.no_grad(): 93 | out = self.forward(images) 94 | poking_scores, masks, _, _, selected_scores, selected_anchors = out 95 | poking_scores = poking_scores[:, 0] 96 | 97 | for mask, scores, anchors, poking_score in zip(masks, selected_scores, selected_anchors, poking_scores): 98 | new_masks, new_scores, new_actions = [], [], [] 99 | poking_score = poking_score.sigmoid() 100 | for m, s, a in zip(mask, scores, anchors): 101 | if s < threshold: 102 | break 103 | new_masks.append((m > 0).cpu().numpy()) 104 | new_scores.append(s.sigmoid().item()) 105 | argmax = (self.make_anchor_mask(a) * poking_score).argmax().item() 106 | new_actions.append(dict(point=(argmax // global_config.grid_size, 107 | argmax % global_config.grid_size))) 108 | pred_masks.append(new_masks) 109 | pred_scores.append(new_scores) 110 | actions.append(new_actions) 111 | return actions, pred_masks, out, pred_scores 112 | 113 | def mask_features(self, features, anchors): 114 | anchor_mask = self.make_anchor_mask(anchors) 115 | return (features.unsqueeze(1) * anchor_mask.unsqueeze(2)).view(-1, features.shape[1], 116 | global_config.grid_size, 117 | global_config.grid_size) 118 | 119 | @staticmethod 120 | def make_anchor_mask(anchors): 121 | shape = tuple(anchors.shape[:-1]) + (global_config.grid_size, global_config.grid_size) 122 | anchors = anchors.view(-1, 4) 123 | anchor_mask = torch.zeros((anchors.shape[0], global_config.grid_size, global_config.grid_size), 124 | dtype=torch.bool).to(anchors.device) 125 | for i in range(anchors.shape[0]): 126 | anchor_mask[i, anchors[i, 0]:anchors[i, 2], anchors[i, 1]:anchors[i, 3]] = True 127 | 128 | return anchor_mask.view(*shape) 129 | 130 | 131 | class RoIModule(nn.Module): 132 | def __init__(self, model_config: ROIModuleConfig): 133 | super(RoIModule, self).__init__() 134 | self.config = model_config 135 | self.clamp = lambda x: min(max(x, 0), global_config.grid_size - 1) 136 | self.coarse_grid_stride = global_config.grid_size // self.config.coarse_grid_size 137 | 138 | self._init_anchors() 139 | 140 | def _init_anchors(self): 141 | boxes = torch.from_numpy(self.config.boxes).int() 142 | self.register_buffer('anchor_offsets_and_sizes', boxes.unsqueeze(1).unsqueeze(2)) 143 | anchor_sizes = boxes[:, 2:].repeat(1, 2).contiguous().unsqueeze(1).unsqueeze(2).float() 144 | self.register_buffer('anchor_sizes', anchor_sizes) 145 | positive_thresholds = torch.tensor(self.config.positive_thresholds) 146 | negative_thresholds = torch.tensor(self.config.negative_thresholds) 147 | self.register_buffer('positive_thresholds', 148 | positive_thresholds.unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4)) 149 | self.register_buffer('negative_thresholds', 150 | negative_thresholds.unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4)) 151 | coarse_grid_coordinates = torch.tensor([[[x * self.coarse_grid_stride, y * self.coarse_grid_stride] 152 | for y in range(self.config.coarse_grid_size)] 153 | for x in range(self.config.coarse_grid_size)]).int() 154 | self.register_buffer('coarse_grid_coordinates', coarse_grid_coordinates.contiguous().unsqueeze(0)) 155 | anchors = self.make_regressed_anchors(torch.zeros(self.config.num_anchors, self.config.coarse_grid_size, 156 | self.config.coarse_grid_size, 4)) 157 | self.register_buffer('anchor_boxes', anchors) 158 | 159 | def make_regressed_anchors(self, regression_logits, delta=(0, 0, 0, 0)): 160 | with torch.no_grad(): 161 | regression = (regression_logits.sigmoid() * self.anchor_sizes / 3).int() + self.anchor_offsets_and_sizes 162 | delta = torch.tensor(list(delta)).unsqueeze(0).unsqueeze(1).unsqueeze(2).to(regression.device) 163 | anchors = regression + delta 164 | anchors[:, :, :, :2] = anchors[:, :, :, :2] + self.coarse_grid_coordinates 165 | anchors[:, :, :, 2:] = anchors[:, :, :, :2] + anchors[:, :, :, 2:] 166 | return anchors.clamp(min=0, max=global_config.grid_size) 167 | 168 | def select_anchors(self, anchor_box_scores: torch.tensor, anchors=None): 169 | stride1 = self.config.coarse_grid_size ** 2 170 | stride2 = self.config.coarse_grid_size 171 | if anchors is None: 172 | anchors = [None] * anchor_box_scores.shape[0] 173 | selected_anchors = [] 174 | selected_scores = [] 175 | for scores, anchs in zip(anchor_box_scores, anchors): 176 | scores_numpy = scores.detach().cpu().numpy() 177 | indices = [] 178 | for _ in range(self.config.num_rois): 179 | ind = scores_numpy.argmax() 180 | ind = (ind // stride1, (ind % stride1) // stride2, ind % stride2) 181 | if anchs is not None: 182 | with torch.no_grad(): 183 | mask = self.box_iou_mask(anchs[ind[0], ind[1], ind[2]], anchs) 184 | scores_numpy[mask] = scores_numpy[mask] - 5 185 | else: 186 | scores_numpy[ind[0], ind[1], ind[2]] = - 1000 187 | indices.append(ind) 188 | 189 | selected_anchors.append(torch.stack([anchs[i[0], i[1], i[2]] for i in reversed(indices)])) 190 | selected_scores.append(torch.stack([scores[i[0], i[1], i[2]] for i in reversed(indices)])) 191 | 192 | return torch.stack(selected_anchors), torch.stack(selected_scores) 193 | 194 | def box_iou_mask(self, box, boxes): 195 | with torch.no_grad(): 196 | area_box = (box[2] - box[0]) * (box[3] - box[1]) 197 | area_boxes = (boxes[:, :, :, 2] - boxes[:, :, :, 0]) * (boxes[:, :, :, 3] - boxes[:, :, :, 1]) 198 | min_xy = torch.max(box[:2], boxes[:, :, :, :2]) 199 | max_xy = torch.min(box[2:], boxes[:, :, :, 2:]) 200 | diff = (max_xy - min_xy).clamp(min=0) 201 | intersections = (diff[:, :, :, 0] * diff[:, :, :, 1]).float() 202 | ious = intersections / (area_box + area_boxes - intersections).clamp(min=1) 203 | mask = (ious > self.config.nms_threshold).cpu().numpy() 204 | return mask 205 | 206 | def refine_indices(self, masks, indices, poking_locations): 207 | ious, intersections = self.compute_anchor_targets(masks, poking_locations) 208 | with torch.no_grad(): 209 | positive_anchors = ((ious / self.positive_thresholds) > 1).sum(dim=1) > 0 210 | positive_anchors = positive_anchors * (intersections > self.config.poking_filter_threshold) 211 | positive_indices = torch.nonzero(positive_anchors) 212 | refined_indices = [index for index in indices if index in positive_indices][-self.config.num_rois:] 213 | if len(refined_indices) < self.config.num_rois: 214 | refined_indices = [index for index in indices if index not in 215 | positive_indices][-(self.config.num_rois - len(refined_indices)):] + \ 216 | refined_indices 217 | return refined_indices 218 | 219 | def compute_ious(self, masks, poking_locations=None, anchor_boxes=None): 220 | if anchor_boxes is None: 221 | anchor_boxes = self.anchor_boxes.unsqueeze(0) 222 | with torch.no_grad(): 223 | size = tuple(masks.shape[:2]) + tuple(anchor_boxes.shape[1:-1]) 224 | space = tuple(masks.shape[-2:]) 225 | for _ in range(len(anchor_boxes.shape[1:-1])): 226 | masks = masks.unsqueeze(2) 227 | masks = masks.expand(*(size + space)) 228 | anchor_boxes_unsqueeze = anchor_boxes.unsqueeze(1).expand(*(size + (4,))) 229 | mask_areas = masks[..., -1, -1] 230 | box_areas = (anchor_boxes_unsqueeze[..., 2] - anchor_boxes_unsqueeze[..., 0]) * \ 231 | (anchor_boxes_unsqueeze[..., 3] - anchor_boxes_unsqueeze[..., 1]) 232 | anch0 = anchor_boxes_unsqueeze[..., 0].reshape(-1) 233 | anch1 = anchor_boxes_unsqueeze[..., 1].reshape(-1) 234 | anch2 = anchor_boxes_unsqueeze[..., 2].reshape(-1) 235 | anch3 = anchor_boxes_unsqueeze[..., 3].reshape(-1) 236 | flat_masks = masks.reshape(*((-1,) + space)) 237 | m = torch.arange(anch0.shape[0]) 238 | int0 = flat_masks[m, anch0, anch1] 239 | int1 = flat_masks[m, anch2, anch3] 240 | int2 = flat_masks[m, anch0, anch3] 241 | int3 = flat_masks[m, anch2, anch1] 242 | intersections = (int0 + int1 - int2 - int3).view(*size).float() 243 | ious = intersections / (mask_areas + box_areas - intersections).clamp(min=1) 244 | 245 | intersections = None 246 | 247 | if poking_locations is not None: 248 | size = tuple(anchor_boxes.shape[:-1]) 249 | poking_locations = poking_locations.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(*(size + space)) 250 | anch0 = anchor_boxes[..., 0].reshape(-1) 251 | anch1 = anchor_boxes[..., 1].reshape(-1) 252 | anch2 = anchor_boxes[..., 2].reshape(-1) 253 | anch3 = anchor_boxes[..., 3].reshape(-1) 254 | flat_poke = poking_locations.reshape(*((-1,) + space)) 255 | m = torch.arange(anch0.shape[0]) 256 | int0 = flat_poke[m, anch0, anch1] 257 | int1 = flat_poke[m, anch2, anch3] 258 | int2 = flat_poke[m, anch0, anch3] 259 | int3 = flat_poke[m, anch2, anch1] 260 | intersections = (int0 + int1 - int2 - int3).view(*size).float() 261 | 262 | return ious, intersections 263 | 264 | def make_and_regress_anchors(self, regressions): 265 | anchors = [] 266 | for regression_logits in regressions: 267 | anchors.append(self.make_regressed_anchors(regression_logits)) 268 | return torch.stack(anchors) 269 | 270 | def match_anchors(self, regressed_anchors, regressions, masks, poking_locations): 271 | 272 | ious, intersections = self.compute_ious(masks.unsqueeze(0), 273 | poking_locations.unsqueeze(0), 274 | regressed_anchors[0].unsqueeze(0)) 275 | ious, intersections = ious.squeeze(0), intersections.squeeze(0) 276 | 277 | with torch.no_grad(): 278 | positives = ((ious / self.positive_thresholds) > 1).sum(dim=1) > 0 279 | negatives = ((ious / self.negative_thresholds) > 1).sum(dim=1) == 0 280 | positives = positives * (intersections > self.config.poking_filter_threshold) 281 | negatives = negatives * (intersections > self.config.poking_filter_threshold) 282 | positives, negatives = positives.squeeze(0), negatives.squeeze(0) 283 | anchor_targets = (positives, negatives) 284 | 285 | matches = None 286 | 287 | if anchor_targets[0].sum() > 0: 288 | positive_regressions = regressions[positives] 289 | ious = ious[:, positives] 290 | matched_masks = self.match_masks(ious, masks) 291 | positive_boxes = regressed_anchors[:, positives] 292 | matches = (positive_regressions, positive_boxes, matched_masks) 293 | 294 | return anchor_targets, matches 295 | 296 | @staticmethod 297 | def match_masks(ious, masks): 298 | mask_shape = len(masks.shape[:-2]) 299 | resolution = masks.shape[-1] 300 | anchor_shape = tuple(ious.shape[mask_shape:]) 301 | ious = ious.view(*((-1,) + anchor_shape)) 302 | masks = masks.view(-1, resolution, resolution) 303 | matches = ious.argmax(0).view(-1) 304 | return masks[matches].view(*(anchor_shape + (resolution, resolution))) 305 | -------------------------------------------------------------------------------- /source/pipeline/evaluator.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime 3 | import json 4 | from collections import OrderedDict 5 | import os 6 | from multiprocessing import Process, Pipe 7 | 8 | from ai2thor.controller import Controller 9 | from pycocotools.cocoeval import COCOeval 10 | from pycocotools.coco import COCO 11 | import numpy as np 12 | import torch 13 | 14 | from pipeline.tester import Tester 15 | from config import global_config, actor_config 16 | from tools.logger import LOGGER, init_logging 17 | from tools.data_utils import ActiveDataset 18 | 19 | 20 | class Evaluator(Tester): 21 | @staticmethod 22 | def save_predictions_to_json(path: str, predictions: tuple, image_ids: list = None, 23 | interactable_classes: int or list = 0): 24 | data = [] 25 | images = [] 26 | k = 0 27 | save_time = time.time() 28 | last_eta = save_time 29 | nitems = len(predictions[0]) 30 | 31 | for i, preds in enumerate(zip(*predictions)): 32 | imid = i if image_ids is None else image_ids[i] 33 | d = dict(width=global_config.resolution, height=global_config.resolution, id=imid) 34 | images.append(d) 35 | for mask, score, action in zip(*preds): 36 | if isinstance(interactable_classes, list): 37 | current_category = interactable_classes[action["force"].item()] 38 | else: 39 | current_category = interactable_classes 40 | k += 1 41 | segmentation, bbox, ar = Evaluator.compute_annotations( 42 | Evaluator.upsample(mask), encoding='utf-8', transpose=False 43 | ) 44 | d = dict(image_id=imid, category_id=current_category, score=score, 45 | segmentation=segmentation, bbox=bbox) 46 | data.append(d) 47 | new_time = time.time() 48 | if new_time - last_eta >= 10: 49 | curtime = new_time - save_time 50 | eta = curtime / (i + 1) * (nitems - i - 1) 51 | LOGGER.info("save {}/{} spent {} s ETA {} s".format(i + 1, nitems, curtime, eta)) 52 | last_eta = new_time 53 | 54 | with open(path, 'w') as file: 55 | json.dump(dict(annotations=data, images=images), file) 56 | 57 | @staticmethod 58 | def run_load(dataset_path, memory, conn): 59 | init_logging() 60 | controller = Controller( 61 | x_display='0.%d' % global_config.actor_gpu, 62 | visibilityDistance=actor_config.visibilityDistance, 63 | renderDepthImage=global_config.depth 64 | ) 65 | dataset = ActiveDataset(dataset_path, memory, controller, conn) 66 | dataset.process() 67 | controller.stop() 68 | 69 | def push_pull_items(self, begin_next, loaders, ndata): 70 | begin_item, next_item = begin_next 71 | last_item = next_item - 1 72 | 73 | # Preload data 74 | for last_item in range(next_item, min(next_item + self.config.bs, ndata)): 75 | loaders[last_item % len(loaders)][1].send(last_item) 76 | 77 | # Read former data 78 | ims, names = [], [] 79 | if begin_item >= 0: 80 | for read_item in range(begin_item, min(begin_item + self.config.bs, ndata)): 81 | im, name = loaders[read_item % len(loaders)][1].recv() 82 | ims.append(im) 83 | names.append(name) 84 | 85 | # Update pointers 86 | begin_item, next_item = next_item, last_item + 1 87 | 88 | return (begin_item, next_item), ims, names 89 | 90 | def dataset_forward(self, dataset_path: str, interaction_thres: float): 91 | loaders = [] 92 | try: 93 | for c in range(self.config.num_actors): 94 | parent, child = Pipe() 95 | proc = Process(target=Evaluator.run_load, args=(dataset_path, self.memory, child)) 96 | proc.start() 97 | loaders.append((proc, parent)) 98 | dataset = ActiveDataset(dataset_path, self.memory, None, None) 99 | ndata = len(dataset) 100 | 101 | begin_next, _, _ = self.push_pull_items((-1, 0), loaders, ndata) 102 | 103 | image_ids = [] 104 | masks, scores, actions = [], [], [] 105 | eval_time = time.time() 106 | last_eta = eval_time 107 | while begin_next[0] < ndata: 108 | begin_next, ims, names = self.push_pull_items(begin_next, loaders, ndata) 109 | 110 | batch_torch = torch.stack(ims).cuda(self.model_device) 111 | image_ids.extend(names) 112 | 113 | new_actions, new_masks, _, new_scores = self.model.compute_masks(batch_torch, interaction_thres) 114 | 115 | masks += new_masks 116 | scores += new_scores 117 | actions += new_actions 118 | 119 | new_time = time.time() 120 | if new_time - last_eta >= 10: 121 | curtime = new_time - eval_time 122 | nepisodes = begin_next[0] 123 | eta = curtime / nepisodes * (ndata - nepisodes) 124 | LOGGER.info("forward {}/{} spent {} s ETA {} s".format(nepisodes, ndata, curtime, eta)) 125 | last_eta = new_time 126 | finally: 127 | for it, (loader, conn) in enumerate(loaders): 128 | conn.send(None) 129 | for it, (loader, conn) in enumerate(loaders): 130 | LOGGER.info("Joining loader {}".format(it)) 131 | loader.join() 132 | 133 | return (masks, scores, actions), image_ids 134 | 135 | @staticmethod 136 | def coco_eval(coco_gt: COCO, coco_dt: COCO, annotation_type: str, use_categories: bool = True): 137 | coco_eval = COCOeval(coco_gt, coco_dt, annotation_type) 138 | 139 | if not use_categories: 140 | coco_eval.params.useCats = 0 141 | 142 | coco_eval.evaluate() 143 | coco_eval.accumulate() 144 | coco_eval.summarize() 145 | 146 | metrics = ["AP", "AP50", "AP75", "APs", "APm", "APl", "AR1", "AR10", "AR100", "ARs", "ARm", "ARl"] 147 | results = OrderedDict( 148 | (metric, float(coco_eval.stats[idx] if coco_eval.stats[idx] >= 0 else "nan")) 149 | for idx, metric in enumerate(metrics) 150 | ) 151 | 152 | return results 153 | 154 | def inference(self, dataset_path: str, output_folder: str, interaction_thres: float, id: int or str = 0, 155 | interactable_classes: int or list = 0): 156 | date_time = datetime.now().strftime("%Y%m%d%H%M%S") 157 | data_basename = os.path.basename(dataset_path) 158 | output_path = os.path.join(output_folder, data_basename.replace( 159 | ".json", "__{}__thres{}__{}__inf.json".format(id, interaction_thres, date_time) 160 | )) 161 | LOGGER.info("Using inference results path {}".format(output_path)) 162 | model_predictions, image_ids = self.dataset_forward(dataset_path, interaction_thres) 163 | self.save_predictions_to_json(output_path, model_predictions, image_ids, interactable_classes) 164 | 165 | return output_path 166 | 167 | @staticmethod 168 | def load_coco_detections(coco_gt, coco_dets_path: str, annotation_type='segm'): 169 | with open(coco_dets_path, "r") as f: 170 | raw_inf = json.load(f) 171 | 172 | if not isinstance(raw_inf, list): 173 | assert raw_inf["images"][0]["id"] in coco_gt.imgs,\ 174 | "ensure {} was generated with Evaluator's inference".format(coco_dets_path) 175 | raw_inf = raw_inf["annotations"] 176 | for entry in raw_inf: 177 | entry.pop("id", None) 178 | entry.pop("area", None) 179 | 180 | if annotation_type == 'segm': # pop bounding boxes to avoid using bbox areas 181 | for entry in raw_inf: 182 | entry.pop("bbox", None) 183 | 184 | # LOGGER.info("{} annotations in {}".format(len(raw_inf), coco_dets_path)) 185 | return coco_gt.loadRes(raw_inf) 186 | 187 | @staticmethod 188 | def results_string(res, labs=None): 189 | if labs is None: 190 | labs = res.keys() 191 | return ", ".join(["%s %4.2f%%" % (m, res[m] * 100) for m in labs]) 192 | 193 | @staticmethod 194 | def evaluate(coco_gt_path: str, coco_dets_path: str, annotation_type='bbox', labs=("AP50", "AP"), 195 | use_categories=False): 196 | if annotation_type == 'mass': 197 | return Evaluator.evaluate_mass(coco_gt_path, coco_dets_path) 198 | 199 | coco_gt = COCO(coco_gt_path) 200 | coco_dt = Evaluator.load_coco_detections(coco_gt, coco_dets_path, annotation_type) 201 | 202 | res = Evaluator.coco_eval(coco_gt, coco_dt, annotation_type, use_categories=use_categories) 203 | res_str = Evaluator.results_string(res, labs) 204 | LOGGER.info("RESULTS {} {} {}".format( 205 | coco_dets_path, 206 | annotation_type, 207 | res_str 208 | )) 209 | return {annotation_type: res} 210 | 211 | @staticmethod 212 | def confusion_matrix(coco_gt, coco_dt, nclasses=3): 213 | eval = COCOeval(coco_gt, coco_dt, "bbox") 214 | eval.params.useCats = 0 215 | eval.params.iouThrs = [0.5] 216 | eval.params.areaRng = eval.params.areaRng[:1] 217 | eval.params.areaRngLbl = eval.params.areaRngLbl[:1] 218 | eval.evaluate() 219 | 220 | conf_mat = np.zeros((nclasses, nclasses), dtype=np.int64) 221 | for it, im in enumerate(eval.evalImgs): 222 | if im is not None: 223 | for gt, dt in zip(im['gtIds'], list(im['gtMatches'][0].astype(np.int32))): 224 | if dt > 0: 225 | row = eval.cocoDt.anns[dt]["category_id"] 226 | col = eval.cocoGt.anns[gt]["category_id"] 227 | conf_mat[row, col] += 1 228 | 229 | return conf_mat 230 | 231 | @staticmethod 232 | def evaluate_mass(coco_gt_path: str, coco_dets_path: str): 233 | annotation_type = 'bbox' 234 | coco_gt = COCO(coco_gt_path) 235 | coco_dt = Evaluator.load_coco_detections(coco_gt, coco_dets_path, annotation_type) 236 | 237 | res = Evaluator.coco_eval(coco_gt, coco_dt, annotation_type, use_categories=True) 238 | 239 | conf_mat = Evaluator.confusion_matrix(coco_gt, coco_dt) 240 | accuracies = conf_mat.diagonal() / conf_mat.sum(axis=0) 241 | mean_accuracy = accuracies.mean().item() 242 | 243 | res.update({"accuracy": mean_accuracy}) 244 | 245 | res_str = Evaluator.results_string(res, labs=("AP50", "accuracy")) 246 | LOGGER.info("RESULTS {} {} {}".format( 247 | coco_dets_path, 248 | "mass", 249 | res_str 250 | )) 251 | 252 | return {"mass": res} 253 | -------------------------------------------------------------------------------- /source/pipeline/tester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | import torch 6 | from copy import deepcopy 7 | from PIL.ImageDraw import Draw 8 | from PIL.Image import fromarray 9 | from pycocotools.mask import encode, area, toBbox 10 | from pycocotools.cocoeval import COCOeval 11 | from pycocotools.coco import COCO 12 | 13 | from config import TestingConfig, global_config 14 | from models.model import Model 15 | from replay_memory.replay_memory import Memory 16 | from pipeline.actor import ActorPool, actor_config 17 | from losses.losses import LossFunction 18 | 19 | 20 | class Tester: 21 | """ 22 | The main testing / illustration pipeline. 23 | """ 24 | 25 | def __init__(self, model: Model, memory: Memory, loss_function: LossFunction, tester_config: TestingConfig): 26 | self.config = tester_config 27 | 28 | self.model_device = torch.cuda.current_device() if global_config.distributed else global_config.model_gpu 29 | 30 | self.model = model 31 | self.model.eval() 32 | 33 | self.loss_function = loss_function 34 | self.memory = memory 35 | 36 | self.actors = None 37 | 38 | def evaluate_coco_metrics(self, dataset: str or int, threshold: float, annotation_type='segm', iou=.5): # or 'bbox' 39 | if type(dataset) == int: 40 | dataset = self.make_eval_dataset(dataset) 41 | path = './tmpdataset%d' % np.random.randint(0, 10 ** 10, 1) 42 | self.save_pycoco_compatible_json(path + 'no_mass.json', dataset, use_mass=0) 43 | self.save_pycoco_compatible_json(path + 'mass.json', dataset, use_mass=1) 44 | gt = True 45 | else: 46 | path = dataset 47 | dataset = self.load_dataset(path) 48 | gt = False 49 | 50 | model_predictions = self.predict_masks(dataset[0], threshold) 51 | inference_path = path + 'inf' 52 | self.save_predictions_to_json(inference_path + 'no_mass.json', model_predictions, use_mass=0) 53 | self.save_predictions_to_json(inference_path + 'mass.json', model_predictions, use_mass=1) 54 | 55 | confusion_matrix, corcoef = self.get_confusion_matrix(model_predictions, dataset) 56 | print(corcoef) 57 | print(confusion_matrix) 58 | confusion_matrix = confusion_matrix / confusion_matrix.sum() 59 | print(confusion_matrix) 60 | marginal = confusion_matrix.sum(axis=0) 61 | accuracy = sum(confusion_matrix[i, i] for i in range(3)) 62 | class_wise_accuracies = [confusion_matrix[i, i] / confusion_matrix[:, i].sum() for i in range(3)] 63 | print(accuracy) 64 | print(marginal) 65 | print(class_wise_accuracies) 66 | print(sum(class_wise_accuracies) / 3) 67 | 68 | coco_gt_no_mass = COCO(path + 'no_mass.json') 69 | coco_gt_mass = COCO(path + 'mass.json') 70 | coco_inf_no_mass = COCO(inference_path + 'no_mass.json') 71 | coco_inf_mass = COCO(inference_path + 'mass.json') 72 | coco_eval_no_mass = COCOeval(coco_gt_no_mass, coco_inf_no_mass, annotation_type) 73 | coco_eval_mass = COCOeval(coco_gt_mass, coco_inf_mass, annotation_type) 74 | coco_eval_no_mass.params.iouThrs = np.array([iou] + list(coco_eval_no_mass.params.iouThrs)[1:]) 75 | coco_eval_mass.params.iouThrs = np.array([iou] + list(coco_eval_mass.params.iouThrs)[1:]) 76 | coco_eval_no_mass.evaluate() 77 | coco_eval_mass.evaluate() 78 | coco_eval_no_mass.accumulate() 79 | coco_eval_mass.accumulate() 80 | pr_curve_no_mass = coco_eval_no_mass.eval['precision'][0, :, 0, 0, -1] 81 | pr_curve_mass = coco_eval_mass.eval['precision'][0, :, 0, 0, -1] 82 | coco_eval_no_mass.summarize() 83 | coco_eval_mass.summarize() 84 | 85 | os.remove(inference_path + 'no_mass.json') 86 | os.remove(inference_path + 'mass.json') 87 | if gt: 88 | os.remove(path + 'no_mass.json') 89 | os.remove(path + 'mass.json') 90 | print(pr_curve_no_mass) 91 | print(pr_curve_mass) 92 | return pr_curve_no_mass 93 | 94 | def make_eval_dataset(self, dataset_size): 95 | self.actors = ActorPool(self.model, self.loss_function, self.memory, self.config.num_actors, 5) 96 | self.actors.start() 97 | images, metadata = self.actors.make_batch(dataset_size, None, None, None) 98 | self.actors.stop() 99 | masks, metadata = [m[0][0] for m in metadata], [(m[0][1:],) + m[1:] for m in metadata] 100 | return images, masks, metadata 101 | 102 | def predict_masks(self, images: list, threshold: float): 103 | batches = [images[i:i + self.config.bs] for i in range(0, len(images), self.config.bs)] 104 | masks, scores, masses = [], [], [] 105 | for batch in batches: 106 | batch_torch = torch.stack([self.memory.base_image_transform(image) 107 | for image in batch]).cuda(self.model_device) 108 | actions, new_masks, _, new_scores = self.model.compute_masks(batch_torch, threshold) 109 | masks += new_masks 110 | scores += new_scores 111 | masses += [[a['force'] for a in action] for action in actions] 112 | return masks, scores, masses 113 | 114 | def save_predictions_to_json(self, path: str, predictions: tuple, use_mass=0): 115 | data = [] 116 | images = [] 117 | predicted_masses = [0, 0, 0] 118 | k = 0 119 | for i, (masks, scores, masses) in enumerate(zip(*predictions)): 120 | d = dict(width=global_config.resolution, height=global_config.resolution, id=i) 121 | images.append(d) 122 | for mask, score, mass in zip(masks, scores, masses): 123 | assert mass in [0, 1, 2], 'masses wrong in predictions' 124 | predicted_masses[mass] += 1 125 | k += 1 126 | segmentation, bbox, ar = self.compute_annotations(self.upsample(mask)) 127 | d = dict(image_id=i, category_id=int(use_mass * mass + 1), score=score, 128 | segmentation=segmentation, bbox=bbox, area=ar, id=k) 129 | data.append(d) 130 | 131 | file = open(path, 'w') 132 | json.dump(dict(annotations=data, images=images), file) 133 | file.close() 134 | print(predicted_masses) 135 | 136 | def save_pycoco_compatible_json(self, path: str, dataset: tuple, use_mass=0): 137 | coco_dataset = dict() 138 | coco_dataset['info'] = {'dataset_size': len(dataset), 139 | 'depth': global_config.depth, 140 | 'mass_threshold': actor_config.mass_threshold, 141 | 'min_pixel_threshold': actor_config.min_pixel_threshold, 142 | 'max_pixel_threshold': actor_config.max_pixel_threshold} 143 | coco_dataset['licenses'] = {} 144 | if use_mass: 145 | coco_dataset['categories'] = [{'supercategory': 'object', 146 | 'id': 1, 147 | 'name': 'light'}, 148 | {'supercategory': 'object', 149 | 'id': 2, 150 | 'name': 'medium'}, 151 | {'supercategory': 'object', 152 | 'id': 3, 153 | 'name': 'heavy'} 154 | ] 155 | else: 156 | coco_dataset['categories'] = [{'supercategory': 'object', 157 | 'id': 1, 158 | 'name': 'object'}] 159 | coco_dataset['images'] = [] 160 | coco_dataset['annotations'] = [] 161 | k = 0 162 | instance_areas = [] 163 | for i, (_, masks, metadata) in enumerate(zip(*dataset)): 164 | coco_dataset['images'].append( 165 | {'width': global_config.resolution, 'height': global_config.resolution, 'id': i}) 166 | masses = metadata[0][1] 167 | for instance_mask, mass in zip(masks, masses): 168 | assert mass in [0, 1, 2], 'masses from metadata wrong' 169 | k += 1 170 | segmentation, bbox, ar = self.compute_annotations(instance_mask) 171 | instance_areas.append(ar) 172 | coco_dataset['annotations'].append( 173 | {'iscrowd': 0, 'image_id': i, 'id': k, 'category_id': use_mass * mass + 1, 174 | 'area': ar, 'segmentation': segmentation, 'bbox': bbox}) 175 | 176 | file = open(path, 'w') 177 | json.dump(coco_dataset, file) 178 | file.close() 179 | print('NUMBER OF INSTANCES IN DATASET: %d' % k) 180 | mean_area = sum(instance_areas) / max(len(instance_areas), 1) 181 | print('MEAN AREA OF INSTANCE IN DATASET: %f' % mean_area) 182 | 183 | def make_and_save_dataset(self, dataset_size, path): 184 | dataset = self.make_eval_dataset(dataset_size) 185 | self.save_data(dataset, path + '.pickle') 186 | self.save_pycoco_compatible_json(path + 'no_mass.json', dataset, use_mass=0) 187 | self.save_pycoco_compatible_json(path + 'mass.json', dataset, use_mass=1) 188 | 189 | @staticmethod 190 | def save_data(dataset, path): 191 | file = open(path, 'wb') 192 | pickle.dump(dataset, file) 193 | file.close() 194 | 195 | @staticmethod 196 | def load_dataset(path): 197 | return pickle.load(open(path + '.pickle', 'rb')) 198 | 199 | @staticmethod 200 | def compute_annotations(instance, encoding='ascii', transpose=True): 201 | if transpose: 202 | segmentation = encode(instance.astype(np.uint8).T) 203 | else: 204 | segmentation = encode(np.asarray(instance.astype(np.uint8), order="F")) 205 | segmentation['counts'] = segmentation['counts'].decode(encoding) 206 | bbox = list(toBbox(segmentation)) 207 | ar = int(area(segmentation)) 208 | return segmentation, bbox, ar 209 | 210 | @staticmethod 211 | def upsample(mask): 212 | return mask.repeat(global_config.stride, axis=0).repeat(global_config.stride, axis=1) 213 | 214 | def illustrate_poking_and_predictions(self, path: str, num_images: int, threshold: float): 215 | """ 216 | :param path: The path to the folder where illustrations are to be saved 217 | :param num_images: Number of images to illustrate 218 | :param threshold: The value of the hyperparameter that controls the number of object proposals of the model. 219 | :return: Nothing. It fills the folder with images 220 | 221 | Attention: The code clears the current contents of the folder! 222 | 223 | This only works for models that provide pixel-wise predictions. 224 | 225 | The illustrations include: 226 | a) Image with GT overlay 227 | b) Image with predicted masks and masses overlay. Masses are red=light, green=medium, blue=heavy 228 | c) interaction score heat map 229 | d) mass logits map (same color coding as above) 230 | e) self supervised ground truth 231 | f) raw image 232 | g) depth image 233 | 234 | """ 235 | if path[-1] != '/': 236 | path = path + '/' 237 | if os.path.isdir(path): 238 | for file in os.listdir(path): 239 | if file[-1] == 'g': 240 | os.remove(path + file) 241 | else: 242 | os.mkdir(path) 243 | rgb = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] 244 | 245 | self.actors = ActorPool(self.model, self.loss_function, self.memory, self.config.num_actors, 4) 246 | self.actors.start() 247 | batch, metadatas, predictions = self.actors.make_batch(num_images, None, None, None, threshold=threshold) 248 | self.actors.stop() 249 | 250 | images_and_depths, targets = batch[:2] 251 | images = [image[0] if global_config.depth else image for image in images_and_depths] 252 | depths = [image[1] if global_config.depth else None for image in images_and_depths] 253 | targets = [target[0] for target in targets] 254 | gt_masks = [metadata[0][0] for metadata in metadatas] 255 | 256 | i = 0 257 | for image, depth, target, p, gtm in zip(images, depths, targets, predictions, gt_masks): 258 | objectness, masses, pred_mask, action = p['predictions'][0], p['predictions'][1], p['masks'], p['action'] 259 | if depth is not None: 260 | depth = ((1 - depth / actor_config.handDistance).clip(min=0, max=1) * 255).astype(np.uint8) 261 | depth = fromarray(depth) 262 | image = fromarray(image) 263 | raw_image = deepcopy(image) 264 | image = image.convert('LA').convert('RGB') 265 | im_heatmap = deepcopy(image) 266 | im_massmap = deepcopy(image) 267 | im_poking_mask = deepcopy(image) 268 | im_gt_mask = deepcopy(image) 269 | draw = Draw(image) 270 | draw_heatmap = Draw(im_heatmap) 271 | draw_massmap = Draw(im_massmap) 272 | draw_poking_mask = Draw(im_poking_mask) 273 | 274 | poking_masks = [t for t in target if t.sum() > 0] 275 | 276 | masses = masses.argmax(dim=0).detach().cpu().numpy() 277 | objectness = objectness[0].sigmoid() 278 | 279 | for mask, color in zip(gtm, self.config.colors): 280 | mask = fromarray(mask.astype(np.uint8) * 170) 281 | im_gt_mask.paste(color, (0, 0), mask) 282 | 283 | for x in range(global_config.grid_size): 284 | for y in range(global_config.grid_size): 285 | corners = self.corners(x, y) 286 | 287 | draw_heatmap.rectangle(corners[0], outline=(0, int(255 * objectness[x, y]), 0)) 288 | draw_massmap.rectangle(corners[0], outline=rgb[masses[x, y]]) 289 | 290 | for color, poking_mask in zip(self.config.colors, poking_masks): 291 | if poking_mask[x, y]: 292 | draw_poking_mask.rectangle(corners[0], outline=color) 293 | 294 | for color, m in zip(self.config.colors, pred_mask): 295 | if m[x, y]: 296 | draw.rectangle(corners[0], outline=color) 297 | 298 | for ac in action: 299 | point, force = ac['point'], ac['force'] 300 | corners = self.corners(*point) 301 | draw_poking_mask.rectangle(corners[1], outline=(0, 255, 0), fill=(0, 255, 0)) 302 | draw.rectangle(corners[1], outline=rgb[force], fill=rgb[force]) 303 | 304 | i += 1 305 | raw_image.save(path + '%d_f.png' % i) 306 | im_gt_mask.save(path + '%d_a.png' % i) 307 | image.save(path + '%d_b.png' % i) 308 | im_heatmap.save(path + '%d_c.png' % i) 309 | im_massmap.save(path + '%d_d.png' % i) 310 | im_poking_mask.save(path + '%d_e.png' % i) 311 | if depth is not None: 312 | depth.save(path + '%d_g.png' % i) 313 | del draw 314 | del draw_heatmap 315 | del draw_massmap 316 | del draw_poking_mask 317 | 318 | @staticmethod 319 | def corners(x, y): 320 | x, y = y, x 321 | return [x * global_config.stride, y * global_config.stride, 322 | (x + 1) * global_config.stride, (y + 1) * global_config.stride], \ 323 | [x * global_config.stride - global_config.stride // 3, 324 | y * global_config.stride - global_config.stride // 3, 325 | (x + 1) * global_config.stride + global_config.stride // 3, 326 | (y + 1) * global_config.stride + global_config.stride // 3] 327 | 328 | def get_confusion_matrix(self, predictions, dataset): 329 | cm = np.zeros((3, 3)) 330 | preds, matched_gts = [], [] 331 | gtms = [md[0][1] for md in dataset[2]] 332 | for pred_masks, pred_masses, gt_masks, gt_masses in zip(predictions[0], predictions[2], dataset[1], gtms): 333 | if len(pred_masks) == 0 or len(gt_masks) == 0: 334 | continue 335 | gt_masks = np.stack(gt_masks) 336 | assert gt_masks.shape[0] == len(gt_masses), 'different num gt_masses vs gt_masks' 337 | assert len(pred_masks) == len(pred_masses), 'different num pred_masses vs pred_masks' 338 | for mask, mass in zip(pred_masks, pred_masses): 339 | mask = self.upsample(mask)[None, ...] 340 | intersections = (mask * gt_masks).sum(axis=(1, 2)) 341 | unions = (mask | gt_masks).sum(axis=(1, 2)).clip(min=1) 342 | ious = intersections.astype(np.float32) / unions 343 | best_match = ious.argmax() 344 | if ious[best_match] >= 0.5: 345 | cm[mass, gt_masses[best_match]] += 1 346 | preds.append(mass) 347 | matched_gts.append(gt_masses[best_match]) 348 | return cm, np.corrcoef(np.array([preds, matched_gts])) 349 | -------------------------------------------------------------------------------- /source/pipeline/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from time import time 4 | import os 5 | import logging 6 | 7 | from config import TrainerConfig, TestingConfig, global_config 8 | from models.model import Model 9 | from replay_memory.replay_memory import Memory 10 | from pipeline.actor import ActorPool 11 | from losses.losses import LossFunction 12 | from pipeline.tester import Tester 13 | 14 | 15 | class Trainer: 16 | """ 17 | The main training pipeline. 18 | """ 19 | def __init__(self, model: Model, memory: Memory, loss_function: LossFunction, 20 | trainer_config: TrainerConfig): 21 | if os.path.exists(trainer_config.log_path): 22 | os.remove(trainer_config.log_path) 23 | logging.basicConfig(filename=trainer_config.log_path, level=logging.DEBUG, 24 | format='%(asctime)s %(message)s', datefmt='%H:%M:%S') 25 | self.model_device = torch.cuda.current_device() if global_config.distributed else global_config.model_gpu 26 | 27 | self.config = trainer_config 28 | 29 | self.model = model 30 | 31 | self.memory = memory 32 | 33 | self.loss_function = loss_function 34 | 35 | self.optimizer = self.init_optimizer() 36 | 37 | self.actors = ActorPool(self.model, self.loss_function, self.memory, 38 | trainer_config.num_actors, trainer_config.ground_truth) 39 | 40 | self.model.eval() 41 | 42 | def train(self): 43 | 44 | if self.config.eval_during_train: 45 | whichval = global_config.val_scenes 46 | global_config.val_scenes = 1 47 | tester_val = Tester(self.model, self.memory, self.loss_function, TestingConfig()) 48 | tester_val_path = './tmpvaldata%d' % np.random.randint(0, 10 ** 10, 1) 49 | tester_val.make_and_save_dataset(500, tester_val_path) 50 | global_config.val_scenes = 0 51 | tester_train = Tester(self.model, self.memory, self.loss_function, TestingConfig()) 52 | tester_train_path = './tmptraindata%d' % np.random.randint(0, 10 ** 10, 1) 53 | tester_train.make_and_save_dataset(500, tester_train_path) 54 | global_config.val_scenes = whichval 55 | else: 56 | tester_val, tester_val_path = None, None 57 | tester_train, tester_train_path = None, None 58 | 59 | self.actors.start() 60 | evals = [] 61 | running_val_stats = [0] * (self.loss_function.loss_summary_length + 1) 62 | 63 | t_train = time() 64 | 65 | if self.config.prefill_memory > 0: 66 | print('prefilling memory') 67 | num_pokes = self.config.poking_schedule(1, self.config.episodes) 68 | self.add_batch(self.config.prefill_memory, num_pokes, 1, self.config.episodes) 69 | 70 | self.memory.initialize_loader(self.config.batch_size) 71 | 72 | print('starting training') 73 | 74 | for episode in range(1, self.config.episodes + 1): 75 | if episode == self.config.unfreeze: 76 | self.model.toggle_detection_net(True) 77 | if episode % 50 == 1 and self.config.eval_during_train: 78 | tester_val.evaluate_coco_metrics(tester_val_path, -2, 'bbox', iou=.5) 79 | tester_train.evaluate_coco_metrics(tester_train_path, -2, 'bbox', iou=.5) 80 | 81 | if self.config.save_frequency > 0 and episode % self.config.save_frequency == 0: 82 | torch.save(self.model.state_dict(), 83 | self.config.checkpoint_path + '%d.pth' % episode) 84 | 85 | current_pokes = self.config.poking_schedule(episode, self.config.episodes) 86 | current_num_updates = self.config.update_schedule(episode, self.config.episodes) 87 | 88 | self.optimizer.param_groups[0]['lr'] = self.config.lr_schedule(episode, self.config.episodes) 89 | 90 | new_val_stats = self.add_batch(self.config.new_datapoints_per_episode, current_pokes, 91 | episode, self.config.episodes) 92 | 93 | running_val_stats = [0.7 * old + 0.3 * new for old, new in zip(running_val_stats, new_val_stats)] 94 | stats = running_val_stats[:] 95 | 96 | stats += self.instance_segmentation_update(current_num_updates) 97 | 98 | print(episode, current_pokes, ', %.4f' * len(stats) % tuple(stats)) 99 | logging.info('%d' % episode + ', %d' % current_pokes + ', %.3f' * len(stats) % tuple(stats)) 100 | evals.append(stats) 101 | 102 | self.actors.stop() 103 | print(time() - t_train) 104 | 105 | if self.config.eval_during_train: 106 | os.remove(tester_val_path + '.json') 107 | os.remove(tester_val_path + '.pickle') 108 | os.remove(tester_train_path + '.json') 109 | os.remove(tester_train_path + '.pickle') 110 | 111 | return evals 112 | 113 | def instance_segmentation_update(self, num_updates): 114 | stats = [] 115 | batch_iterator = self.memory.iterator() 116 | for _ in range(num_updates): 117 | batch = next(batch_iterator) 118 | 119 | images = batch['images'].cuda(self.model_device) 120 | targets = tuple(t.cuda(self.model_device) for t in batch['targets']) 121 | 122 | superpixels = batch['superpixels'] if global_config.superpixels else None 123 | 124 | indices = batch['indices'] if self.memory.prioritized_replay else None 125 | 126 | weights = batch['weights'].cuda(self.model_device) 127 | 128 | self.model.train() 129 | model_preds = self.model(images, targets) 130 | 131 | losses = self.loss_function(model_preds, targets, weights, superpixels) 132 | 133 | if self.loss_function.prioritized_replay: 134 | losses, priorities = losses 135 | else: 136 | priorities = None 137 | 138 | if priorities is not None and indices is not None: 139 | self.memory.update_priorities(indices, priorities) 140 | 141 | stats.append(np.array([loss.item() for loss in losses])) 142 | loss = sum(losses) 143 | 144 | self.optimizer.zero_grad() 145 | loss.backward() 146 | self.optimizer.step() 147 | self.model.eval() 148 | 149 | stats = list(np.stack(stats).mean(axis=0)) 150 | return stats 151 | 152 | def add_batch(self, batch_size, num_pokes, episode, episodes): 153 | batch, stats = self.actors.make_batch(batch_size, num_pokes, episode, episodes) 154 | self.memory.add_batch(batch) 155 | return stats 156 | 157 | def init_optimizer(self): 158 | return torch.optim.Adam(self.model.parameters(), lr=5e-4, weight_decay=self.config.weight_decay) 159 | -------------------------------------------------------------------------------- /source/replay_memory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/learning_from_interaction/a266bc16d682832aa854348fa557a30d86b84674/source/replay_memory/__init__.py -------------------------------------------------------------------------------- /source/replay_memory/replay_memory.py: -------------------------------------------------------------------------------- 1 | class Memory: 2 | """ 3 | The Trainer class requires a Memory, which supplies the trainer with batches of images and targets 4 | to train the model on. The memory is filled during training from interactions with the iTHOR environment. 5 | """ 6 | 7 | def __init__(self): 8 | self.prioritized_replay = False 9 | 10 | def add_batch(self, batch: tuple): 11 | """ 12 | :param batch: a tuple of iterables, such as (batched images, batched targets) 13 | :return: Nothing 14 | 15 | Possibly preprocesses the batch, and adds it to the memory. If the memory is full, it overwrites the oldest 16 | entries in the memory. 17 | """ 18 | raise NotImplementedError 19 | 20 | def initialize_loader(self, batch_size: int): 21 | """ 22 | :param batch_size 23 | :return: Nothing 24 | 25 | The memory loader is initialized with the correct batch size. The loader supplies Trainer classes with batches 26 | to train the model on. 27 | """ 28 | raise NotImplementedError 29 | 30 | def iterator(self): 31 | """ 32 | :return: The iterator from which batches are generated, as in: 33 | for images, targets in iterator: train_model(images, targets) 34 | The batch should be returned as a dictionary 35 | """ 36 | raise NotImplementedError 37 | 38 | def load_memory(self, path): 39 | """ 40 | :param path: path to pickle file 41 | :return: Nothing 42 | 43 | Load warm start memory from pickle file 44 | """ 45 | raise NotImplementedError 46 | 47 | def save_memory(self, path): 48 | """ 49 | :param path: path to pickle file 50 | :return: Nothing 51 | 52 | Save current memory to pickle file 53 | """ 54 | raise NotImplementedError 55 | 56 | def base_image_transform(self, image): 57 | """ 58 | :param image: The image as it is returned from the Actor class (possibly image / depth pair) 59 | :return: The image as it would be output by the memory's iterator, minus any data augmentation. 60 | 61 | """ 62 | raise NotImplementedError 63 | 64 | def update_priorities(self, batch_indices: list, priorities: list): 65 | """ 66 | :param batch_indices: The indices of the images in the mini-batch on which the loss was just evaluated 67 | :param priorities: The losses correspondig to the individual images of the minibatch 68 | :return: Nothing. Update the priorities of the images, in case prioritizes replay is used. 69 | """ 70 | raise NotImplementedError 71 | -------------------------------------------------------------------------------- /source/replay_memory/replay_pil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision import transforms 4 | from torchvision.transforms import functional as tf 5 | import pickle 6 | 7 | from config import MemoryConfigPIL, global_config 8 | from replay_memory.replay_memory import Memory 9 | 10 | 11 | class ReplayPILDataset(Memory): 12 | """ 13 | A DIY memory class where images are saved in PIL format, allowing for flipping and color jittering 14 | as data augmentation during training. Actual memory use is not very efficient/fast. 15 | """ 16 | 17 | def __init__(self, memory_config: MemoryConfigPIL): 18 | super(ReplayPILDataset, self).__init__() 19 | self.prioritized_replay = memory_config.prioritized_replay 20 | self.config = memory_config 21 | self.images = [None] * memory_config.capacity 22 | self.depths = [None] * memory_config.capacity if global_config.depth else None 23 | self.targets = [None] * memory_config.capacity 24 | self.superpixels = [None] * memory_config.capacity if global_config.superpixels else None 25 | self.priorities = np.zeros(memory_config.capacity) if self.prioritized_replay else None 26 | self.last = 0 27 | self.max = 0 28 | self.to_pil = transforms.ToPILImage() 29 | self.to_tensor = transforms.ToTensor() 30 | self.image_transform = transforms.Compose([ 31 | transforms.RandomApply([transforms.ColorJitter(*(memory_config.jitter,) * 4)], 32 | p=memory_config.jitter_prob), self.to_tensor]) 33 | self.depth_image_transform = transforms.ToTensor() 34 | self.batch_size = None 35 | 36 | def add_batch(self, batch): 37 | if global_config.superpixels: 38 | images, targets, superpixels = batch 39 | else: 40 | images, targets = batch 41 | superpixels = [None] * len(images) 42 | 43 | for image, target, superpixel in zip(images, targets, superpixels): 44 | if self.last == self.config.capacity: 45 | self.last = 0 46 | 47 | if global_config.depth: 48 | image, depth = image 49 | self.images[self.last] = self.to_pil(image) 50 | self.depths[self.last] = self.to_pil(depth) 51 | else: 52 | self.images[self.last] = self.to_pil(image) 53 | 54 | self.targets[self.last] = target 55 | 56 | if global_config.superpixels: 57 | self.superpixels[self.last] = superpixel 58 | 59 | if self.prioritized_replay: 60 | self.priorities[self.last] = self.config.initial_priority 61 | 62 | self.last += 1 63 | self.max = max(self.max, self.last) 64 | 65 | def initialize_loader(self, batch_size): 66 | self.batch_size = batch_size 67 | 68 | def iterator(self): 69 | probs = self.priorities[:self.max] / self.priorities[:self.max].sum() if self.prioritized_replay else None 70 | index_batches = np.random.choice(self.max, self.max // self.batch_size * self.batch_size, 71 | replace=False, p=probs).reshape((-1, self.batch_size)) 72 | return (self.make_batch(indices) for indices in index_batches) 73 | 74 | def make_batch(self, indices): 75 | images, targets = [], [] 76 | superpixels = [] if global_config.superpixels else None 77 | for i in indices: 78 | image = (self.images[i], self.depths[i]) if global_config.depth else self.images[i] 79 | 80 | image, target, superpixel = self.flip(image, self.targets[i], 81 | self.superpixels[i] if global_config.superpixels else None) 82 | 83 | if global_config.depth: 84 | image = torch.cat([self.image_transform(image[0]).float(), 85 | self.depth_image_transform(image[1])], dim=0) 86 | else: 87 | image = self.image_transform(image).float() 88 | 89 | images.append(image) 90 | 91 | target = [torch.from_numpy(t) for t in target] 92 | targets.append(target) 93 | 94 | if global_config.superpixels: 95 | superpixels.append(torch.from_numpy(superpixel)) 96 | 97 | images = torch.stack(images) 98 | targets = (list(x) for x in zip(*targets)) 99 | targets = tuple(torch.stack(x) for x in targets) 100 | 101 | batch = dict(images=images, targets=targets) 102 | 103 | if global_config.superpixels: 104 | batch['superpixels'] = torch.stack(superpixels) 105 | 106 | if self.prioritized_replay: 107 | batch['indices'] = indices 108 | if self.prioritized_replay and self.config.bias_correct: 109 | batch['weights'] = torch.tensor([1/self.priorities[i] for i in indices]) 110 | else: 111 | batch['weights'] = torch.ones(len(indices)) 112 | 113 | return batch 114 | 115 | def flip(self, image, target, superpixel): 116 | if np.random.random(1) < self.config.flip_prob: 117 | if global_config.depth: 118 | image = (tf.hflip(image[0]), tf.hflip(image[1])) 119 | else: 120 | image = tf.hflip(image) 121 | 122 | target = tuple(np.flip(t, axis=-1).copy() for t in target) 123 | 124 | if global_config.superpixels: 125 | superpixel = np.flip(superpixel, axis=-1).copy() 126 | 127 | return image, target, superpixel 128 | 129 | def load_memory(self, path): 130 | arrays = pickle.load(open(path, 'rb')) 131 | 132 | images, targets = arrays[0][-self.config.capacity:], arrays[1][-self.config.capacity:] 133 | size = len(images) 134 | self.images[:size] = images 135 | self.targets[:size] = targets 136 | 137 | if global_config.depth: 138 | depths = arrays[2][-self.config.capacity:] 139 | self.depths[:size] = depths 140 | if global_config.superpixels: 141 | superpixels = arrays[3][-self.config.capacity:] 142 | self.superpixels[:size] = superpixels 143 | 144 | self.max = size 145 | self.last = size 146 | 147 | def save_memory(self, path): 148 | arrays = (self.images[:self.max], self.targets[:self.max]) 149 | if global_config.depth: 150 | arrays = arrays + (self.depths[:self.max],) 151 | if global_config.superpixels: 152 | arrays = arrays + (self.superpixels[:self.max],) 153 | file = open(path, 'wb') 154 | pickle.dump(arrays, file) 155 | file.close() 156 | 157 | def base_image_transform(self, image): 158 | if global_config.depth: 159 | image = tuple(self.to_tensor(self.to_pil(im)) for im in image) 160 | else: 161 | image = self.to_tensor(self.to_pil(image)) 162 | 163 | if global_config.depth: 164 | image = torch.cat(list(image), dim=0) 165 | 166 | return image 167 | 168 | def update_priorities(self, batch_indices: list, priorities: list): 169 | for index, priority in zip(batch_indices, priorities): 170 | self.priorities[index] = priority 171 | -------------------------------------------------------------------------------- /source/replay_memory/replay_tensor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision import transforms 4 | import pickle 5 | 6 | from config import MemoryConfigTensor, global_config 7 | from replay_memory.replay_memory import Memory 8 | 9 | 10 | class ReplayTensorDataset(Memory): 11 | """ 12 | In this class, all output of the Actor (images and targets) are saved as tensors. No data augmentation happens 13 | in the memory's iterator. This should be the fastest (most lightweight) way to implement the Memory class. 14 | """ 15 | 16 | def __init__(self, memory_config: MemoryConfigTensor): 17 | super(ReplayTensorDataset, self).__init__() 18 | self.config = memory_config 19 | self.to_tensor = transforms.ToTensor() 20 | memory = tuple(torch.zeros((memory_config.capacity,) + shape, dtype=dtype) for shape, dtype in 21 | zip(memory_config.sizes, memory_config.dtypes)) 22 | self.memory = GrowingTensorDataset(memory, self) 23 | self.loader = None 24 | self.last = 0 25 | self.max = 0 26 | if memory_config.warm_start_memory is not None: 27 | self.load_memory(memory_config.warm_start_memory) 28 | 29 | def add_batch(self, batch): 30 | for datapoint in zip(*batch): 31 | if self.last == self.config.capacity: 32 | self.last = 0 33 | datapoint = self.base_transform(datapoint) 34 | for tensor, new_entry in zip(self.memory.tensors, datapoint): 35 | tensor[self.last] = new_entry 36 | self.last += 1 37 | self.max = max(self.last, self.max) 38 | 39 | def initialize_loader(self, batch_size): 40 | self.loader = torch.utils.data.DataLoader(self.memory, shuffle=True, batch_size=batch_size, pin_memory=True, 41 | num_workers=self.config.num_workers) 42 | 43 | def iterator(self): 44 | if global_config.superpixels: 45 | return iter(dict(images=batch[0], targets=batch[1:-1], superpixels=batch[-1]) for batch in self.loader) 46 | return iter(dict(images=batch[0], targets=batch[1:]) for batch in self.loader) 47 | 48 | def load_memory(self, path): 49 | tensors = pickle.load(open(path, 'rb')) 50 | tensors = [tensor[-self.config.capacity:] for tensor in tensors] 51 | size = tensors[0].shape[0] 52 | for t1, t2 in zip(self.memory.tensors, tensors): 53 | t1[:size] = t2 54 | self.max = size 55 | self.last = size 56 | 57 | def save_memory(self, path): 58 | tensors = [tensor[:self.max] for tensor in self.memory.tensors] 59 | file = open(path, 'wb') 60 | pickle.dump(tensors, file) 61 | file.close() 62 | 63 | def base_image_transform(self, image): 64 | if global_config.depth: 65 | image = np.concatenate([image[0], image[1][..., None]], axis=2) 66 | return self.to_tensor(image) 67 | 68 | def base_transform(self, datapoint): 69 | image = self.base_image_transform(datapoint[0]) 70 | targets_and_superpixels = [x for x in datapoint[1]] + ([datapoint[2]] if global_config.superpixels else []) 71 | return [image] + [torch.from_numpy(x) for x in targets_and_superpixels] 72 | 73 | def update_priorities(self, batch_indices: list, priorities: list): 74 | """ 75 | :param batch_indices: The indices of the images in the mini-batch on which the loss was just evaluated 76 | :param priorities: The losses correspondig to the individual images of the minibatch 77 | :return: Nothing. Update the priorities of the images, in case prioritizes replay is used. 78 | """ 79 | raise NotImplementedError 80 | 81 | 82 | class GrowingTensorDataset(torch.utils.data.TensorDataset): 83 | """ 84 | Trivial modification of pytorch's inbuilt TensorDataset to allow for dataset size that grows during training 85 | """ 86 | 87 | def __init__(self, memory, owner): 88 | super(GrowingTensorDataset, self).__init__(*memory) 89 | self.owner = owner 90 | 91 | def __len__(self): 92 | return self.owner.max 93 | -------------------------------------------------------------------------------- /source/requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | numpy==1.18.1 3 | git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI 4 | Pillow==6.2.2 5 | torch==1.3.1 6 | torchvision==0.4.2 7 | scikit-image==0.16.2 8 | scikit-learn==0.22.2.post1 9 | -e git+git://github.com/allenai/ai2thor.git@405d8ec4846208b8e6bff2b379df60c4cb80b55b#egg=ai2thor 10 | -------------------------------------------------------------------------------- /source/startx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import subprocess 4 | import shlex 5 | import re 6 | import platform 7 | import tempfile 8 | import os 9 | import sys 10 | 11 | 12 | def pci_records(): 13 | records = [] 14 | command = shlex.split('lspci -vmm') 15 | output = subprocess.check_output(command).decode() 16 | 17 | for devices in output.strip().split("\n\n"): 18 | record = {} 19 | records.append(record) 20 | for row in devices.split("\n"): 21 | key, value = row.split("\t") 22 | record[key.split(':')[0]] = value 23 | 24 | return records 25 | 26 | 27 | def generate_xorg_conf(devices): 28 | xorg_conf = [] 29 | 30 | device_section = """ 31 | Section "Device" 32 | Identifier "Device{device_id}" 33 | Driver "nvidia" 34 | VendorName "NVIDIA Corporation" 35 | BusID "{bus_id}" 36 | EndSection 37 | """ 38 | server_layout_section = """ 39 | Section "ServerLayout" 40 | Identifier "Layout0" 41 | {screen_records} 42 | EndSection 43 | """ 44 | screen_section = """ 45 | Section "Screen" 46 | Identifier "Screen{screen_id}" 47 | Device "Device{device_id}" 48 | DefaultDepth 24 49 | Option "AllowEmptyInitialConfiguration" "True" 50 | SubSection "Display" 51 | Depth 24 52 | Virtual 1024 768 53 | EndSubSection 54 | EndSection 55 | """ 56 | screen_records = [] 57 | for i, bus_id in enumerate(devices): 58 | xorg_conf.append(device_section.format(device_id=i, bus_id=bus_id)) 59 | xorg_conf.append(screen_section.format(device_id=i, screen_id=i)) 60 | screen_records.append('Screen {screen_id} "Screen{screen_id}" 0 0'.format(screen_id=i)) 61 | 62 | xorg_conf.append(server_layout_section.format(screen_records="\n ".join(screen_records))) 63 | 64 | output = "\n".join(xorg_conf) 65 | print(output) 66 | return output 67 | 68 | 69 | def startx(display): 70 | if platform.system() != 'Linux': 71 | raise Exception("Can only run startx on linux") 72 | 73 | devices = [] 74 | for r in pci_records(): 75 | if r.get('Vendor', '') == 'NVIDIA Corporation' \ 76 | and r['Class'] in ['VGA compatible controller', '3D controller']: 77 | bus_id = 'PCI:' + ':'.join(map(lambda x: str(int(x, 16)), re.split(r'[:\.]', r['Slot']))) 78 | devices.append(bus_id) 79 | 80 | if not devices: 81 | raise Exception("no nvidia cards found") 82 | 83 | try: 84 | fd, path = tempfile.mkstemp() 85 | with open(path, "w") as f: 86 | f.write(generate_xorg_conf(devices)) 87 | command = shlex.split("Xorg -noreset -logfile xorg.log -logverbose -config %s :%s" % (path, display)) 88 | subprocess.call(command) 89 | finally: 90 | os.close(fd) 91 | os.unlink(path) 92 | 93 | 94 | if __name__ == '__main__': 95 | display = 0 96 | if len(sys.argv) > 1: 97 | display = int(sys.argv[1]) 98 | print("Starting X on DISPLAY=:%s" % display) 99 | startx(display) -------------------------------------------------------------------------------- /source/tools/coco_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from datetime import datetime 4 | 5 | from pycocotools.mask import area, toBbox 6 | 7 | from tools.logger import LOGGER 8 | 9 | 10 | def save_coco_dataset(dataset_file, output_folder, classes=("light", "medium", "heavy"), force=False): 11 | def get_dicts(jsonfile): 12 | with open(jsonfile, "r") as f: 13 | res = json.load(f) 14 | return res 15 | 16 | def data_to_coco(data, classes): 17 | res = dict( 18 | info=dict( 19 | date_created=datetime.now().strftime("%Y%m%d%H%M%S"), 20 | description="Automatically generated COCO json file", 21 | ), 22 | categories=[dict(id=it, name=cl) for it, cl in enumerate(classes)], 23 | images=[], 24 | annotations=[], 25 | ) 26 | 27 | for ep in data: 28 | res["images"].append(dict( 29 | id=ep["image_id"], 30 | width=ep["width"], 31 | height=ep["height"], 32 | file_name="" 33 | )) 34 | 35 | for ann in ep["annotations"]: 36 | seg = ann["segmentation"] 37 | res["annotations"].append(dict( 38 | id=len(res["annotations"]) + 1, 39 | image_id=ep["image_id"], 40 | bbox=list(toBbox(seg)), 41 | area=float(area(seg)), 42 | iscrowd=0, 43 | category_id=ann["category_id"], 44 | segmentation=seg, 45 | )) 46 | 47 | return res 48 | 49 | dataset_base = os.path.basename(dataset_file) 50 | json_file_name = os.path.join(output_folder, dataset_base.replace(".json", "__coco_format.json")) 51 | if os.path.exists(json_file_name) and not force: 52 | LOGGER.info("skipping conversion; {} already exists".format(json_file_name)) 53 | return json_file_name 54 | 55 | json_dict = data_to_coco(get_dicts(dataset_file), classes) 56 | with open(json_file_name, "w") as f: 57 | json.dump(json_dict, f) 58 | LOGGER.info("COCO gt annotations saved to {}".format(json_file_name)) 59 | 60 | return json_file_name 61 | -------------------------------------------------------------------------------- /source/tools/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from torch.utils.data import Dataset 5 | 6 | from config import global_config 7 | from pipeline.actor import Actor 8 | 9 | 10 | class EvalDataset(Dataset): 11 | def __init__(self, dataset_file, memory, controller): 12 | with open(dataset_file, "r") as f: 13 | self.dataset = json.load(f) 14 | self.folder = os.path.dirname(dataset_file) 15 | self.memory = memory 16 | self.lut = Actor._make_depth_correction(global_config.resolution, global_config.resolution, 90) 17 | self.controller = controller 18 | 19 | def __len__(self): 20 | return len(self.dataset) 21 | 22 | def load_meta(self, thor_meta): 23 | scene = thor_meta["scene"] 24 | seed = thor_meta["seed"] 25 | position = thor_meta["position"] 26 | rotation = thor_meta["rotation"] 27 | horizon = thor_meta["horizon"] 28 | 29 | self.controller.reset(scene) 30 | self.controller.step(action='InitialRandomSpawn', seed=seed, 31 | forceVisible=True, numPlacementAttempts=5) 32 | self.controller.step(action='MakeAllObjectsMoveable') 33 | event = self.controller.step(action='TeleportFull', x=position['x'], y=position['y'], 34 | z=position['z'], rotation=rotation, horizon=horizon) 35 | 36 | return event 37 | 38 | def __getitem__(self, item): 39 | entry = self.dataset[item] 40 | evt = self.load_meta(entry["thor_meta"]) 41 | rgb = evt.frame.copy() 42 | if global_config.depth: 43 | dist = (evt.depth_frame.copy() - .1) * self.lut 44 | rgbd = self.memory.base_image_transform((rgb, dist)) 45 | else: 46 | rgbd = self.memory.base_image_transform(rgb) 47 | 48 | return rgbd, entry["image_id"] 49 | 50 | 51 | class ActiveDataset(EvalDataset): 52 | def __init__(self, dataset_file, memory, controller, conn): 53 | super().__init__(dataset_file, memory, controller) 54 | self.conn = conn 55 | 56 | def process(self): 57 | while True: 58 | item = self.conn.recv() 59 | if item is None: 60 | break 61 | self.conn.send(self.__getitem__(item)) 62 | -------------------------------------------------------------------------------- /source/tools/dist_training_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | from models.instance_segmentation_models import Model 6 | from pipeline.trainer import Trainer 7 | 8 | 9 | class DistributedWrapper(Model): 10 | def __init__(self, shared_model): 11 | super(DistributedWrapper, self).__init__() 12 | self.shared_model = shared_model 13 | 14 | def forward(self, images: torch.tensor, *targets): 15 | return self.shared_model(images, *targets) 16 | 17 | def compute_actions(self, images: torch.tensor, num_pokes: int, episode: int, episodes: int): 18 | return self.shared_model.module.compute_actions(images, num_pokes, episode, episodes) 19 | 20 | def compute_masks(self, images: torch.tensor, threshold: float): 21 | return self.shared_model.module.compute_masks(images, threshold) 22 | 23 | 24 | class DummySharedWrapper(torch.nn.Module): 25 | def __init__(self, model): 26 | super(DummySharedWrapper, self).__init__() 27 | self.module = model 28 | 29 | def forward(self, x, *y): 30 | return self.module(x, *y) 31 | 32 | 33 | def do_setup_and_start_training(modules, configs, rank, size, device_list, single=False): 34 | if not single: 35 | os.environ['MASTER_ADDR'] = '127.0.0.1' 36 | os.environ['MASTER_PORT'] = '29509' 37 | dist.init_process_group('nccl', rank=rank, world_size=size) 38 | 39 | with torch.cuda.device(device_list[rank]): 40 | print('initializing model on rank %d'%rank) 41 | if single: 42 | shared_model = DummySharedWrapper(modules['model'](configs['model'])).cuda() 43 | else: 44 | shared_model = torch.nn.parallel.DistributedDataParallel(modules['model'](configs['model']).cuda(), 45 | device_ids=[device_list[rank]], 46 | find_unused_parameters=True) 47 | model = DistributedWrapper(shared_model) 48 | memory = modules['memory'](configs['memory']) 49 | loss_function = modules['loss'](configs['loss']) 50 | trainer_config = configs['trainer'] 51 | trainer_config.log_path = trainer_config.log_path + '%d.log'%rank 52 | if rank > 0: 53 | trainer_config.save_frequency = 0 54 | trainer = Trainer(model, memory, loss_function, trainer_config) 55 | print('starting training process on rank %d'%rank) 56 | stats = trainer.train() 57 | 58 | print('done %d' % rank) 59 | return stats 60 | 61 | -------------------------------------------------------------------------------- /source/tools/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | LOGGER = logging.getLogger("console_logger") 5 | 6 | 7 | def excepthook(*args): 8 | LOGGER.error("Uncaught exception:", exc_info=args) 9 | 10 | 11 | class StreamToLogger: 12 | def __init__(self): 13 | self.linebuf = '' 14 | 15 | def write(self, buf): 16 | temp_linebuf = self.linebuf + buf 17 | self.linebuf = '' 18 | for line in temp_linebuf.splitlines(True): 19 | if line[-1] == '\n': 20 | LOGGER.info(line.rstrip()) 21 | else: 22 | self.linebuf += line 23 | 24 | def flush(self): 25 | if self.linebuf != '': 26 | LOGGER.info(self.linebuf.rstrip()) 27 | self.linebuf = '' 28 | 29 | 30 | def init_logging(log_format="default", log_level="debug"): 31 | if len(LOGGER.handlers) > 0: 32 | return 33 | 34 | if log_level == "debug": 35 | log_level = logging.DEBUG 36 | elif log_level == "info": 37 | log_level = logging.INFO 38 | elif log_level == "warning": 39 | log_level = logging.WARNING 40 | elif log_level == "error": 41 | log_level = logging.ERROR 42 | assert log_level in [logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR], \ 43 | "unknown log_level {}".format(log_level) 44 | 45 | ch = logging.StreamHandler() 46 | ch.setLevel(log_level) 47 | 48 | if log_format == "default": 49 | formatter = logging.Formatter( 50 | fmt="%(asctime)s: %(levelname)s: %(message)s\t[%(filename)s: %(lineno)d]", 51 | datefmt="%m/%d %H:%M:%S", 52 | ) 53 | elif log_format == "defaultMilliseconds": 54 | formatter = logging.Formatter( 55 | fmt="%(asctime)s: %(levelname)s: %(message)s\t[%(filename)s: %(lineno)d]" 56 | ) 57 | else: 58 | formatter = logging.Formatter(fmt=log_format, datefmt="%m/%d %H:%M:%S") 59 | ch.setFormatter(formatter) 60 | 61 | LOGGER.setLevel(log_level) 62 | LOGGER.addHandler(ch) 63 | 64 | sys.excepthook = excepthook 65 | sys.stdout = StreamToLogger() 66 | -------------------------------------------------------------------------------- /source/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from models.clustering_models import ClusteringModel 5 | from pipeline.trainer import Trainer 6 | from pipeline.tester import Tester 7 | from replay_memory.replay_pil import ReplayPILDataset 8 | from losses.clustering_losses import MaskAndMassLoss 9 | from config import MemoryConfigPIL, MaskAndMassLossConfig, TrainerConfig, TestingConfig, ClusteringModelConfig 10 | from config import global_config, actor_config 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser( 15 | description="self-supervised-objects training", 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 17 | ) 18 | parser.add_argument( 19 | "output_folder", 20 | type=str, 21 | help="required output model folder name", 22 | ) 23 | parser.add_argument( 24 | "dataset_folder", 25 | type=str, 26 | help="required dataset folder name", 27 | ) 28 | parser.add_argument( 29 | "dataset", 30 | type=int, 31 | help="required dataset type should be 0 (NovelObjects) or 1 (NovelSpaces) and match the one used for training", 32 | ) 33 | parser.add_argument( 34 | "-g", 35 | "--model_gpu", 36 | required=False, 37 | default=0, 38 | type=int, 39 | help="gpu id to run model", 40 | ) 41 | parser.add_argument( 42 | "-a", 43 | "--actors_gpu", 44 | required=False, 45 | default=1, 46 | type=int, 47 | help="gpu id to run AI2-THOR actors", 48 | ) 49 | parser.add_argument( 50 | "-p", 51 | "--checkpoint_prefix", 52 | required=False, 53 | default="clustering_model_weights_", 54 | type=str, 55 | help="prefix for checkpoints in output folder", 56 | ) 57 | args = parser.parse_args() 58 | 59 | return args 60 | 61 | 62 | if __name__ == '__main__': 63 | args = get_args() 64 | 65 | print("Running train with args {}".format(args)) 66 | 67 | output_folder = os.path.normpath(args.output_folder) 68 | dataset_folder = os.path.normpath(args.dataset_folder) 69 | dataset = args.dataset 70 | 71 | os.makedirs(output_folder, exist_ok=True) 72 | 73 | assert os.path.isdir(output_folder), 'Output folder does not exist' 74 | assert os.path.isdir(dataset_folder), 'Dataset folder does not exist' 75 | assert dataset in [0, 1], 'Dataset argument should be either 0 (NovelObjects) or 1 (NovelSpaces)' 76 | 77 | global_config.model_gpu = args.model_gpu 78 | global_config.actor_gpu = args.actors_gpu 79 | 80 | actor_config.data_files = [['NovelObjects__train.json', 81 | 'NovelObjects__test.json'], 82 | ['NovelSpaces__train.json', 83 | 'NovelSpaces__test.json']][dataset] 84 | actor_config.data_files = [os.path.join(dataset_folder, fn) for fn in actor_config.data_files] 85 | 86 | loss_function = MaskAndMassLoss(MaskAndMassLossConfig()) 87 | model = ClusteringModel(ClusteringModelConfig()).cuda(global_config.model_gpu) 88 | 89 | trainer_config = TrainerConfig() 90 | trainer_config.log_path = os.path.join(output_folder, 'training_log.log') 91 | 92 | trainer = Trainer(model, ReplayPILDataset(MemoryConfigPIL()), loss_function, trainer_config) 93 | 94 | print('Running instance segmentation only pre-training') 95 | 96 | actor_config.instance_only = True 97 | loss_function.config.instance_only = True 98 | trainer_config.checkpoint_path = os.path.join(output_folder, args.checkpoint_prefix + 'inst_only_') 99 | 100 | model.toggle_mass_head(False) 101 | 102 | trainer.train() 103 | 104 | print('Training with force prediction') 105 | 106 | actor_config.instance_only = False 107 | loss_function.config.instance_only = False 108 | trainer_config.checkpoint_path = os.path.join(output_folder, args.checkpoint_prefix) 109 | trainer_config.update_schedule = lambda episode, episodes: int(15 + 20 * episode / episodes) 110 | trainer_config.poking_schedule = lambda episode, episodes: 10 111 | trainer.memory = ReplayPILDataset(MemoryConfigPIL()) 112 | 113 | model.toggle_mass_head(True) 114 | model.toggle_detection_net(False) 115 | trainer_config.unfreeze = 100 116 | 117 | trainer.train() 118 | 119 | print('Testing on small subset of Val and preparing some illustrations') 120 | 121 | tester = Tester(model, ReplayPILDataset(MemoryConfigPIL()), loss_function, TestingConfig()) 122 | 123 | tester.illustrate_poking_and_predictions(os.path.join(output_folder, 'illustrations/'), 50, -.2) 124 | tester.evaluate_coco_metrics(300, threshold=-100, annotation_type='bbox', iou=.5) 125 | -------------------------------------------------------------------------------- /trained_model_novel_objects/clustering_model_weights_900.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/learning_from_interaction/a266bc16d682832aa854348fa557a30d86b84674/trained_model_novel_objects/clustering_model_weights_900.pth -------------------------------------------------------------------------------- /trained_model_novel_spaces/clustering_model_weights_900.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/learning_from_interaction/a266bc16d682832aa854348fa557a30d86b84674/trained_model_novel_spaces/clustering_model_weights_900.pth -------------------------------------------------------------------------------- /training_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/learning_from_interaction/a266bc16d682832aa854348fa557a30d86b84674/training_video.gif --------------------------------------------------------------------------------