├── LICENSE ├── README.md ├── common ├── evaluation.py ├── logger.py ├── supervision.py └── utils.py ├── data ├── caltech.py ├── dataset.py ├── download.py ├── pfpascal.py ├── pfwillow.py └── spair.py ├── model ├── base │ ├── correlation.py │ ├── geometry.py │ ├── norm.py │ └── resnet.py ├── dhpf.py ├── gating.py ├── objective.py └── rhm.py ├── test.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 - Juhong Min 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-to-compose-hypercolumns-for-visual/semantic-correspondence-on-spair-71k)](https://paperswithcode.com/sota/semantic-correspondence-on-spair-71k?p=learning-to-compose-hypercolumns-for-visual) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-to-compose-hypercolumns-for-visual/semantic-correspondence-on-pf-pascal)](https://paperswithcode.com/sota/semantic-correspondence-on-pf-pascal?p=learning-to-compose-hypercolumns-for-visual) 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-to-compose-hypercolumns-for-visual/semantic-correspondence-on-pf-willow)](https://paperswithcode.com/sota/semantic-correspondence-on-pf-willow?p=learning-to-compose-hypercolumns-for-visual) 4 | 5 | # Learning to Compose Hypercolumns for Visual Correspondence 6 | This is the implementation of the paper "Learning to Compose Hypercolumns for Visual Correspondence" by J. Min, J. Lee, J. Ponce and M. Cho. Implemented on Python 3.7 and PyTorch 1.0.1. 7 | 8 | ![](https://juhongm999.github.io/pic/dhpf.png) 9 | 10 | For more information, check out project [[website](http://cvlab.postech.ac.kr/research/DHPF/)] and the paper on [[arXiv](https://arxiv.org/abs/2007.10587)]. 11 | 12 | 13 | ## Requirements 14 | 15 | - Python 3.7 16 | - PyTorch 1.0.1 17 | - tensorboard 18 | - scipy 19 | - pandas 20 | - requests 21 | - scikit-image 22 | 23 | Conda environment settings: 24 | ```bash 25 | conda create -n dhpf python=3.7 26 | conda activate dhpf 27 | 28 | conda install pytorch=1.0.1 torchvision cudatoolkit=10.0 -c pytorch 29 | pip install tensorboardX 30 | conda install -c anaconda scipy 31 | conda install -c anaconda pandas 32 | conda install -c anaconda requests 33 | conda install -c anaconda scikit-image 34 | conda install -c anaconda "pillow<7" 35 | ``` 36 | 37 | ## Training 38 | 39 | Training DHPF with strong supervision (keypoint annotations) on PF-PASCAL and SPair-71k
40 | (reproducing strongly-supervised results in Tab. 1 and 2): 41 | ```bash 42 | python train.py --supervision strong \ 43 | --lr 0.03 \ 44 | --bsz 8 \ 45 | --niter 100 \ 46 | --selection 0.5 \ 47 | --benchmark pfpascal \ 48 | --backbone {resnet50, resnet101} 49 | 50 | python train.py --supervision strong \ 51 | --lr 0.03 \ 52 | --bsz 8 \ 53 | --niter 5 \ 54 | --selection 0.5 \ 55 | --benchmark spair \ 56 | --backbone {resnet50, resnet101} 57 | ``` 58 | Training DHPF with weak supervision (image-level labels) on PF-PASCAL
59 | (reproducing weak-supervised results in Tab. 1): 60 | ```bash 61 | python train.py --supervision weak \ 62 | --lr 0.1 \ 63 | --bsz 4 \ 64 | --niter 30 \ 65 | --selection 0.5 \ 66 | --benchmark pfpascal \ 67 | --backbone {resnet50, resnet101} 68 | ``` 69 | 70 | ## Testing 71 | 72 | We provide trained models available on [[Google drive](https://drive.google.com/drive/folders/1aoKQlvHOb7vZIFK8pDJsQnC7SOyEjXVF?usp=sharing)]. 73 | 74 | PCK @ αimg=0.1 on PF-PASCAL at different μ: 75 | 76 | | Trained models
at differnt μ | 0.3 | 0.4 | 0.5 | 0.6 | 0.7 | 0.8 | 0.9 | 1 | 77 | |:--------------------------------------------------:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| 78 | | [weak (res50)](https://drive.google.com/drive/folders/1WykysKyy9PAsX-DpC5UuZILokCMToJWH?usp=sharing) | 77.3 | 79 | 79 | 79.3 | 79.6 | 80.7 | 81.1 | 80.7 | 79 | | [weak (res101)](https://drive.google.com/drive/folders/1IjjoFgrIZzys2YDEGhLQrOg0bTG29-Pl?usp=sharing) | 80.3 | 81.2 | 82.1 | 80.1 | 81.7 | 80.9 | 81.3 | 81.3 | 80 | | [strong (res50)](https://drive.google.com/drive/folders/1RC9EbVhk8QOjpF3NIO-tidIsKcY399S8?usp=sharing) | 87.7 | 89.1 | 88.9 | 88.5 | 89.4 | 89.1 | 89 | 89.5 | 81 | | [strong (res101)](https://drive.google.com/drive/folders/1QDYOxqF-BsWKjKbwLKfbcfxaS5OHlbVT?usp=sharing) | 88.7 | 90 | 90.7 | 90.2 | 90.1 | 90.6 | 90.6 | 90.4 | 82 | 83 | PCK @ αimg=0.1 on SPair-71k at μ=0.5: 84 | 85 | | Trained models
at μ=0.5 | PCK | 86 | |:---------------------------------------------:|:----:| 87 | | [weak (res101)](https://drive.google.com/file/d/1uDfONwSiAzDsxW9wbhdlYKf8auqAVXoM/view?usp=sharing) | 27.7 | 88 | | [strong (res101)](https://drive.google.com/file/d/1DnsDhttMIImAcupdjuANowlgZqVSx_5E/view?usp=sharing) | 37.3 | 89 | 90 | Reproducing results in Tab. 1, 2 and 3: 91 | ```bash 92 | python test.py --backbone {resnet50, resnet101} \ 93 | --benchmark {pfpascal, pfwillow, caltech, spair} \ 94 | --load "path_to_trained_model" 95 | ``` 96 | 97 | 98 | ## BibTeX 99 | If you use this code for your research, please consider citing: 100 | ````BibTeX 101 | @InProceedings{min2020dhpf, 102 | title={Learning to Compose Hypercolumns for Visual Correspondence}, 103 | author={Juhong Min and Jongmin Lee and Jean Ponce and Minsu Cho}, 104 | booktitle={ECCV}, 105 | year={2020} 106 | } 107 | ```` 108 | -------------------------------------------------------------------------------- /common/evaluation.py: -------------------------------------------------------------------------------- 1 | """For quantitative evaluation of DHPF""" 2 | from skimage import draw 3 | import numpy as np 4 | import torch 5 | 6 | from . import utils 7 | 8 | 9 | class Evaluator: 10 | r"""Computes evaluation metrics of PCK, LT-ACC, IoU""" 11 | @classmethod 12 | def initialize(cls, benchmark, alpha=0.1): 13 | if benchmark == 'caltech': 14 | cls.eval_func = cls.eval_mask_transfer 15 | else: 16 | cls.eval_func = cls.eval_kps_transfer 17 | cls.alpha = alpha 18 | 19 | @classmethod 20 | def evaluate(cls, prd_kps, batch): 21 | r"""Compute evaluation metric""" 22 | return cls.eval_func(prd_kps, batch) 23 | 24 | @classmethod 25 | def eval_kps_transfer(cls, prd_kps, batch): 26 | r"""Compute percentage of correct key-points (PCK) based on prediction""" 27 | 28 | easy_match = {'src': [], 'trg': [], 'dist': []} 29 | hard_match = {'src': [], 'trg': []} 30 | 31 | pck = [] 32 | for idx, (pk, tk) in enumerate(zip(prd_kps, batch['trg_kps'])): 33 | thres = batch['pckthres'][idx] 34 | npt = batch['n_pts'][idx] 35 | correct_dist, correct_ids, incorrect_ids = cls.classify_prd(pk[:, :npt], tk[:, :npt], thres) 36 | 37 | # Collect easy and hard match feature index & store pck to buffer 38 | easy_match['dist'].append(correct_dist) 39 | easy_match['src'].append(batch['src_kpidx'][idx][:npt][correct_ids]) 40 | easy_match['trg'].append(batch['trg_kpidx'][idx][:npt][correct_ids]) 41 | hard_match['src'].append(batch['src_kpidx'][idx][:npt][incorrect_ids]) 42 | hard_match['trg'].append(batch['trg_kpidx'][idx][:npt][incorrect_ids]) 43 | pck.append((len(correct_ids) / npt.item()) * 100) 44 | 45 | eval_result = {'easy_match': easy_match, 46 | 'hard_match': hard_match, 47 | 'pck': pck} 48 | 49 | return eval_result 50 | 51 | @classmethod 52 | def eval_mask_transfer(cls, prd_kps, batch): 53 | r"""Compute LT-ACC and IoU based on transferred points""" 54 | 55 | ltacc = [] 56 | iou = [] 57 | 58 | for idx, prd in enumerate(prd_kps): 59 | trg_n_pts = (batch['trg_kps'][idx] > 0)[0].sum() 60 | prd_kp = prd[:, :batch['n_pts'][idx]] 61 | trg_kp = batch['trg_kps'][idx][:, :trg_n_pts] 62 | 63 | imsize = list(batch['trg_img'].size())[2:] 64 | trg_xstr, trg_ystr = cls.pts2ptstr(trg_kp) 65 | trg_mask = cls.ptstr2mask(trg_xstr, trg_ystr, imsize[0], imsize[1]) 66 | prd_xstr, pred_ystr = cls.pts2ptstr(prd_kp) 67 | prd_mask = cls.ptstr2mask(prd_xstr, pred_ystr, imsize[0], imsize[1]) 68 | 69 | ltacc.append(cls.label_transfer_accuracy(prd_mask, trg_mask)) 70 | iou.append(cls.intersection_over_union(prd_mask, trg_mask)) 71 | 72 | eval_result = {'ltacc': ltacc, 73 | 'iou': iou} 74 | 75 | return eval_result 76 | 77 | @classmethod 78 | def classify_prd(cls, prd_kps, trg_kps, pckthres): 79 | r"""Compute the number of correctly transferred key-points""" 80 | l2dist = (prd_kps - trg_kps).pow(2).sum(dim=0).pow(0.5) 81 | thres = pckthres.expand_as(l2dist).float() * cls.alpha 82 | correct_pts = torch.le(l2dist, thres) 83 | 84 | correct_ids = utils.where(correct_pts == 1) 85 | incorrect_ids = utils.where(correct_pts == 0) 86 | correct_dist = l2dist[correct_pts] 87 | 88 | return correct_dist, correct_ids, incorrect_ids 89 | 90 | @classmethod 91 | def intersection_over_union(cls, mask1, mask2): 92 | r"""Computes IoU between two masks""" 93 | rel_part_weight = torch.sum(torch.sum(mask2.gt(0.5).float(), 2, True), 3, True) / \ 94 | torch.sum(mask2.gt(0.5).float()) 95 | part_iou = torch.sum(torch.sum((mask1.gt(0.5) & mask2.gt(0.5)).float(), 2, True), 3, True) / \ 96 | torch.sum(torch.sum((mask1.gt(0.5) | mask2.gt(0.5)).float(), 2, True), 3, True) 97 | weighted_iou = torch.sum(torch.mul(rel_part_weight, part_iou)).item() 98 | 99 | return weighted_iou 100 | 101 | @classmethod 102 | def label_transfer_accuracy(cls, mask1, mask2): 103 | r"""LT-ACC measures the overlap with emphasis on the background class""" 104 | return torch.mean((mask1.gt(0.5) == mask2.gt(0.5)).double()).item() 105 | 106 | @classmethod 107 | def pts2ptstr(cls, pts): 108 | r"""Convert tensor of points to string""" 109 | x_str = str(list(pts[0].cpu().numpy())) 110 | x_str = x_str[1:len(x_str)-1] 111 | y_str = str(list(pts[1].cpu().numpy())) 112 | y_str = y_str[1:len(y_str)-1] 113 | 114 | return x_str, y_str 115 | 116 | @classmethod 117 | def pts2mask(cls, x_pts, y_pts, shape): 118 | r"""Build a binary mask tensor base on given xy-points""" 119 | x_idx, y_idx = draw.polygon(x_pts, y_pts, shape) 120 | mask = np.zeros(shape, dtype=np.bool) 121 | mask[x_idx, y_idx] = True 122 | 123 | return mask 124 | 125 | @classmethod 126 | def ptstr2mask(cls, x_str, y_str, out_h, out_w): 127 | r"""Convert xy-point mask (string) to tensor mask""" 128 | x_pts = np.fromstring(x_str, sep=',') 129 | y_pts = np.fromstring(y_str, sep=',') 130 | mask_np = cls.pts2mask(y_pts, x_pts, [out_h, out_w]) 131 | mask = torch.tensor(mask_np.astype(np.float32)).unsqueeze(0).unsqueeze(0).float() 132 | 133 | return mask 134 | -------------------------------------------------------------------------------- /common/logger.py: -------------------------------------------------------------------------------- 1 | r"""Logging""" 2 | import datetime 3 | import logging 4 | import os 5 | 6 | from tensorboardX import SummaryWriter 7 | import matplotlib.pyplot as plt 8 | import pandas as pd 9 | import numpy as np 10 | import torch 11 | 12 | 13 | class Logger: 14 | r"""Writes results of training/testing""" 15 | @classmethod 16 | def initialize(cls, args): 17 | logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S') 18 | logpath = args.logpath 19 | 20 | cls.logpath = os.path.join('logs', logpath + logtime + '.log') 21 | cls.benchmark = args.benchmark 22 | os.makedirs(cls.logpath) 23 | 24 | logging.basicConfig(filemode='w', 25 | filename=os.path.join(cls.logpath, 'log.txt'), 26 | level=logging.INFO, 27 | format='%(message)s', 28 | datefmt='%m-%d %H:%M:%S') 29 | 30 | # Console log config 31 | console = logging.StreamHandler() 32 | console.setLevel(logging.INFO) 33 | formatter = logging.Formatter('%(message)s') 34 | console.setFormatter(formatter) 35 | logging.getLogger('').addHandler(console) 36 | 37 | # Tensorboard writer 38 | cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs')) 39 | 40 | # Log arguments 41 | logging.info('\n+=========== Dynamic Hyperpixel Flow ============+') 42 | for arg_key in args.__dict__: 43 | logging.info('| %20s: %-24s |' % (arg_key, str(args.__dict__[arg_key]))) 44 | logging.info('+================================================+\n') 45 | 46 | @classmethod 47 | def info(cls, msg): 48 | r"""Writes message to .txt""" 49 | logging.info(msg) 50 | 51 | @classmethod 52 | def save_model(cls, model, epoch, val_pck): 53 | torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt')) 54 | cls.info('Model saved @%d w/ val. PCK: %5.2f.\n' % (epoch, val_pck)) 55 | 56 | @classmethod 57 | def visualize_selection(cls, catwise_sel): 58 | r"""Visualize (class-wise) layer selection frequency""" 59 | if cls.benchmark == 'pfpascal': 60 | sort_ids = [17, 8, 10, 19, 4, 15, 0, 3, 6, 5, 18, 13, 1, 14, 12, 2, 11, 7, 16, 9] 61 | elif cls.benchmark == 'pfwillow': 62 | sort_ids = np.arange(10) 63 | elif cls.benchmark == 'caltech': 64 | sort_ids = np.arange(101) 65 | elif cls.benchmark == 'spair': 66 | sort_ids = np.arange(18) 67 | 68 | for key in catwise_sel: 69 | catwise_sel[key] = torch.stack(catwise_sel[key]).mean(dim=0).cpu().numpy() 70 | 71 | category = np.array(list(catwise_sel.keys()))[sort_ids] 72 | values = np.array(list(catwise_sel.values()))[sort_ids] 73 | cols = list(range(values.shape[1])) 74 | df = pd.DataFrame(values, index=category, columns=cols) 75 | 76 | plt.pcolor(df, vmin=0.0, vmax=1.0) 77 | plt.gca().set_aspect('equal') 78 | plt.yticks(np.arange(0.5, len(df.index), 1), df.index) 79 | plt.xticks(np.arange(0.5, len(df.columns), 5), df.columns[::5]) 80 | plt.tight_layout() 81 | 82 | plt.savefig('%s/selected_layers.jpg' % cls.logpath) 83 | 84 | 85 | class AverageMeter: 86 | r"""Stores loss, evaluation results, selected layers""" 87 | def __init__(self, benchamrk): 88 | r"""Constructor of AverageMeter""" 89 | if benchamrk == 'caltech': 90 | self.buffer_keys = ['ltacc', 'iou'] 91 | else: 92 | self.buffer_keys = ['pck'] 93 | 94 | self.buffer = {} 95 | for key in self.buffer_keys: 96 | self.buffer[key] = [] 97 | self.sel_buffer = {} 98 | 99 | self.loss_buffer = [] 100 | 101 | def update(self, eval_result, layer_sel, category, loss=None): 102 | for key in self.buffer_keys: 103 | self.buffer[key] += eval_result[key] 104 | 105 | for sel, cls in zip(layer_sel, category): 106 | if self.sel_buffer.get(cls) is None: 107 | self.sel_buffer[cls] = [] 108 | self.sel_buffer[cls] += [sel] 109 | 110 | if loss is not None: 111 | self.loss_buffer.append(loss) 112 | 113 | def write_result(self, split, epoch=-1): 114 | msg = '\n*** %s ' % split 115 | msg += '[@Epoch %02d] ' % epoch if epoch > -1 else '' 116 | 117 | if len(self.loss_buffer) > 0: 118 | msg += 'Loss: %5.2f ' % (sum(self.loss_buffer) / len(self.loss_buffer)) 119 | 120 | for key in self.buffer_keys: 121 | msg += '%s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key])) 122 | msg += '***\n' 123 | Logger.info(msg) 124 | 125 | def write_process(self, batch_idx, datalen, epoch=-1): 126 | msg = '[Epoch: %02d] ' % epoch if epoch > -1 else '' 127 | msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) 128 | if len(self.loss_buffer) > 0: 129 | msg += 'Loss: %6.2f ' % self.loss_buffer[-1] 130 | msg += 'Avg Loss: %6.5f ' % (sum(self.loss_buffer) / len(self.loss_buffer)) 131 | 132 | for key in self.buffer_keys: 133 | msg += 'Avg %s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key])) 134 | Logger.info(msg) 135 | -------------------------------------------------------------------------------- /common/supervision.py: -------------------------------------------------------------------------------- 1 | r"""Two different strategies of weak/strong supervisions""" 2 | from abc import ABC, abstractmethod 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from model.objective import Objective 8 | 9 | 10 | class SupervisionStrategy(ABC): 11 | r"""Different strategies for methods:""" 12 | @abstractmethod 13 | def get_image_pair(self, batch, *args): 14 | pass 15 | 16 | @abstractmethod 17 | def get_correlation(self, correlation_matrix): 18 | pass 19 | 20 | @abstractmethod 21 | def compute_loss(self, correlation_matrix, *args): 22 | pass 23 | 24 | 25 | class StrongSupStrategy(SupervisionStrategy): 26 | def get_image_pair(self, batch, *args): 27 | r"""Returns (semantically related) pairs for strongly-supervised training""" 28 | return batch['src_img'], batch['trg_img'] 29 | 30 | def get_correlation(self, correlation_matrix): 31 | r"""Returns correlation matrices of 'ALL PAIRS' in a batch""" 32 | return correlation_matrix.clone().detach() 33 | 34 | def compute_loss(self, correlation_matrix, *args): 35 | r"""Strongly-supervised matching loss (L_{match})""" 36 | easy_match = args[0]['easy_match'] 37 | hard_match = args[0]['hard_match'] 38 | layer_sel = args[1] 39 | batch = args[2] 40 | 41 | loss_cre = Objective.weighted_cross_entropy(correlation_matrix, easy_match, hard_match, batch) 42 | loss_sel = Objective.layer_selection_loss(layer_sel) 43 | loss_net = loss_cre + loss_sel 44 | 45 | return loss_net 46 | 47 | 48 | class WeakSupStrategy(SupervisionStrategy): 49 | def get_image_pair(self, batch, *args): 50 | r"""Forms positive/negative image paris for weakly-supervised training""" 51 | training = args[0] 52 | self.bsz = len(batch['src_img']) 53 | 54 | if training: 55 | shifted_idx = np.roll(np.arange(self.bsz), -1) 56 | trg_img_neg = batch['trg_img'][shifted_idx].clone() 57 | trg_cls_neg = batch['category_id'][shifted_idx].clone() 58 | neg_subidx = (batch['category_id'] - trg_cls_neg) != 0 59 | 60 | src_img = torch.cat([batch['src_img'], batch['src_img'][neg_subidx]], dim=0) 61 | trg_img = torch.cat([batch['trg_img'], trg_img_neg[neg_subidx]], dim=0) 62 | self.num_negatives = neg_subidx.sum() 63 | else: 64 | src_img, trg_img = batch['src_img'], batch['trg_img'] 65 | self.num_negatives = 0 66 | 67 | return src_img, trg_img 68 | 69 | def get_correlation(self, correlation_matrix): 70 | r"""Returns correlation matrices of 'POSITIVE PAIRS' in a batch""" 71 | return correlation_matrix[:self.bsz].clone().detach() 72 | 73 | def compute_loss(self, correlation_matrix, *args): 74 | r"""Weakly-supervised matching loss (L_{match})""" 75 | layer_sel = args[1] 76 | loss_pos = Objective.information_entropy(correlation_matrix[:self.bsz]) 77 | loss_neg = Objective.information_entropy(correlation_matrix[self.bsz:]) if self.num_negatives > 0 else 1.0 78 | loss_sel = Objective.layer_selection_loss(layer_sel) 79 | loss_net = (loss_pos / loss_neg) + loss_sel 80 | 81 | return loss_net 82 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | r"""Some helper functions""" 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def fix_randseed(seed): 9 | r"""Fixes random seed for reproducibility""" 10 | random.seed(seed) 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | torch.backends.cudnn.benchmark = False 16 | torch.backends.cudnn.deterministic = True 17 | 18 | 19 | def mean(x): 20 | r"""Computes average of a list""" 21 | return sum(x) / len(x) if len(x) > 0 else 0.0 22 | 23 | 24 | def where(predicate): 25 | r"""Predicate must be a condition on nd-tensor""" 26 | matching_indices = predicate.nonzero() 27 | if len(matching_indices) != 0: 28 | matching_indices = matching_indices.t().squeeze(0) 29 | return matching_indices 30 | -------------------------------------------------------------------------------- /data/caltech.py: -------------------------------------------------------------------------------- 1 | r"""Caltech-101 dataset""" 2 | import os 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import torch 7 | 8 | from .dataset import CorrespondenceDataset 9 | 10 | 11 | class CaltechDataset(CorrespondenceDataset): 12 | r"""Inherits CorrespondenceDataset""" 13 | def __init__(self, benchmark, datapath, thres, device, split): 14 | r"""Caltech-101 dataset constructor""" 15 | super(CaltechDataset, self).__init__(benchmark, datapath, thres, device, split) 16 | 17 | self.train_data = pd.read_csv(self.spt_path) 18 | self.src_imnames = np.array(self.train_data.iloc[:, 0]) 19 | self.trg_imnames = np.array(self.train_data.iloc[:, 1]) 20 | self.src_kps = self.train_data.iloc[:, 3:5] 21 | self.trg_kps = self.train_data.iloc[:, 5:] 22 | self.cls = ['Faces', 'Faces_easy', 'Leopards', 'Motorbikes', 'accordion', 'airplanes', 23 | 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 24 | 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 25 | 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 26 | 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 27 | 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar', 28 | 'elephant', 'emu', 'euphonium', 'ewer', 'ferry', 'flamingo', 'flamingo_head', 29 | 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 30 | 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 31 | 'ketch', 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 32 | 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi', 'pagoda', 33 | 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 34 | 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 35 | 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 36 | 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 37 | 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang'] 38 | self.cls_ids = self.train_data.iloc[:, 2].values.astype('int') - 1 39 | self.src_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.src_imnames)) 40 | self.trg_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.trg_imnames)) 41 | 42 | def __getitem__(self, idx): 43 | r"""Constructs and returns a batch for Caltech-101 dataset""" 44 | return super(CaltechDataset, self).__getitem__(idx) 45 | 46 | def get_pckthres(self, batch): 47 | r"""No PCK measure for Caltech-101 dataset""" 48 | return None 49 | 50 | def get_points(self, pts, idx, org_imsize): 51 | r"""Return mask-points of an image""" 52 | x_pts = torch.tensor(list(map(lambda pt: float(pt), pts[pts.columns[0]][idx].split(',')))) 53 | y_pts = torch.tensor(list(map(lambda pt: float(pt), pts[pts.columns[1]][idx].split(',')))) 54 | 55 | x_pts *= (self.imside / org_imsize[0]) 56 | y_pts *= (self.imside / org_imsize[1]) 57 | 58 | n_pts = x_pts.size(0) 59 | if n_pts > self.max_pts: 60 | raise Exception('The number of keypoints is above threshold: %d' % n_pts) 61 | pad_pts = torch.zeros((2, self.max_pts - n_pts)) - 1 62 | 63 | kps = torch.cat([torch.stack([x_pts, y_pts]), pad_pts], dim=1).to(self.device) 64 | 65 | return kps, n_pts 66 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | r"""Superclass for semantic correspondence datasets""" 2 | import os 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from PIL import Image 7 | import torch 8 | 9 | from model.base.geometry import Geometry 10 | 11 | 12 | class CorrespondenceDataset(Dataset): 13 | r"""Parent class of PFPascal, PFWillow, Caltech, and SPair""" 14 | def __init__(self, benchmark, datapath, thres, device, split): 15 | r"""CorrespondenceDataset constructor""" 16 | super(CorrespondenceDataset, self).__init__() 17 | 18 | # {Directory name, Layout path, Image path, Annotation path, PCK threshold} 19 | self.metadata = { 20 | 'pfwillow': ('PF-WILLOW', 21 | 'test_pairs.csv', 22 | '', 23 | '', 24 | 'bbox'), 25 | 'pfpascal': ('PF-PASCAL', 26 | '_pairs.csv', 27 | 'JPEGImages', 28 | 'Annotations', 29 | 'img'), 30 | 'caltech': ('Caltech-101', 31 | 'test_pairs_caltech_with_category.csv', 32 | '101_ObjectCategories', 33 | '', 34 | ''), 35 | 'spair': ('SPair-71k', 36 | 'Layout/large', 37 | 'JPEGImages', 38 | 'PairAnnotation', 39 | 'bbox') 40 | } 41 | 42 | # Directory path for train, val, or test splits 43 | base_path = os.path.join(os.path.abspath(datapath), self.metadata[benchmark][0]) 44 | if benchmark == 'pfpascal': 45 | self.spt_path = os.path.join(base_path, split+'_pairs.csv') 46 | elif benchmark == 'spair': 47 | self.spt_path = os.path.join(base_path, self.metadata[benchmark][1], split+'.txt') 48 | else: 49 | self.spt_path = os.path.join(base_path, self.metadata[benchmark][1]) 50 | 51 | # Directory path for images 52 | self.img_path = os.path.join(base_path, self.metadata[benchmark][2]) 53 | 54 | # Directory path for annotations 55 | if benchmark == 'spair': 56 | self.ann_path = os.path.join(base_path, self.metadata[benchmark][3], split) 57 | else: 58 | self.ann_path = os.path.join(base_path, self.metadata[benchmark][3]) 59 | 60 | # Miscellaneous 61 | if benchmark == 'caltech': 62 | self.max_pts = 400 63 | else: 64 | self.max_pts = 40 65 | self.split = split 66 | self.device = device 67 | self.imside = 240 68 | self.benchmark = benchmark 69 | self.range_ts = torch.arange(self.max_pts) 70 | self.thres = self.metadata[benchmark][4] if thres == 'auto' else thres 71 | self.transform = transforms.Compose([transforms.Resize((self.imside, self.imside)), 72 | transforms.ToTensor(), 73 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 74 | std=[0.229, 0.224, 0.225])]) 75 | 76 | # To get initialized in subclass constructors 77 | self.train_data = [] 78 | self.src_imnames = [] 79 | self.trg_imnames = [] 80 | self.cls = [] 81 | self.cls_ids = [] 82 | self.src_kps = [] 83 | self.trg_kps = [] 84 | 85 | def __len__(self): 86 | r"""Returns the number of pairs""" 87 | return len(self.train_data) 88 | 89 | def __getitem__(self, idx): 90 | r"""Constructs and return a batch""" 91 | 92 | # Image names 93 | batch = dict() 94 | batch['src_imname'] = self.src_imnames[idx] 95 | batch['trg_imname'] = self.trg_imnames[idx] 96 | 97 | # Class of instances in the images 98 | batch['category_id'] = self.cls_ids[idx] 99 | batch['category'] = self.cls[batch['category_id']] 100 | 101 | # Image as numpy (original width, original height) 102 | src_pil = self.get_image(self.src_imnames, idx) 103 | trg_pil = self.get_image(self.trg_imnames, idx) 104 | batch['src_imsize'] = src_pil.size 105 | batch['trg_imsize'] = trg_pil.size 106 | 107 | # Image as tensor 108 | batch['src_img'] = self.transform(src_pil).to(self.device) 109 | batch['trg_img'] = self.transform(trg_pil).to(self.device) 110 | 111 | # Key-points (re-scaled) 112 | batch['src_kps'], num_pts = self.get_points(self.src_kps, idx, src_pil.size) 113 | batch['trg_kps'], _ = self.get_points(self.trg_kps, idx, trg_pil.size) 114 | batch['n_pts'] = torch.tensor(num_pts) 115 | 116 | # The number of pairs in training split 117 | batch['datalen'] = len(self.train_data) 118 | 119 | return batch 120 | 121 | def get_image(self, imnames, idx): 122 | r"""Reads PIL image from path""" 123 | path = os.path.join(self.img_path, imnames[idx]) 124 | return Image.open(path).convert('RGB') 125 | 126 | def get_pckthres(self, batch, imsize): 127 | r"""Computes PCK threshold""" 128 | if self.thres == 'bbox': 129 | bbox = batch['trg_bbox'].clone() 130 | bbox_w = (bbox[2] - bbox[0]) 131 | bbox_h = (bbox[3] - bbox[1]) 132 | pckthres = torch.max(bbox_w, bbox_h) 133 | elif self.thres == 'img': 134 | imsize_t = batch['trg_img'].size() 135 | pckthres = torch.tensor(max(imsize_t[1], imsize_t[2])) 136 | else: 137 | raise Exception('Invalid pck threshold type: %s' % self.thres) 138 | return pckthres.float().to(self.device) 139 | 140 | def get_points(self, pts_list, idx, org_imsize): 141 | r"""Returns key-points of an image with size of (240,240)""" 142 | xy, n_pts = pts_list[idx].size() 143 | pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 1 144 | x_crds = pts_list[idx][0] * (self.imside / org_imsize[0]) 145 | y_crds = pts_list[idx][1] * (self.imside / org_imsize[1]) 146 | kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1).to(self.device) 147 | 148 | return kps, n_pts 149 | 150 | def match_idx(self, kps, n_pts): 151 | r"""Samples the nearst feature (receptive field) indices""" 152 | nearest_idx = find_knn(Geometry.rf_center, kps.t()) 153 | nearest_idx -= (self.range_ts >= n_pts).to(self.device).long() 154 | 155 | return nearest_idx 156 | 157 | 158 | def find_knn(db_vectors, qr_vectors): 159 | r"""Finds K-nearest neighbors (Euclidean distance)""" 160 | db = db_vectors.unsqueeze(1).repeat(1, qr_vectors.size(0), 1) 161 | qr = qr_vectors.unsqueeze(0).repeat(db_vectors.size(0), 1, 1) 162 | dist = (db - qr).pow(2).sum(2).pow(0.5).t() 163 | _, nearest_idx = dist.min(dim=1) 164 | 165 | return nearest_idx 166 | -------------------------------------------------------------------------------- /data/download.py: -------------------------------------------------------------------------------- 1 | r"""Functions to download semantic correspondence datasets""" 2 | import tarfile 3 | import os 4 | 5 | import requests 6 | 7 | from . import pfpascal 8 | from . import pfwillow 9 | from . import caltech 10 | from . import spair 11 | 12 | 13 | def load_dataset(benchmark, datapath, thres, device, split='test'): 14 | r"""Instantiates desired correspondence dataset""" 15 | correspondence_benchmark = { 16 | 'pfpascal': pfpascal.PFPascalDataset, 17 | 'pfwillow': pfwillow.PFWillowDataset, 18 | 'caltech': caltech.CaltechDataset, 19 | 'spair': spair.SPairDataset, 20 | } 21 | 22 | dataset = correspondence_benchmark.get(benchmark) 23 | if dataset is None: 24 | raise Exception('Invalid benchmark dataset %s.' % benchmark) 25 | 26 | return dataset(benchmark, datapath, thres, device, split) 27 | 28 | 29 | def download_from_google(token_id, filename): 30 | r"""Downloads desired filename from Google drive""" 31 | print('Downloading %s ...' % os.path.basename(filename)) 32 | 33 | url = 'https://docs.google.com/uc?export=download' 34 | destination = filename + '.tar.gz' 35 | session = requests.Session() 36 | 37 | response = session.get(url, params={'id': token_id, 'confirm':'t'}, stream=True) 38 | token = get_confirm_token(response) 39 | 40 | if token: 41 | params = {'id': token_id, 'confirm': token} 42 | response = session.get(url, params=params, stream=True) 43 | save_response_content(response, destination) 44 | file = tarfile.open(destination, 'r:gz') 45 | 46 | print("Extracting %s ..." % destination) 47 | file.extractall(filename) 48 | file.close() 49 | 50 | os.remove(destination) 51 | os.rename(filename, filename + '_tmp') 52 | os.rename(os.path.join(filename + '_tmp', os.path.basename(filename)), filename) 53 | os.rmdir(filename+'_tmp') 54 | 55 | 56 | def get_confirm_token(response): 57 | r"""Retrieves confirm token""" 58 | for key, value in response.cookies.items(): 59 | if key.startswith('download_warning'): 60 | return value 61 | 62 | return None 63 | 64 | 65 | def save_response_content(response, destination): 66 | r"""Saves the response to the destination""" 67 | chunk_size = 32768 68 | 69 | with open(destination, "wb") as file: 70 | for chunk in response.iter_content(chunk_size): 71 | if chunk: 72 | file.write(chunk) 73 | 74 | 75 | def download_dataset(datapath, benchmark): 76 | r"""Downloads semantic correspondence benchmark dataset from Google drive""" 77 | if not os.path.isdir(datapath): 78 | os.mkdir(datapath) 79 | 80 | file_data = { 81 | 'pfwillow': ('1tDP0y8RO5s45L-vqnortRaieiWENQco_', 'PF-WILLOW'), 82 | 'pfpascal': ('1OOwpGzJnTsFXYh-YffMQ9XKM_Kl_zdzg', 'PF-PASCAL'), 83 | 'caltech': ('1IV0E5sJ6xSdDyIvVSTdZjPHELMwGzsMn', 'Caltech-101'), 84 | 'spair': ('1KSvB0k2zXA06ojWNvFjBv0Ake426Y76k', 'SPair-71k') 85 | # 'spair': ('1s73NVEFPro260H1tXxCh1ain7oApR8of', 'SPair-71k') old version 86 | } 87 | 88 | file_id, filename = file_data[benchmark] 89 | abs_filepath = os.path.join(datapath, filename) 90 | 91 | if not os.path.isdir(abs_filepath): 92 | download_from_google(file_id, abs_filepath) 93 | -------------------------------------------------------------------------------- /data/pfpascal.py: -------------------------------------------------------------------------------- 1 | r"""PF-PASCAL dataset""" 2 | import os 3 | 4 | import scipy.io as sio 5 | import pandas as pd 6 | import numpy as np 7 | import torch 8 | 9 | from .dataset import CorrespondenceDataset 10 | 11 | 12 | class PFPascalDataset(CorrespondenceDataset): 13 | r"""Inherits CorrespondenceDataset""" 14 | def __init__(self, benchmark, datapath, thres, device, split): 15 | r"""PF-PASCAL dataset constructor""" 16 | super(PFPascalDataset, self).__init__(benchmark, datapath, thres, device, split) 17 | 18 | self.train_data = pd.read_csv(self.spt_path) 19 | self.src_imnames = np.array(self.train_data.iloc[:, 0]) 20 | self.trg_imnames = np.array(self.train_data.iloc[:, 1]) 21 | self.cls = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 22 | 'bus', 'car', 'cat', 'chair', 'cow', 23 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 24 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 25 | self.cls_ids = self.train_data.iloc[:, 2].values.astype('int') - 1 26 | 27 | if split == 'trn': 28 | self.flip = self.train_data.iloc[:, 3].values.astype('int') 29 | self.src_kps = [] 30 | self.trg_kps = [] 31 | self.src_bbox = [] 32 | self.trg_bbox = [] 33 | for src_imname, trg_imname, cls in zip(self.src_imnames, self.trg_imnames, self.cls_ids): 34 | src_anns = os.path.join(self.ann_path, self.cls[cls], 35 | os.path.basename(src_imname))[:-4] + '.mat' 36 | trg_anns = os.path.join(self.ann_path, self.cls[cls], 37 | os.path.basename(trg_imname))[:-4] + '.mat' 38 | 39 | src_kp = torch.tensor(read_mat(src_anns, 'kps')).float() 40 | trg_kp = torch.tensor(read_mat(trg_anns, 'kps')).float() 41 | src_box = torch.tensor(read_mat(src_anns, 'bbox')[0].astype(float)) 42 | trg_box = torch.tensor(read_mat(trg_anns, 'bbox')[0].astype(float)) 43 | 44 | src_kps = [] 45 | trg_kps = [] 46 | for src_kk, trg_kk in zip(src_kp, trg_kp): 47 | if len(torch.isnan(src_kk).nonzero()) != 0 or \ 48 | len(torch.isnan(trg_kk).nonzero()) != 0: 49 | continue 50 | else: 51 | src_kps.append(src_kk) 52 | trg_kps.append(trg_kk) 53 | self.src_kps.append(torch.stack(src_kps).t()) 54 | self.trg_kps.append(torch.stack(trg_kps).t()) 55 | self.src_bbox.append(src_box) 56 | self.trg_bbox.append(trg_box) 57 | 58 | self.src_imnames = list(map(lambda x: os.path.basename(x), self.src_imnames)) 59 | self.trg_imnames = list(map(lambda x: os.path.basename(x), self.trg_imnames)) 60 | 61 | def __getitem__(self, idx): 62 | r"""Constructs and return a batch for PF-PASCAL dataset""" 63 | batch = super(PFPascalDataset, self).__getitem__(idx) 64 | 65 | # Object bounding-box (resized following self.imside) 66 | batch['src_bbox'] = self.get_bbox(self.src_bbox, idx, batch['src_imsize']) 67 | batch['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, batch['trg_imsize']) 68 | batch['pckthres'] = self.get_pckthres(batch, batch['trg_imsize']) 69 | 70 | # Horizontal flipping key-points during training 71 | if self.split == 'trn' and self.flip[idx]: 72 | self.horizontal_flip(batch) 73 | batch['flip'] = 1 74 | else: 75 | batch['flip'] = 0 76 | 77 | batch['src_kpidx'] = self.match_idx(batch['src_kps'], batch['n_pts']) 78 | batch['trg_kpidx'] = self.match_idx(batch['trg_kps'], batch['n_pts']) 79 | 80 | return batch 81 | 82 | def get_bbox(self, bbox_list, idx, imsize): 83 | r"""Returns object bounding-box""" 84 | bbox = bbox_list[idx].clone() 85 | bbox[0::2] *= (self.imside / imsize[0]) 86 | bbox[1::2] *= (self.imside / imsize[1]) 87 | return bbox.to(self.device) 88 | 89 | def horizontal_flip(self, batch): 90 | tmp = batch['src_bbox'][0].clone() 91 | batch['src_bbox'][0] = batch['src_img'].size(2) - batch['src_bbox'][2] 92 | batch['src_bbox'][2] = batch['src_img'].size(2) - tmp 93 | 94 | tmp = batch['trg_bbox'][0].clone() 95 | batch['trg_bbox'][0] = batch['trg_img'].size(2) - batch['trg_bbox'][2] 96 | batch['trg_bbox'][2] = batch['trg_img'].size(2) - tmp 97 | 98 | batch['src_kps'][0][:batch['n_pts']] = batch['src_img'].size(2) - batch['src_kps'][0][:batch['n_pts']] 99 | batch['trg_kps'][0][:batch['n_pts']] = batch['trg_img'].size(2) - batch['trg_kps'][0][:batch['n_pts']] 100 | 101 | batch['src_img'] = torch.flip(batch['src_img'], dims=(2,)) 102 | batch['trg_img'] = torch.flip(batch['trg_img'], dims=(2,)) 103 | 104 | 105 | def read_mat(path, obj_name): 106 | r"""Reads specified objects from Matlab data file, (.mat)""" 107 | mat_contents = sio.loadmat(path) 108 | mat_obj = mat_contents[obj_name] 109 | 110 | return mat_obj 111 | -------------------------------------------------------------------------------- /data/pfwillow.py: -------------------------------------------------------------------------------- 1 | r"""PF-WILLOW dataset""" 2 | import os 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import torch 7 | 8 | from .dataset import CorrespondenceDataset 9 | 10 | 11 | class PFWillowDataset(CorrespondenceDataset): 12 | r"""Inherits CorrespondenceDataset""" 13 | def __init__(self, benchmark, datapath, thres, device, split): 14 | r"""PF-WILLOW dataset constructor""" 15 | super(PFWillowDataset, self).__init__(benchmark, datapath, thres, device, split) 16 | 17 | self.train_data = pd.read_csv(self.spt_path) 18 | self.src_imnames = np.array(self.train_data.iloc[:, 0]) 19 | self.trg_imnames = np.array(self.train_data.iloc[:, 1]) 20 | self.src_kps = self.train_data.iloc[:, 2:22].values 21 | self.trg_kps = self.train_data.iloc[:, 22:].values 22 | self.cls = ['car(G)', 'car(M)', 'car(S)', 'duck(S)', 23 | 'motorbike(G)', 'motorbike(M)', 'motorbike(S)', 24 | 'winebottle(M)', 'winebottle(wC)', 'winebottle(woC)'] 25 | self.cls_ids = list(map(lambda names: self.cls.index(names.split('/')[1]), self.src_imnames)) 26 | self.src_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.src_imnames)) 27 | self.trg_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.trg_imnames)) 28 | 29 | def __getitem__(self, idx): 30 | r"""Constructs and return a batch for PF-WILLOW dataset""" 31 | batch = super(PFWillowDataset, self).__getitem__(idx) 32 | batch['pckthres'] = self.get_pckthres(batch).to(self.device) 33 | 34 | batch['src_kpidx'] = self.match_idx(batch['src_kps'], batch['n_pts']) 35 | batch['trg_kpidx'] = self.match_idx(batch['trg_kps'], batch['n_pts']) 36 | 37 | return batch 38 | 39 | def get_pckthres(self, batch): 40 | r"""Computes PCK threshold""" 41 | if self.thres == 'bbox': 42 | return max(batch['trg_kps'].max(1)[0] - batch['trg_kps'].min(1)[0]).clone() 43 | elif self.thres == 'img': 44 | return torch.tensor(max(batch['trg_img'].size()[1], batch['trg_img'].size()[2])) 45 | else: 46 | raise Exception('Invalid pck evaluation level: %s' % self.thres) 47 | 48 | def get_points(self, pts_list, idx, org_imsize): 49 | r"""Returns key-points of an image""" 50 | point_coords = pts_list[idx, :].reshape(2, 10) 51 | point_coords = torch.tensor(point_coords.astype(np.float32)) 52 | xy, n_pts = point_coords.size() 53 | pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 1 54 | x_crds = point_coords[0] * (self.imside / org_imsize[0]) 55 | y_crds = point_coords[1] * (self.imside / org_imsize[1]) 56 | kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1).to(self.device) 57 | 58 | return kps, n_pts 59 | -------------------------------------------------------------------------------- /data/spair.py: -------------------------------------------------------------------------------- 1 | r"""SPair-71k dataset""" 2 | import json 3 | import glob 4 | import os 5 | 6 | from PIL import Image 7 | import torch 8 | 9 | from .dataset import CorrespondenceDataset 10 | 11 | 12 | class SPairDataset(CorrespondenceDataset): 13 | r"""Inherits CorrespondenceDataset""" 14 | def __init__(self, benchmark, datapath, thres, device, split): 15 | r"""SPair-71k dataset constructor""" 16 | super(SPairDataset, self).__init__(benchmark, datapath, thres, device, split) 17 | 18 | self.train_data = open(self.spt_path).read().split('\n') 19 | self.train_data = self.train_data[:len(self.train_data) - 1] 20 | self.src_imnames = list(map(lambda x: x.split('-')[1] + '.jpg', self.train_data)) 21 | self.trg_imnames = list(map(lambda x: x.split('-')[2].split(':')[0] + '.jpg', self.train_data)) 22 | self.cls = os.listdir(self.img_path) 23 | self.cls.sort() 24 | 25 | anntn_files = [] 26 | for data_name in self.train_data: 27 | anntn_files.append(glob.glob('%s/%s.json' % (self.ann_path, data_name))[0]) 28 | anntn_files = list(map(lambda x: json.load(open(x)), anntn_files)) 29 | self.src_kps = list(map(lambda x: torch.tensor(x['src_kps']).t().float(), anntn_files)) 30 | self.trg_kps = list(map(lambda x: torch.tensor(x['trg_kps']).t().float(), anntn_files)) 31 | self.src_bbox = list(map(lambda x: torch.tensor(x['src_bndbox']).float(), anntn_files)) 32 | self.trg_bbox = list(map(lambda x: torch.tensor(x['trg_bndbox']).float(), anntn_files)) 33 | self.cls_ids = list(map(lambda x: self.cls.index(x['category']), anntn_files)) 34 | 35 | self.vpvar = list(map(lambda x: torch.tensor(x['viewpoint_variation']), anntn_files)) 36 | self.scvar = list(map(lambda x: torch.tensor(x['scale_variation']), anntn_files)) 37 | self.trncn = list(map(lambda x: torch.tensor(x['truncation']), anntn_files)) 38 | self.occln = list(map(lambda x: torch.tensor(x['occlusion']), anntn_files)) 39 | 40 | def __getitem__(self, idx): 41 | r"""Constructs and return a batch for SPair-71k dataset""" 42 | batch = super(SPairDataset, self).__getitem__(idx) 43 | 44 | batch['src_bbox'] = self.get_bbox(self.src_bbox, idx, batch['src_imsize']) 45 | batch['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, batch['trg_imsize']) 46 | batch['pckthres'] = self.get_pckthres(batch, batch['trg_imsize']) 47 | 48 | batch['src_kpidx'] = self.match_idx(batch['src_kps'], batch['n_pts']) 49 | batch['trg_kpidx'] = self.match_idx(batch['trg_kps'], batch['n_pts']) 50 | 51 | batch['vpvar'] = self.vpvar[idx] 52 | batch['scvar'] = self.scvar[idx] 53 | batch['trncn'] = self.trncn[idx] 54 | batch['occln'] = self.occln[idx] 55 | 56 | return batch 57 | 58 | def get_image(self, img_names, idx): 59 | r"""Returns image tensor""" 60 | path = os.path.join(self.img_path, self.cls[self.cls_ids[idx]], img_names[idx]) 61 | 62 | return Image.open(path).convert('RGB') 63 | 64 | def get_bbox(self, bbox_list, idx, imsize): 65 | r"""Returns object bounding-box""" 66 | bbox = bbox_list[idx].clone() 67 | bbox[0::2] *= (self.imside / imsize[0]) 68 | bbox[1::2] *= (self.imside / imsize[1]) 69 | return bbox.to(self.device) 70 | -------------------------------------------------------------------------------- /model/base/correlation.py: -------------------------------------------------------------------------------- 1 | r"""Provides functions that creates/manipulates correlation matrices""" 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class Correlation: 7 | @classmethod 8 | def bmm_interp(cls, src_feat, trg_feat, interp_size): 9 | r"""Performs batch-wise matrix-multiplication after interpolation""" 10 | src_feat = F.interpolate(src_feat, interp_size, mode='bilinear', align_corners=True) 11 | trg_feat = F.interpolate(trg_feat, interp_size, mode='bilinear', align_corners=True) 12 | 13 | src_feat = src_feat.view(src_feat.size(0), src_feat.size(1), -1).transpose(1, 2) 14 | trg_feat = trg_feat.view(trg_feat.size(0), trg_feat.size(1), -1) 15 | 16 | return torch.bmm(src_feat, trg_feat) 17 | 18 | @classmethod 19 | def mutual_nn_filter(cls, correlation_matrix): 20 | r"""Mutual nearest neighbor filtering (Rocco et al. NeurIPS'18)""" 21 | corr_src_max = torch.max(correlation_matrix, dim=2, keepdim=True)[0] 22 | corr_trg_max = torch.max(correlation_matrix, dim=1, keepdim=True)[0] 23 | corr_src_max[corr_src_max == 0] += 1e-30 24 | corr_trg_max[corr_trg_max == 0] += 1e-30 25 | 26 | corr_src = correlation_matrix / corr_src_max 27 | corr_trg = correlation_matrix / corr_trg_max 28 | 29 | return correlation_matrix * (corr_src * corr_trg) 30 | -------------------------------------------------------------------------------- /model/base/geometry.py: -------------------------------------------------------------------------------- 1 | """Provides functions that manipulate boxes and points""" 2 | import torch 3 | 4 | from .correlation import Correlation 5 | 6 | 7 | class Geometry: 8 | @classmethod 9 | def initialize(cls, feat_size, device): 10 | cls.max_pts = 400 11 | cls.eps = 1e-30 12 | cls.rfs = cls.receptive_fields(11, 4, feat_size).to(device) 13 | cls.rf_center = Geometry.center(cls.rfs) 14 | 15 | @classmethod 16 | def center(cls, box): 17 | r"""Computes centers, (x, y), of box (N, 4)""" 18 | x_center = box[:, 0] + (box[:, 2] - box[:, 0]) // 2 19 | y_center = box[:, 1] + (box[:, 3] - box[:, 1]) // 2 20 | return torch.stack((x_center, y_center)).t().to(box.device) 21 | 22 | @classmethod 23 | def receptive_fields(cls, rfsz, jsz, feat_size): 24 | r"""Returns a set of receptive fields (N, 4)""" 25 | width = feat_size[1] 26 | height = feat_size[0] 27 | 28 | feat_ids = torch.tensor(list(range(width))).repeat(1, height).t().repeat(1, 2) 29 | feat_ids[:, 0] = torch.tensor(list(range(height))).unsqueeze(1).repeat(1, width).view(-1) 30 | 31 | box = torch.zeros(feat_ids.size()[0], 4) 32 | box[:, 0] = feat_ids[:, 1] * jsz - rfsz // 2 33 | box[:, 1] = feat_ids[:, 0] * jsz - rfsz // 2 34 | box[:, 2] = feat_ids[:, 1] * jsz + rfsz // 2 35 | box[:, 3] = feat_ids[:, 0] * jsz + rfsz // 2 36 | 37 | return box 38 | 39 | @classmethod 40 | def gaussian2d(cls, side=7): 41 | r"""Returns 2-dimensional gaussian filter""" 42 | dim = [side, side] 43 | 44 | siz = torch.LongTensor(dim) 45 | sig_sq = (siz.float()/2/2.354).pow(2) 46 | siz2 = (siz-1)/2 47 | 48 | x_axis = torch.arange(-siz2[0], siz2[0] + 1).unsqueeze(0).expand(dim).float() 49 | y_axis = torch.arange(-siz2[1], siz2[1] + 1).unsqueeze(1).expand(dim).float() 50 | 51 | gaussian = torch.exp(-(x_axis.pow(2)/2/sig_sq[0] + y_axis.pow(2)/2/sig_sq[1])) 52 | gaussian = gaussian / gaussian.sum() 53 | 54 | return gaussian 55 | 56 | @classmethod 57 | def neighbours(cls, box, kps): 58 | r"""Returns boxes in one-hot format that covers given keypoints""" 59 | box_duplicate = box.unsqueeze(2).repeat(1, 1, len(kps.t())).transpose(0, 1) 60 | kps_duplicate = kps.unsqueeze(1).repeat(1, len(box), 1) 61 | 62 | xmin = kps_duplicate[0].ge(box_duplicate[0]) 63 | ymin = kps_duplicate[1].ge(box_duplicate[1]) 64 | xmax = kps_duplicate[0].le(box_duplicate[2]) 65 | ymax = kps_duplicate[1].le(box_duplicate[3]) 66 | 67 | nbr_onehot = torch.mul(torch.mul(xmin, ymin), torch.mul(xmax, ymax)).t() 68 | n_neighbours = nbr_onehot.sum(dim=1) 69 | 70 | return nbr_onehot, n_neighbours 71 | 72 | @classmethod 73 | def transfer_kps(cls, correlation_matrix, kps, n_pts): 74 | r"""Transfer keypoints by nearest-neighbour assignment""" 75 | correlation_matrix = Correlation.mutual_nn_filter(correlation_matrix) 76 | 77 | prd_kps = [] 78 | for ct, kpss, np in zip(correlation_matrix, kps, n_pts): 79 | 80 | # 1. Prepare geometries & argmax target indices 81 | kp = kpss.narrow_copy(1, 0, np) 82 | _, trg_argmax_idx = torch.max(ct, dim=1) 83 | geomet = cls.rfs[:, :2].unsqueeze(0).repeat(len(kp.t()), 1, 1) 84 | 85 | # 2. Retrieve neighbouring source boxes that cover source key-points 86 | src_nbr_onehot, n_neighbours = cls.neighbours(cls.rfs, kp) 87 | 88 | # 3. Get displacements from source neighbouring box centers to each key-point 89 | src_displacements = kp.t().unsqueeze(1).repeat(1, len(cls.rfs), 1) - geomet 90 | src_displacements = src_displacements * src_nbr_onehot.unsqueeze(2).repeat(1, 1, 2).float() 91 | 92 | # 4. Transfer the neighbours based on given correlation matrix 93 | vector_summator = torch.zeros_like(geomet) 94 | src_idx = src_nbr_onehot.nonzero() 95 | 96 | trg_idx = trg_argmax_idx.index_select(dim=0, index=src_idx[:, 1]) 97 | vector_summator[src_idx[:, 0], src_idx[:, 1]] = geomet[src_idx[:, 0], trg_idx] 98 | vector_summator += src_displacements 99 | prd = (vector_summator.sum(dim=1) / n_neighbours.unsqueeze(1).repeat(1, 2).float()).t() 100 | 101 | # 5. Concatenate pad-points 102 | pads = (torch.zeros((2, cls.max_pts - np)).to(prd.device) - 1) 103 | prd = torch.cat([prd, pads], dim=1) 104 | prd_kps.append(prd) 105 | 106 | return torch.stack(prd_kps) 107 | -------------------------------------------------------------------------------- /model/base/norm.py: -------------------------------------------------------------------------------- 1 | r"""Normalization functions""" 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class Norm: 7 | r"""Vector normalization""" 8 | @classmethod 9 | def feat_normalize(cls, x, interp_size): 10 | r"""L2-normalizes given 2D feature map after interpolation""" 11 | x = F.interpolate(x, interp_size, mode='bilinear', align_corners=True) 12 | return x.pow(2).sum(1).view(x.size(0), -1) 13 | 14 | @classmethod 15 | def l1normalize(cls, x): 16 | r"""L1-normalization""" 17 | vector_sum = torch.sum(x, dim=2, keepdim=True) 18 | vector_sum[vector_sum == 0] = 1.0 19 | return x / vector_sum 20 | 21 | @classmethod 22 | def unit_gaussian_normalize(cls, x): 23 | r"""Make each (row) distribution into unit gaussian""" 24 | correlation_matrix = x - x.mean(dim=2).unsqueeze(2).expand_as(x) 25 | 26 | with torch.no_grad(): 27 | standard_deviation = correlation_matrix.std(dim=2) 28 | standard_deviation[standard_deviation == 0] = 1.0 29 | correlation_matrix /= standard_deviation.unsqueeze(2).expand_as(correlation_matrix) 30 | 31 | return correlation_matrix 32 | -------------------------------------------------------------------------------- /model/base/resnet.py: -------------------------------------------------------------------------------- 1 | r"""ResNet code from PyTorch library""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['Backbone', 'resnet50', 'resnet101'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, groups=2, bias=False) 23 | 24 | 25 | def conv1x1(in_planes, out_planes, stride=1): 26 | """1x1 convolution""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, groups=2, bias=False) 28 | 29 | 30 | class Bottleneck(nn.Module): 31 | expansion = 4 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(Bottleneck, self).__init__() 35 | self.conv1 = conv1x1(inplanes, planes) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.conv2 = conv3x3(planes, planes, stride) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.conv3 = conv1x1(planes, planes * self.expansion) 40 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | identity = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | if self.downsample is not None: 60 | identity = self.downsample(x) 61 | 62 | out += identity 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Backbone(nn.Module): 69 | def __init__(self, block, layers, zero_init_residual=False): 70 | super(Backbone, self).__init__() 71 | 72 | self.inplanes = 128 73 | self.conv1 = nn.Conv2d(6, 128, kernel_size=7, stride=2, padding=3, groups=2, 74 | bias=False) 75 | self.bn1 = nn.BatchNorm2d(128) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 78 | self.layer1 = self._make_layer(block, 128, layers[0]) 79 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2) 80 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2) 81 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2) 82 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 83 | self.fc = nn.Linear(512 * block.expansion, 1000) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 88 | elif isinstance(m, nn.BatchNorm2d): 89 | nn.init.constant_(m.weight, 1) 90 | nn.init.constant_(m.bias, 0) 91 | 92 | # Zero-initialize the last BN in each residual branch, 93 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 94 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 95 | if zero_init_residual: 96 | for m in self.modules(): 97 | if isinstance(m, Bottleneck): 98 | nn.init.constant_(m.bn3.weight, 0) 99 | # elif isinstance(m, BasicBlock): 100 | # nn.init.constant_(m.bn2.weight, 0) 101 | 102 | def _make_layer(self, block, planes, blocks, stride=1): 103 | downsample = None 104 | if stride != 1 or self.inplanes != planes * block.expansion: 105 | downsample = nn.Sequential( 106 | conv1x1(self.inplanes, planes * block.expansion, stride), 107 | nn.BatchNorm2d(planes * block.expansion), 108 | ) 109 | 110 | layers = [] 111 | layers.append(block(self.inplanes, planes, stride, downsample)) 112 | self.inplanes = planes * block.expansion 113 | for _ in range(1, blocks): 114 | layers.append(block(self.inplanes, planes)) 115 | 116 | return nn.Sequential(*layers) 117 | 118 | 119 | def resnet50(pretrained=False, **kwargs): 120 | """Constructs a ResNet-50 model. 121 | 122 | Args: 123 | pretrained (bool): If True, returns a model pre-trained on ImageNet 124 | """ 125 | model = Backbone(Bottleneck, [3, 4, 6, 3], **kwargs) 126 | if pretrained: 127 | weights = model_zoo.load_url(model_urls['resnet50']) 128 | 129 | for key in weights: 130 | if key.split('.')[0] == 'fc': 131 | weights[key] = weights[key].clone() 132 | continue 133 | weights[key] = torch.cat([weights[key].clone(), weights[key].clone()], dim=0) 134 | 135 | model.load_state_dict(weights) 136 | return model 137 | 138 | 139 | def resnet101(pretrained=False, **kwargs): 140 | """Constructs a ResNet-101 model. 141 | 142 | Args: 143 | pretrained (bool): If True, returns a model pre-trained on ImageNet 144 | """ 145 | model = Backbone(Bottleneck, [3, 4, 23, 3], **kwargs) 146 | if pretrained: 147 | weights = model_zoo.load_url(model_urls['resnet101']) 148 | 149 | for key in weights: 150 | if key.split('.')[0] == 'fc': 151 | weights[key] = weights[key].clone() 152 | continue 153 | weights[key] = torch.cat([weights[key].clone(), weights[key].clone()], dim=0) 154 | 155 | model.load_state_dict(weights) 156 | return model 157 | -------------------------------------------------------------------------------- /model/dhpf.py: -------------------------------------------------------------------------------- 1 | """Implementation of Dynamic Hyperpixel Flow""" 2 | from functools import reduce 3 | from operator import add 4 | 5 | import torch.nn as nn 6 | import torch 7 | 8 | from .base.correlation import Correlation 9 | from .base.geometry import Geometry 10 | from .base.norm import Norm 11 | from .base import resnet 12 | from . import gating 13 | from . import rhm 14 | 15 | 16 | class DynamicHPF: 17 | r"""Dynamic Hyperpixel Flow (DHPF)""" 18 | def __init__(self, backbone, device, img_side=240): 19 | r"""Constructor for DHPF""" 20 | super(DynamicHPF, self).__init__() 21 | 22 | # 1. Backbone network initialization 23 | if backbone == 'resnet50': 24 | self.backbone = resnet.resnet50(pretrained=True).to(device) 25 | self.in_channels = [64, 256, 256, 256, 512, 512, 512, 512, 1024, 26 | 1024, 1024, 1024, 1024, 1024, 2048, 2048, 2048] 27 | nbottlenecks = [3, 4, 6, 3] 28 | elif backbone == 'resnet101': 29 | self.backbone = resnet.resnet101(pretrained=True).to(device) 30 | self.in_channels = [64, 256, 256, 256, 512, 512, 512, 512, 31 | 1024, 1024, 1024, 1024, 1024, 1024, 1024, 32 | 1024, 1024, 1024, 1024, 1024, 1024, 1024, 33 | 1024, 1024, 1024, 1024, 1024, 1024, 1024, 34 | 1024, 1024, 2048, 2048, 2048] 35 | nbottlenecks = [3, 4, 23, 3] 36 | else: 37 | raise Exception('Unavailable backbone: %s' % backbone) 38 | self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks))) 39 | self.layer_ids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)]) 40 | self.backbone.eval() 41 | 42 | # 2. Dynamic layer gatings initialization 43 | self.learner = gating.GumbelFeatureSelection(self.in_channels).to(device) 44 | 45 | # 3. Miscellaneous 46 | self.relu = nn.ReLU() 47 | self.upsample_size = [int(img_side / 4)] * 2 48 | Geometry.initialize(self.upsample_size, device) 49 | self.rhm = rhm.HoughMatching(Geometry.rfs, torch.tensor([img_side, img_side]).to(device)) 50 | 51 | # Forward pass 52 | def __call__(self, *args, **kwargs): 53 | # 1. Compute correlations between hyperimages 54 | src_img = args[0] 55 | trg_img = args[1] 56 | correlation_matrix, layer_sel = self.hyperimage_correlation(src_img, trg_img) 57 | 58 | # 2. Compute geometric matching scores to re-weight appearance matching scores (RHM) 59 | with torch.no_grad(): # no back-prop thru rhm due to memory issue 60 | geometric_scores = torch.stack([self.rhm.run(c.clone().detach()) for c in correlation_matrix], dim=0) 61 | correlation_matrix *= geometric_scores 62 | 63 | return correlation_matrix, layer_sel 64 | 65 | def hyperimage_correlation(self, src_img, trg_img): 66 | r"""Dynamically construct hyperimages and compute their correlations""" 67 | layer_sel = [] 68 | correlation, src_norm, trg_norm = 0, 0, 0 69 | 70 | # Concatenate source & target images (B,6,H,W) 71 | # Perform group convolution (group=2) for faster inference time 72 | pair_img = torch.cat([src_img, trg_img], dim=1) 73 | 74 | # Layer 0 75 | with torch.no_grad(): 76 | feat = self.backbone.conv1.forward(pair_img) 77 | feat = self.backbone.bn1.forward(feat) 78 | feat = self.backbone.relu.forward(feat) 79 | feat = self.backbone.maxpool.forward(feat) 80 | 81 | src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone() 82 | trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone() 83 | 84 | # Save base maps 85 | base_src_feat = self.learner.reduction_ffns[0](src_feat) 86 | base_trg_feat = self.learner.reduction_ffns[0](trg_feat) 87 | base_correlation = Correlation.bmm_interp(base_src_feat, base_trg_feat, self.upsample_size) 88 | base_src_norm = Norm.feat_normalize(base_src_feat, self.upsample_size) 89 | base_trg_norm = Norm.feat_normalize(base_trg_feat, self.upsample_size) 90 | 91 | src_feat, trg_feat, lsel = self.learner(0, src_feat, trg_feat) 92 | if src_feat is not None and trg_feat is not None: 93 | correlation += Correlation.bmm_interp(src_feat, trg_feat, self.upsample_size) 94 | src_norm += Norm.feat_normalize(src_feat, self.upsample_size) 95 | trg_norm += Norm.feat_normalize(trg_feat, self.upsample_size) 96 | layer_sel.append(lsel) 97 | 98 | # Layer 1-4 99 | for hid, (bid, lid) in enumerate(zip(self.bottleneck_ids, self.layer_ids)): 100 | with torch.no_grad(): 101 | res = feat 102 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat) 103 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat) 104 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 105 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat) 106 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat) 107 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 108 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat) 109 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat) 110 | if bid == 0: 111 | res = self.backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res) 112 | feat += res 113 | 114 | src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone() 115 | trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone() 116 | 117 | src_feat, trg_feat, lsel = self.learner(hid + 1, src_feat, trg_feat) 118 | if src_feat is not None and trg_feat is not None: 119 | correlation += Correlation.bmm_interp(src_feat, trg_feat, self.upsample_size) 120 | src_norm += Norm.feat_normalize(src_feat, self.upsample_size) 121 | trg_norm += Norm.feat_normalize(trg_feat, self.upsample_size) 122 | layer_sel.append(lsel) 123 | 124 | with torch.no_grad(): 125 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat) 126 | 127 | layer_sel = torch.stack(layer_sel).t() 128 | 129 | # If no layers are selected, select the base map 130 | if (layer_sel.sum(dim=1) == 0).sum() > 0: 131 | empty_sel = (layer_sel.sum(dim=1) == 0).nonzero().view(-1).long() 132 | if src_img.size(0) == 1: 133 | correlation = base_correlation 134 | src_norm = base_src_norm 135 | trg_norm = base_trg_norm 136 | else: 137 | correlation[empty_sel] += base_correlation[empty_sel] 138 | src_norm[empty_sel] += base_src_norm[empty_sel] 139 | trg_norm[empty_sel] += base_trg_norm[empty_sel] 140 | 141 | if self.learner.training: 142 | src_norm[src_norm == 0.0] += 0.0001 143 | trg_norm[trg_norm == 0.0] += 0.0001 144 | src_norm = src_norm.pow(0.5).unsqueeze(2) 145 | trg_norm = trg_norm.pow(0.5).unsqueeze(1) 146 | 147 | # Appearance matching confidence (p(m_a)): cosine similarity between hyperpimages 148 | correlation_ts = self.relu(correlation / (torch.bmm(src_norm, trg_norm) + 0.001)).pow(2) 149 | 150 | return correlation_ts, layer_sel 151 | 152 | def parameters(self): 153 | return self.learner.parameters() 154 | 155 | def state_dict(self): 156 | return self.learner.state_dict() 157 | 158 | def load_state_dict(self, state_dict): 159 | self.learner.load_state_dict(state_dict) 160 | 161 | def eval(self): 162 | self.learner.eval() 163 | 164 | def train(self): 165 | self.learner.train() 166 | -------------------------------------------------------------------------------- /model/gating.py: -------------------------------------------------------------------------------- 1 | r"""Implementation of Dynamic Layer Gating (DLG)""" 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class GumbelFeatureSelection(nn.Module): 7 | r"""Dynamic layer gating with Gumbel-max trick""" 8 | def __init__(self, in_channels, reduction=8, hidden_size=32): 9 | r"""Constructor for DLG""" 10 | super(GumbelFeatureSelection, self).__init__() 11 | 12 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 13 | self.softmax = nn.Softmax(dim=1) 14 | self.reduction = reduction 15 | 16 | # Learnable modules in Dynamic Hyperpixel Flow 17 | self.reduction_ffns = [] # Convolutional Feature Transformation (CFT) 18 | self.gumbel_ffns = [] # Gumbel Layer Gating (GLG) 19 | for in_channel in in_channels: 20 | out_channel = in_channel // self.reduction 21 | reduction_ffn = nn.Sequential( 22 | nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=False), 23 | nn.BatchNorm2d(out_channel), 24 | nn.ReLU(inplace=True) 25 | ) 26 | gumbel_ffn = nn.Sequential( 27 | nn.Conv2d(in_channel, hidden_size, kernel_size=1, bias=False), 28 | nn.BatchNorm2d(hidden_size), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(hidden_size, 2, kernel_size=1, bias=False) 31 | ) 32 | self.reduction_ffns.append(reduction_ffn) 33 | self.gumbel_ffns.append(gumbel_ffn) 34 | self.reduction_ffns = nn.ModuleList(self.reduction_ffns) 35 | self.gumbel_ffns = nn.ModuleList(self.gumbel_ffns) 36 | 37 | def forward(self, lid, src_feat, trg_feat): 38 | r"""DLG forward pass""" 39 | relevance = self.gumbel_ffns[lid](self.avgpool(src_feat) + self.avgpool(trg_feat)) 40 | 41 | # For measuring per-pair inference time on test set 42 | if not self.training and len(relevance) == 1: 43 | selected = relevance.max(dim=1)[1].squeeze() 44 | 45 | # Perform CFT iff the layer is selected 46 | if selected: 47 | src_x = self.reduction_ffns[lid](src_feat) 48 | trg_x = self.reduction_ffns[lid](trg_feat) 49 | else: 50 | src_x = None 51 | trg_x = None 52 | layer_sel = relevance.view(-1).max(dim=0)[1].unsqueeze(0).float() 53 | else: 54 | # Hard selection during forward pass (layer_sel) 55 | # Soft gradients during backward pass (dL/dy) 56 | y = self.gumbel_softmax(relevance.squeeze()) 57 | _y = self._softmax(y) 58 | layer_sel = y[:, 1] + _y[:, 1] 59 | 60 | src_x = self.reduction_ffns[lid](src_feat) 61 | trg_x = self.reduction_ffns[lid](trg_feat) 62 | src_x = src_x * y[:, 1].view(-1, 1, 1, 1) + src_x * _y[:, 1].view(-1, 1, 1, 1) 63 | trg_x = trg_x * y[:, 1].view(-1, 1, 1, 1) + trg_x * _y[:, 1].view(-1, 1, 1, 1) 64 | 65 | return src_x, trg_x, layer_sel 66 | 67 | def _softmax(self, soft_sample): 68 | r"""Gumbel-max trick: replaces argmax with softmax during backward pass: soft_sample + _soft_sample""" 69 | hard_sample_idx = torch.max(soft_sample, dim=1)[1].unsqueeze(1) 70 | hard_sample = soft_sample.detach().clone().zero_().scatter(dim=1, index=hard_sample_idx, source=1) 71 | 72 | _soft_sample = (hard_sample - soft_sample.detach().clone()) 73 | 74 | return _soft_sample 75 | 76 | def gumbel_softmax(self, logits, temperature=1, eps=1e-10): 77 | """Softly draws a sample from the Gumbel distribution""" 78 | if self.training: 79 | gumbel_noise = -torch.log(eps - torch.log(logits.detach().clone().uniform_() + eps)) 80 | gumbel_input = logits + gumbel_noise 81 | else: 82 | gumbel_input = logits 83 | 84 | if gumbel_input.dim() == 1: 85 | gumbel_input = gumbel_input.unsqueeze(0) 86 | 87 | soft_sample = self.softmax(gumbel_input / temperature) 88 | 89 | return soft_sample 90 | -------------------------------------------------------------------------------- /model/objective.py: -------------------------------------------------------------------------------- 1 | r"""Training objectives of DHPF""" 2 | import math 3 | 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | from .base.geometry import Correlation 8 | from .base.norm import Norm 9 | 10 | 11 | class Objective: 12 | r"""Provides training objectives of DHPF""" 13 | @classmethod 14 | def initialize(cls, target_rate, alpha): 15 | cls.softmax = torch.nn.Softmax(dim=1) 16 | cls.target_rate = target_rate 17 | cls.alpha = alpha 18 | cls.eps = 1e-30 19 | 20 | @classmethod 21 | def weighted_cross_entropy(cls, correlation_matrix, easy_match, hard_match, batch): 22 | r"""Computes sum of weighted cross-entropy values between ground-truth and prediction""" 23 | loss_buf = correlation_matrix.new_zeros(correlation_matrix.size(0)) 24 | correlation_matrix = Norm.unit_gaussian_normalize(correlation_matrix) 25 | 26 | for idx, (ct, thres, npt) in enumerate(zip(correlation_matrix, batch['pckthres'], batch['n_pts'])): 27 | 28 | # Hard (incorrect) match 29 | if len(hard_match['src'][idx]) > 0: 30 | cross_ent = cls.cross_entropy(ct, hard_match['src'][idx], hard_match['trg'][idx]) 31 | loss_buf[idx] += cross_ent.sum() 32 | 33 | # Easy (correct) match 34 | if len(easy_match['src'][idx]) > 0: 35 | cross_ent = cls.cross_entropy(ct, easy_match['src'][idx], easy_match['trg'][idx]) 36 | smooth_weight = (easy_match['dist'][idx] / (thres * cls.alpha)).pow(2) 37 | loss_buf[idx] += (smooth_weight * cross_ent).sum() 38 | 39 | loss_buf[idx] /= npt 40 | 41 | return torch.mean(loss_buf) 42 | 43 | @classmethod 44 | def cross_entropy(cls, correlation_matrix, src_match, trg_match): 45 | r"""Cross-entropy between predicted pdf and ground-truth pdf (one-hot vector)""" 46 | pdf = cls.softmax(correlation_matrix.index_select(0, src_match)) 47 | prob = pdf[range(len(trg_match)), trg_match] 48 | cross_ent = -torch.log(prob + cls.eps) 49 | 50 | return cross_ent 51 | 52 | @classmethod 53 | def information_entropy(cls, correlation_matrix, rescale_factor=4): 54 | r"""Computes information entropy of all candidate matches""" 55 | bsz = correlation_matrix.size(0) 56 | 57 | correlation_matrix = Correlation.mutual_nn_filter(correlation_matrix) 58 | 59 | side = int(math.sqrt(correlation_matrix.size(1))) 60 | new_side = side // rescale_factor 61 | 62 | trg2src_dist = correlation_matrix.view(bsz, -1, side, side) 63 | src2trg_dist = correlation_matrix.view(bsz, side, side, -1).permute(0, 3, 1, 2) 64 | 65 | # Squeeze distributions for reliable entropy computation 66 | trg2src_dist = F.interpolate(trg2src_dist, [new_side, new_side], mode='bilinear', align_corners=True) 67 | src2trg_dist = F.interpolate(src2trg_dist, [new_side, new_side], mode='bilinear', align_corners=True) 68 | 69 | src_pdf = Norm.l1normalize(trg2src_dist.view(bsz, -1, (new_side * new_side))) 70 | trg_pdf = Norm.l1normalize(src2trg_dist.view(bsz, -1, (new_side * new_side))) 71 | 72 | src_pdf[src_pdf == 0.0] = cls.eps 73 | trg_pdf[trg_pdf == 0.0] = cls.eps 74 | 75 | src_ent = (-(src_pdf * torch.log2(src_pdf)).sum(dim=2)).view(bsz, -1) 76 | trg_ent = (-(trg_pdf * torch.log2(trg_pdf)).sum(dim=2)).view(bsz, -1) 77 | score_net = (src_ent + trg_ent).mean(dim=1) / 2 78 | 79 | return score_net.mean() 80 | 81 | @classmethod 82 | def layer_selection_loss(cls, layer_sel): 83 | r"""Encourages model to select each layer at a certain rate""" 84 | return (layer_sel.mean(dim=0) - cls.target_rate).pow(2).sum() 85 | -------------------------------------------------------------------------------- /model/rhm.py: -------------------------------------------------------------------------------- 1 | """Implementation of regularized Hough matching algorithm (RHM)""" 2 | import math 3 | 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | from .base.geometry import Geometry 8 | 9 | 10 | class HoughMatching: 11 | r"""Regularized Hough matching algorithm""" 12 | def __init__(self, rf, img_side, ncells=8192): 13 | r"""Constructor of HoughMatching""" 14 | super(HoughMatching, self).__init__() 15 | 16 | device = rf.device 17 | self.nbins_x, self.nbins_y, hs_cellsize = self.build_hspace(img_side, img_side, ncells) 18 | self.bin_ids = self.compute_bin_id(img_side, rf, rf, hs_cellsize, self.nbins_x) 19 | self.hspace = rf.new_zeros((len(rf), self.nbins_y * self.nbins_x)) 20 | self.hbin_ids = self.bin_ids.add(torch.arange(0, len(rf)).to(device). 21 | mul(self.hspace.size(1)).unsqueeze(1).expand_as(self.bin_ids)) 22 | self.hsfilter = Geometry.gaussian2d(7).to(device) 23 | 24 | def run(self, votes): 25 | r"""Regularized Hough matching""" 26 | hspace = self.hspace.view(-1).index_add(0, self.hbin_ids.view(-1), votes.view(-1)).view_as(self.hspace) 27 | hspace = torch.sum(hspace, dim=0) 28 | hspace = F.conv2d(hspace.view(1, 1, self.nbins_y, self.nbins_x), 29 | self.hsfilter.unsqueeze(0).unsqueeze(0), padding=3).view(-1) 30 | 31 | return torch.index_select(hspace, dim=0, index=self.bin_ids.view(-1)).view_as(votes) 32 | 33 | def compute_bin_id(self, src_imsize, src_box, trg_box, hs_cellsize, nbins_x): 34 | r"""Computes Hough space bin ids for voting""" 35 | src_ptref = src_imsize.float() 36 | src_trans = Geometry.center(src_box) 37 | trg_trans = Geometry.center(trg_box) 38 | xy_vote = (src_ptref.unsqueeze(0).expand_as(src_trans) - src_trans).unsqueeze(2).\ 39 | repeat(1, 1, len(trg_box)) + trg_trans.t().unsqueeze(0).repeat(len(src_box), 1, 1) 40 | 41 | bin_ids = (xy_vote / hs_cellsize).long() 42 | 43 | return bin_ids[:, 0, :] + bin_ids[:, 1, :] * nbins_x 44 | 45 | def build_hspace(self, src_imsize, trg_imsize, ncells): 46 | r"""Build Hough space""" 47 | hs_width = src_imsize[0] + trg_imsize[0] 48 | hs_height = src_imsize[1] + trg_imsize[1] 49 | hs_cellsize = math.sqrt((hs_width * hs_height) / ncells) 50 | nbins_x = int(hs_width / hs_cellsize) + 1 51 | nbins_y = int(hs_height / hs_cellsize) + 1 52 | 53 | return nbins_x, nbins_y, hs_cellsize 54 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | r"""Dynamic Hyperpixel Flow testing code""" 2 | import argparse 3 | 4 | from torch.utils.data import DataLoader 5 | import torch 6 | 7 | from common.evaluation import Evaluator 8 | from common.logger import AverageMeter 9 | from common.logger import Logger 10 | from common import utils 11 | from model.base.geometry import Geometry 12 | from model import dhpf 13 | from data import download 14 | 15 | 16 | def test(model, dataloader): 17 | r"""Code for testing DHPF""" 18 | average_meter = AverageMeter(dataloader.dataset.benchmark) 19 | 20 | for idx, batch in enumerate(dataloader): 21 | 22 | # 1. DHPF forward pass 23 | correlation_matrix, layer_sel = model(batch['src_img'], batch['trg_img']) 24 | 25 | # 2. Transfer key-points (nearest neighbor assignment) 26 | prd_kps = Geometry.transfer_kps(correlation_matrix, batch['src_kps'], batch['n_pts']) 27 | 28 | # 3. Evaluate predictions 29 | eval_result = Evaluator.evaluate(prd_kps, batch) 30 | average_meter.update(eval_result, layer_sel.detach(), batch['category']) 31 | average_meter.write_process(idx, len(dataloader)) 32 | 33 | # Write evaluation results 34 | Logger.visualize_selection(average_meter.sel_buffer) 35 | average_meter.write_result('Test') 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | # Arguments parsing 41 | parser = argparse.ArgumentParser(description='Dynamic Hyperpixel Flow Pytorch Implementation') 42 | parser.add_argument('--datapath', type=str, default='../Datasets_DHPF') 43 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['resnet50', 'resnet101']) 44 | parser.add_argument('--benchmark', type=str, default='pfpascal', choices=['pfpascal', 'pfwillow', 'caltech', 'spair']) 45 | parser.add_argument('--thres', type=str, default='auto', choices=['auto', 'img', 'bbox']) 46 | parser.add_argument('--alpha', type=float, default=0.1) 47 | parser.add_argument('--logpath', type=str, default='') 48 | parser.add_argument('--bsz', type=int, default=16) 49 | parser.add_argument('--load', type=str, default='') 50 | args = parser.parse_args() 51 | Logger.initialize(args) 52 | utils.fix_randseed(seed=0) 53 | 54 | # Model initialization 55 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 56 | model = dhpf.DynamicHPF(args.backbone, device) 57 | model.load_state_dict(torch.load(args.load)) 58 | model.eval() 59 | 60 | # Dataset download & initialization 61 | download.download_dataset(args.datapath, args.benchmark) 62 | test_ds = download.load_dataset(args.benchmark, args.datapath, args.thres, device, 'test') 63 | test_dl = DataLoader(test_ds, batch_size=args.bsz, shuffle=False) 64 | Evaluator.initialize(args.benchmark, args.alpha) 65 | 66 | # Test DHPF 67 | with torch.no_grad(): test(model, test_dl) 68 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | r"""Dynamic Hyperpixel Flow training (validation) code""" 2 | import argparse 3 | 4 | from torch.utils.data import DataLoader 5 | import torch.optim as optim 6 | import torch 7 | 8 | from common.evaluation import Evaluator 9 | from common.logger import AverageMeter 10 | from common.logger import Logger 11 | from common import supervision as sup 12 | from common import utils 13 | from model.base.geometry import Geometry 14 | from model.objective import Objective 15 | from model import dhpf 16 | from data import download 17 | 18 | 19 | def train(epoch, model, dataloader, strategy, optimizer, training): 20 | r"""Code for training DHPF""" 21 | model.train() if training else model.eval() 22 | average_meter = AverageMeter(dataloader.dataset.benchmark) 23 | 24 | for idx, batch in enumerate(dataloader): 25 | 26 | # 1. DHPF forward pass 27 | src_img, trg_img = strategy.get_image_pair(batch, training) 28 | correlation_matrix, layer_sel = model(src_img, trg_img) 29 | 30 | # 2. Transfer key-points (nearest neighbor assignment) 31 | prd_kps = Geometry.transfer_kps(strategy.get_correlation(correlation_matrix), batch['src_kps'], batch['n_pts']) 32 | 33 | # 3. Evaluate predictions 34 | eval_result = Evaluator.evaluate(prd_kps, batch) 35 | 36 | # 4. Compute loss to update weights 37 | loss = strategy.compute_loss(correlation_matrix, eval_result, layer_sel, batch) 38 | if training: 39 | optimizer.zero_grad() 40 | loss.backward() 41 | optimizer.step() 42 | average_meter.update(eval_result, layer_sel.detach(), batch['category'], loss.item()) 43 | average_meter.write_process(idx, len(dataloader), epoch) 44 | 45 | # Write evaluation results 46 | average_meter.write_result('Training' if training else 'Validation', epoch) 47 | 48 | avg_loss = utils.mean(average_meter.loss_buffer) 49 | avg_pck = utils.mean(average_meter.buffer['pck']) 50 | return avg_loss, avg_pck 51 | 52 | 53 | if __name__ == '__main__': 54 | 55 | # Arguments parsing 56 | parser = argparse.ArgumentParser(description='Dynamic Hyperpixel Flow Pytorch Implementation') 57 | parser.add_argument('--datapath', type=str, default='../Datasets_DHPF') 58 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['resnet50', 'resnet101']) 59 | parser.add_argument('--benchmark', type=str, default='pfpascal', choices=['pfpascal', 'spair']) 60 | parser.add_argument('--thres', type=str, default='auto', choices=['auto', 'img', 'bbox']) 61 | parser.add_argument('--supervision', type=str, default='strong', choices=['weak', 'strong']) 62 | parser.add_argument('--selection', type=float, default=0.5) 63 | parser.add_argument('--alpha', type=float, default=0.1) 64 | parser.add_argument('--logpath', type=str, default='') 65 | parser.add_argument('--lr', type=float, default=0.03) 66 | parser.add_argument('--niter', type=int, default=100) 67 | parser.add_argument('--bsz', type=int, default=8) 68 | args = parser.parse_args() 69 | Logger.initialize(args) 70 | utils.fix_randseed(seed=0) 71 | 72 | # Model initialization 73 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 74 | model = dhpf.DynamicHPF(args.backbone, device) 75 | Objective.initialize(args.selection, args.alpha) 76 | strategy = sup.WeakSupStrategy() if args.supervision == 'weak' else sup.StrongSupStrategy() 77 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.5) 78 | 79 | # Dataset download & initialization 80 | download.download_dataset(args.datapath, args.benchmark) 81 | trn_ds = download.load_dataset(args.benchmark, args.datapath, args.thres, device, 'trn') 82 | val_ds = download.load_dataset(args.benchmark, args.datapath, args.thres, device, 'val') 83 | trn_dl = DataLoader(trn_ds, batch_size=args.bsz, shuffle=True) 84 | val_dl = DataLoader(val_ds, batch_size=args.bsz, shuffle=False) 85 | Evaluator.initialize(args.benchmark, args.alpha) 86 | 87 | # Train DHPF 88 | best_val_pck = float('-inf') 89 | for epoch in range(args.niter): 90 | 91 | trn_loss, trn_pck = train(epoch, model, trn_dl, strategy, optimizer, training=True) 92 | with torch.no_grad(): 93 | val_loss, val_pck = train(epoch, model, val_dl, strategy, optimizer, training=False) 94 | 95 | # Save the best model 96 | if val_pck > best_val_pck: 97 | best_val_pck = val_pck 98 | Logger.save_model(model, epoch, val_pck) 99 | Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss}, epoch) 100 | Logger.tbd_writer.add_scalars('data/pck', {'trn_pck': trn_pck, 'val_pck': val_pck}, epoch) 101 | 102 | Logger.tbd_writer.close() 103 | Logger.info('==================== Finished training ====================') 104 | --------------------------------------------------------------------------------