├── LICENSE
├── README.md
├── common
├── evaluation.py
├── logger.py
├── supervision.py
└── utils.py
├── data
├── caltech.py
├── dataset.py
├── download.py
├── pfpascal.py
├── pfwillow.py
└── spair.py
├── model
├── base
│ ├── correlation.py
│ ├── geometry.py
│ ├── norm.py
│ └── resnet.py
├── dhpf.py
├── gating.py
├── objective.py
└── rhm.py
├── test.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022 - Juhong Min
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://paperswithcode.com/sota/semantic-correspondence-on-spair-71k?p=learning-to-compose-hypercolumns-for-visual)
2 | [](https://paperswithcode.com/sota/semantic-correspondence-on-pf-pascal?p=learning-to-compose-hypercolumns-for-visual)
3 | [](https://paperswithcode.com/sota/semantic-correspondence-on-pf-willow?p=learning-to-compose-hypercolumns-for-visual)
4 |
5 | # Learning to Compose Hypercolumns for Visual Correspondence
6 | This is the implementation of the paper "Learning to Compose Hypercolumns for Visual Correspondence" by J. Min, J. Lee, J. Ponce and M. Cho. Implemented on Python 3.7 and PyTorch 1.0.1.
7 |
8 | 
9 |
10 | For more information, check out project [[website](http://cvlab.postech.ac.kr/research/DHPF/)] and the paper on [[arXiv](https://arxiv.org/abs/2007.10587)].
11 |
12 |
13 | ## Requirements
14 |
15 | - Python 3.7
16 | - PyTorch 1.0.1
17 | - tensorboard
18 | - scipy
19 | - pandas
20 | - requests
21 | - scikit-image
22 |
23 | Conda environment settings:
24 | ```bash
25 | conda create -n dhpf python=3.7
26 | conda activate dhpf
27 |
28 | conda install pytorch=1.0.1 torchvision cudatoolkit=10.0 -c pytorch
29 | pip install tensorboardX
30 | conda install -c anaconda scipy
31 | conda install -c anaconda pandas
32 | conda install -c anaconda requests
33 | conda install -c anaconda scikit-image
34 | conda install -c anaconda "pillow<7"
35 | ```
36 |
37 | ## Training
38 |
39 | Training DHPF with strong supervision (keypoint annotations) on PF-PASCAL and SPair-71k
40 | (reproducing strongly-supervised results in Tab. 1 and 2):
41 | ```bash
42 | python train.py --supervision strong \
43 | --lr 0.03 \
44 | --bsz 8 \
45 | --niter 100 \
46 | --selection 0.5 \
47 | --benchmark pfpascal \
48 | --backbone {resnet50, resnet101}
49 |
50 | python train.py --supervision strong \
51 | --lr 0.03 \
52 | --bsz 8 \
53 | --niter 5 \
54 | --selection 0.5 \
55 | --benchmark spair \
56 | --backbone {resnet50, resnet101}
57 | ```
58 | Training DHPF with weak supervision (image-level labels) on PF-PASCAL
59 | (reproducing weak-supervised results in Tab. 1):
60 | ```bash
61 | python train.py --supervision weak \
62 | --lr 0.1 \
63 | --bsz 4 \
64 | --niter 30 \
65 | --selection 0.5 \
66 | --benchmark pfpascal \
67 | --backbone {resnet50, resnet101}
68 | ```
69 |
70 | ## Testing
71 |
72 | We provide trained models available on [[Google drive](https://drive.google.com/drive/folders/1aoKQlvHOb7vZIFK8pDJsQnC7SOyEjXVF?usp=sharing)].
73 |
74 | PCK @ αimg=0.1 on PF-PASCAL at different μ:
75 |
76 | | Trained models
at differnt μ | 0.3 | 0.4 | 0.5 | 0.6 | 0.7 | 0.8 | 0.9 | 1 |
77 | |:--------------------------------------------------:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
78 | | [weak (res50)](https://drive.google.com/drive/folders/1WykysKyy9PAsX-DpC5UuZILokCMToJWH?usp=sharing) | 77.3 | 79 | 79 | 79.3 | 79.6 | 80.7 | 81.1 | 80.7 |
79 | | [weak (res101)](https://drive.google.com/drive/folders/1IjjoFgrIZzys2YDEGhLQrOg0bTG29-Pl?usp=sharing) | 80.3 | 81.2 | 82.1 | 80.1 | 81.7 | 80.9 | 81.3 | 81.3 |
80 | | [strong (res50)](https://drive.google.com/drive/folders/1RC9EbVhk8QOjpF3NIO-tidIsKcY399S8?usp=sharing) | 87.7 | 89.1 | 88.9 | 88.5 | 89.4 | 89.1 | 89 | 89.5 |
81 | | [strong (res101)](https://drive.google.com/drive/folders/1QDYOxqF-BsWKjKbwLKfbcfxaS5OHlbVT?usp=sharing) | 88.7 | 90 | 90.7 | 90.2 | 90.1 | 90.6 | 90.6 | 90.4 |
82 |
83 | PCK @ αimg=0.1 on SPair-71k at μ=0.5:
84 |
85 | | Trained models
at μ=0.5 | PCK |
86 | |:---------------------------------------------:|:----:|
87 | | [weak (res101)](https://drive.google.com/file/d/1uDfONwSiAzDsxW9wbhdlYKf8auqAVXoM/view?usp=sharing) | 27.7 |
88 | | [strong (res101)](https://drive.google.com/file/d/1DnsDhttMIImAcupdjuANowlgZqVSx_5E/view?usp=sharing) | 37.3 |
89 |
90 | Reproducing results in Tab. 1, 2 and 3:
91 | ```bash
92 | python test.py --backbone {resnet50, resnet101} \
93 | --benchmark {pfpascal, pfwillow, caltech, spair} \
94 | --load "path_to_trained_model"
95 | ```
96 |
97 |
98 | ## BibTeX
99 | If you use this code for your research, please consider citing:
100 | ````BibTeX
101 | @InProceedings{min2020dhpf,
102 | title={Learning to Compose Hypercolumns for Visual Correspondence},
103 | author={Juhong Min and Jongmin Lee and Jean Ponce and Minsu Cho},
104 | booktitle={ECCV},
105 | year={2020}
106 | }
107 | ````
108 |
--------------------------------------------------------------------------------
/common/evaluation.py:
--------------------------------------------------------------------------------
1 | """For quantitative evaluation of DHPF"""
2 | from skimage import draw
3 | import numpy as np
4 | import torch
5 |
6 | from . import utils
7 |
8 |
9 | class Evaluator:
10 | r"""Computes evaluation metrics of PCK, LT-ACC, IoU"""
11 | @classmethod
12 | def initialize(cls, benchmark, alpha=0.1):
13 | if benchmark == 'caltech':
14 | cls.eval_func = cls.eval_mask_transfer
15 | else:
16 | cls.eval_func = cls.eval_kps_transfer
17 | cls.alpha = alpha
18 |
19 | @classmethod
20 | def evaluate(cls, prd_kps, batch):
21 | r"""Compute evaluation metric"""
22 | return cls.eval_func(prd_kps, batch)
23 |
24 | @classmethod
25 | def eval_kps_transfer(cls, prd_kps, batch):
26 | r"""Compute percentage of correct key-points (PCK) based on prediction"""
27 |
28 | easy_match = {'src': [], 'trg': [], 'dist': []}
29 | hard_match = {'src': [], 'trg': []}
30 |
31 | pck = []
32 | for idx, (pk, tk) in enumerate(zip(prd_kps, batch['trg_kps'])):
33 | thres = batch['pckthres'][idx]
34 | npt = batch['n_pts'][idx]
35 | correct_dist, correct_ids, incorrect_ids = cls.classify_prd(pk[:, :npt], tk[:, :npt], thres)
36 |
37 | # Collect easy and hard match feature index & store pck to buffer
38 | easy_match['dist'].append(correct_dist)
39 | easy_match['src'].append(batch['src_kpidx'][idx][:npt][correct_ids])
40 | easy_match['trg'].append(batch['trg_kpidx'][idx][:npt][correct_ids])
41 | hard_match['src'].append(batch['src_kpidx'][idx][:npt][incorrect_ids])
42 | hard_match['trg'].append(batch['trg_kpidx'][idx][:npt][incorrect_ids])
43 | pck.append((len(correct_ids) / npt.item()) * 100)
44 |
45 | eval_result = {'easy_match': easy_match,
46 | 'hard_match': hard_match,
47 | 'pck': pck}
48 |
49 | return eval_result
50 |
51 | @classmethod
52 | def eval_mask_transfer(cls, prd_kps, batch):
53 | r"""Compute LT-ACC and IoU based on transferred points"""
54 |
55 | ltacc = []
56 | iou = []
57 |
58 | for idx, prd in enumerate(prd_kps):
59 | trg_n_pts = (batch['trg_kps'][idx] > 0)[0].sum()
60 | prd_kp = prd[:, :batch['n_pts'][idx]]
61 | trg_kp = batch['trg_kps'][idx][:, :trg_n_pts]
62 |
63 | imsize = list(batch['trg_img'].size())[2:]
64 | trg_xstr, trg_ystr = cls.pts2ptstr(trg_kp)
65 | trg_mask = cls.ptstr2mask(trg_xstr, trg_ystr, imsize[0], imsize[1])
66 | prd_xstr, pred_ystr = cls.pts2ptstr(prd_kp)
67 | prd_mask = cls.ptstr2mask(prd_xstr, pred_ystr, imsize[0], imsize[1])
68 |
69 | ltacc.append(cls.label_transfer_accuracy(prd_mask, trg_mask))
70 | iou.append(cls.intersection_over_union(prd_mask, trg_mask))
71 |
72 | eval_result = {'ltacc': ltacc,
73 | 'iou': iou}
74 |
75 | return eval_result
76 |
77 | @classmethod
78 | def classify_prd(cls, prd_kps, trg_kps, pckthres):
79 | r"""Compute the number of correctly transferred key-points"""
80 | l2dist = (prd_kps - trg_kps).pow(2).sum(dim=0).pow(0.5)
81 | thres = pckthres.expand_as(l2dist).float() * cls.alpha
82 | correct_pts = torch.le(l2dist, thres)
83 |
84 | correct_ids = utils.where(correct_pts == 1)
85 | incorrect_ids = utils.where(correct_pts == 0)
86 | correct_dist = l2dist[correct_pts]
87 |
88 | return correct_dist, correct_ids, incorrect_ids
89 |
90 | @classmethod
91 | def intersection_over_union(cls, mask1, mask2):
92 | r"""Computes IoU between two masks"""
93 | rel_part_weight = torch.sum(torch.sum(mask2.gt(0.5).float(), 2, True), 3, True) / \
94 | torch.sum(mask2.gt(0.5).float())
95 | part_iou = torch.sum(torch.sum((mask1.gt(0.5) & mask2.gt(0.5)).float(), 2, True), 3, True) / \
96 | torch.sum(torch.sum((mask1.gt(0.5) | mask2.gt(0.5)).float(), 2, True), 3, True)
97 | weighted_iou = torch.sum(torch.mul(rel_part_weight, part_iou)).item()
98 |
99 | return weighted_iou
100 |
101 | @classmethod
102 | def label_transfer_accuracy(cls, mask1, mask2):
103 | r"""LT-ACC measures the overlap with emphasis on the background class"""
104 | return torch.mean((mask1.gt(0.5) == mask2.gt(0.5)).double()).item()
105 |
106 | @classmethod
107 | def pts2ptstr(cls, pts):
108 | r"""Convert tensor of points to string"""
109 | x_str = str(list(pts[0].cpu().numpy()))
110 | x_str = x_str[1:len(x_str)-1]
111 | y_str = str(list(pts[1].cpu().numpy()))
112 | y_str = y_str[1:len(y_str)-1]
113 |
114 | return x_str, y_str
115 |
116 | @classmethod
117 | def pts2mask(cls, x_pts, y_pts, shape):
118 | r"""Build a binary mask tensor base on given xy-points"""
119 | x_idx, y_idx = draw.polygon(x_pts, y_pts, shape)
120 | mask = np.zeros(shape, dtype=np.bool)
121 | mask[x_idx, y_idx] = True
122 |
123 | return mask
124 |
125 | @classmethod
126 | def ptstr2mask(cls, x_str, y_str, out_h, out_w):
127 | r"""Convert xy-point mask (string) to tensor mask"""
128 | x_pts = np.fromstring(x_str, sep=',')
129 | y_pts = np.fromstring(y_str, sep=',')
130 | mask_np = cls.pts2mask(y_pts, x_pts, [out_h, out_w])
131 | mask = torch.tensor(mask_np.astype(np.float32)).unsqueeze(0).unsqueeze(0).float()
132 |
133 | return mask
134 |
--------------------------------------------------------------------------------
/common/logger.py:
--------------------------------------------------------------------------------
1 | r"""Logging"""
2 | import datetime
3 | import logging
4 | import os
5 |
6 | from tensorboardX import SummaryWriter
7 | import matplotlib.pyplot as plt
8 | import pandas as pd
9 | import numpy as np
10 | import torch
11 |
12 |
13 | class Logger:
14 | r"""Writes results of training/testing"""
15 | @classmethod
16 | def initialize(cls, args):
17 | logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
18 | logpath = args.logpath
19 |
20 | cls.logpath = os.path.join('logs', logpath + logtime + '.log')
21 | cls.benchmark = args.benchmark
22 | os.makedirs(cls.logpath)
23 |
24 | logging.basicConfig(filemode='w',
25 | filename=os.path.join(cls.logpath, 'log.txt'),
26 | level=logging.INFO,
27 | format='%(message)s',
28 | datefmt='%m-%d %H:%M:%S')
29 |
30 | # Console log config
31 | console = logging.StreamHandler()
32 | console.setLevel(logging.INFO)
33 | formatter = logging.Formatter('%(message)s')
34 | console.setFormatter(formatter)
35 | logging.getLogger('').addHandler(console)
36 |
37 | # Tensorboard writer
38 | cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
39 |
40 | # Log arguments
41 | logging.info('\n+=========== Dynamic Hyperpixel Flow ============+')
42 | for arg_key in args.__dict__:
43 | logging.info('| %20s: %-24s |' % (arg_key, str(args.__dict__[arg_key])))
44 | logging.info('+================================================+\n')
45 |
46 | @classmethod
47 | def info(cls, msg):
48 | r"""Writes message to .txt"""
49 | logging.info(msg)
50 |
51 | @classmethod
52 | def save_model(cls, model, epoch, val_pck):
53 | torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt'))
54 | cls.info('Model saved @%d w/ val. PCK: %5.2f.\n' % (epoch, val_pck))
55 |
56 | @classmethod
57 | def visualize_selection(cls, catwise_sel):
58 | r"""Visualize (class-wise) layer selection frequency"""
59 | if cls.benchmark == 'pfpascal':
60 | sort_ids = [17, 8, 10, 19, 4, 15, 0, 3, 6, 5, 18, 13, 1, 14, 12, 2, 11, 7, 16, 9]
61 | elif cls.benchmark == 'pfwillow':
62 | sort_ids = np.arange(10)
63 | elif cls.benchmark == 'caltech':
64 | sort_ids = np.arange(101)
65 | elif cls.benchmark == 'spair':
66 | sort_ids = np.arange(18)
67 |
68 | for key in catwise_sel:
69 | catwise_sel[key] = torch.stack(catwise_sel[key]).mean(dim=0).cpu().numpy()
70 |
71 | category = np.array(list(catwise_sel.keys()))[sort_ids]
72 | values = np.array(list(catwise_sel.values()))[sort_ids]
73 | cols = list(range(values.shape[1]))
74 | df = pd.DataFrame(values, index=category, columns=cols)
75 |
76 | plt.pcolor(df, vmin=0.0, vmax=1.0)
77 | plt.gca().set_aspect('equal')
78 | plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
79 | plt.xticks(np.arange(0.5, len(df.columns), 5), df.columns[::5])
80 | plt.tight_layout()
81 |
82 | plt.savefig('%s/selected_layers.jpg' % cls.logpath)
83 |
84 |
85 | class AverageMeter:
86 | r"""Stores loss, evaluation results, selected layers"""
87 | def __init__(self, benchamrk):
88 | r"""Constructor of AverageMeter"""
89 | if benchamrk == 'caltech':
90 | self.buffer_keys = ['ltacc', 'iou']
91 | else:
92 | self.buffer_keys = ['pck']
93 |
94 | self.buffer = {}
95 | for key in self.buffer_keys:
96 | self.buffer[key] = []
97 | self.sel_buffer = {}
98 |
99 | self.loss_buffer = []
100 |
101 | def update(self, eval_result, layer_sel, category, loss=None):
102 | for key in self.buffer_keys:
103 | self.buffer[key] += eval_result[key]
104 |
105 | for sel, cls in zip(layer_sel, category):
106 | if self.sel_buffer.get(cls) is None:
107 | self.sel_buffer[cls] = []
108 | self.sel_buffer[cls] += [sel]
109 |
110 | if loss is not None:
111 | self.loss_buffer.append(loss)
112 |
113 | def write_result(self, split, epoch=-1):
114 | msg = '\n*** %s ' % split
115 | msg += '[@Epoch %02d] ' % epoch if epoch > -1 else ''
116 |
117 | if len(self.loss_buffer) > 0:
118 | msg += 'Loss: %5.2f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
119 |
120 | for key in self.buffer_keys:
121 | msg += '%s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
122 | msg += '***\n'
123 | Logger.info(msg)
124 |
125 | def write_process(self, batch_idx, datalen, epoch=-1):
126 | msg = '[Epoch: %02d] ' % epoch if epoch > -1 else ''
127 | msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
128 | if len(self.loss_buffer) > 0:
129 | msg += 'Loss: %6.2f ' % self.loss_buffer[-1]
130 | msg += 'Avg Loss: %6.5f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
131 |
132 | for key in self.buffer_keys:
133 | msg += 'Avg %s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
134 | Logger.info(msg)
135 |
--------------------------------------------------------------------------------
/common/supervision.py:
--------------------------------------------------------------------------------
1 | r"""Two different strategies of weak/strong supervisions"""
2 | from abc import ABC, abstractmethod
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from model.objective import Objective
8 |
9 |
10 | class SupervisionStrategy(ABC):
11 | r"""Different strategies for methods:"""
12 | @abstractmethod
13 | def get_image_pair(self, batch, *args):
14 | pass
15 |
16 | @abstractmethod
17 | def get_correlation(self, correlation_matrix):
18 | pass
19 |
20 | @abstractmethod
21 | def compute_loss(self, correlation_matrix, *args):
22 | pass
23 |
24 |
25 | class StrongSupStrategy(SupervisionStrategy):
26 | def get_image_pair(self, batch, *args):
27 | r"""Returns (semantically related) pairs for strongly-supervised training"""
28 | return batch['src_img'], batch['trg_img']
29 |
30 | def get_correlation(self, correlation_matrix):
31 | r"""Returns correlation matrices of 'ALL PAIRS' in a batch"""
32 | return correlation_matrix.clone().detach()
33 |
34 | def compute_loss(self, correlation_matrix, *args):
35 | r"""Strongly-supervised matching loss (L_{match})"""
36 | easy_match = args[0]['easy_match']
37 | hard_match = args[0]['hard_match']
38 | layer_sel = args[1]
39 | batch = args[2]
40 |
41 | loss_cre = Objective.weighted_cross_entropy(correlation_matrix, easy_match, hard_match, batch)
42 | loss_sel = Objective.layer_selection_loss(layer_sel)
43 | loss_net = loss_cre + loss_sel
44 |
45 | return loss_net
46 |
47 |
48 | class WeakSupStrategy(SupervisionStrategy):
49 | def get_image_pair(self, batch, *args):
50 | r"""Forms positive/negative image paris for weakly-supervised training"""
51 | training = args[0]
52 | self.bsz = len(batch['src_img'])
53 |
54 | if training:
55 | shifted_idx = np.roll(np.arange(self.bsz), -1)
56 | trg_img_neg = batch['trg_img'][shifted_idx].clone()
57 | trg_cls_neg = batch['category_id'][shifted_idx].clone()
58 | neg_subidx = (batch['category_id'] - trg_cls_neg) != 0
59 |
60 | src_img = torch.cat([batch['src_img'], batch['src_img'][neg_subidx]], dim=0)
61 | trg_img = torch.cat([batch['trg_img'], trg_img_neg[neg_subidx]], dim=0)
62 | self.num_negatives = neg_subidx.sum()
63 | else:
64 | src_img, trg_img = batch['src_img'], batch['trg_img']
65 | self.num_negatives = 0
66 |
67 | return src_img, trg_img
68 |
69 | def get_correlation(self, correlation_matrix):
70 | r"""Returns correlation matrices of 'POSITIVE PAIRS' in a batch"""
71 | return correlation_matrix[:self.bsz].clone().detach()
72 |
73 | def compute_loss(self, correlation_matrix, *args):
74 | r"""Weakly-supervised matching loss (L_{match})"""
75 | layer_sel = args[1]
76 | loss_pos = Objective.information_entropy(correlation_matrix[:self.bsz])
77 | loss_neg = Objective.information_entropy(correlation_matrix[self.bsz:]) if self.num_negatives > 0 else 1.0
78 | loss_sel = Objective.layer_selection_loss(layer_sel)
79 | loss_net = (loss_pos / loss_neg) + loss_sel
80 |
81 | return loss_net
82 |
--------------------------------------------------------------------------------
/common/utils.py:
--------------------------------------------------------------------------------
1 | r"""Some helper functions"""
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def fix_randseed(seed):
9 | r"""Fixes random seed for reproducibility"""
10 | random.seed(seed)
11 | np.random.seed(seed)
12 | torch.manual_seed(seed)
13 | torch.cuda.manual_seed(seed)
14 | torch.cuda.manual_seed_all(seed)
15 | torch.backends.cudnn.benchmark = False
16 | torch.backends.cudnn.deterministic = True
17 |
18 |
19 | def mean(x):
20 | r"""Computes average of a list"""
21 | return sum(x) / len(x) if len(x) > 0 else 0.0
22 |
23 |
24 | def where(predicate):
25 | r"""Predicate must be a condition on nd-tensor"""
26 | matching_indices = predicate.nonzero()
27 | if len(matching_indices) != 0:
28 | matching_indices = matching_indices.t().squeeze(0)
29 | return matching_indices
30 |
--------------------------------------------------------------------------------
/data/caltech.py:
--------------------------------------------------------------------------------
1 | r"""Caltech-101 dataset"""
2 | import os
3 |
4 | import pandas as pd
5 | import numpy as np
6 | import torch
7 |
8 | from .dataset import CorrespondenceDataset
9 |
10 |
11 | class CaltechDataset(CorrespondenceDataset):
12 | r"""Inherits CorrespondenceDataset"""
13 | def __init__(self, benchmark, datapath, thres, device, split):
14 | r"""Caltech-101 dataset constructor"""
15 | super(CaltechDataset, self).__init__(benchmark, datapath, thres, device, split)
16 |
17 | self.train_data = pd.read_csv(self.spt_path)
18 | self.src_imnames = np.array(self.train_data.iloc[:, 0])
19 | self.trg_imnames = np.array(self.train_data.iloc[:, 1])
20 | self.src_kps = self.train_data.iloc[:, 3:5]
21 | self.trg_kps = self.train_data.iloc[:, 5:]
22 | self.cls = ['Faces', 'Faces_easy', 'Leopards', 'Motorbikes', 'accordion', 'airplanes',
23 | 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain',
24 | 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side',
25 | 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body',
26 | 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup',
27 | 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar',
28 | 'elephant', 'emu', 'euphonium', 'ewer', 'ferry', 'flamingo', 'flamingo_head',
29 | 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone',
30 | 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo',
31 | 'ketch', 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly',
32 | 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi', 'pagoda',
33 | 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino',
34 | 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse',
35 | 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign',
36 | 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch',
37 | 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang']
38 | self.cls_ids = self.train_data.iloc[:, 2].values.astype('int') - 1
39 | self.src_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.src_imnames))
40 | self.trg_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.trg_imnames))
41 |
42 | def __getitem__(self, idx):
43 | r"""Constructs and returns a batch for Caltech-101 dataset"""
44 | return super(CaltechDataset, self).__getitem__(idx)
45 |
46 | def get_pckthres(self, batch):
47 | r"""No PCK measure for Caltech-101 dataset"""
48 | return None
49 |
50 | def get_points(self, pts, idx, org_imsize):
51 | r"""Return mask-points of an image"""
52 | x_pts = torch.tensor(list(map(lambda pt: float(pt), pts[pts.columns[0]][idx].split(','))))
53 | y_pts = torch.tensor(list(map(lambda pt: float(pt), pts[pts.columns[1]][idx].split(','))))
54 |
55 | x_pts *= (self.imside / org_imsize[0])
56 | y_pts *= (self.imside / org_imsize[1])
57 |
58 | n_pts = x_pts.size(0)
59 | if n_pts > self.max_pts:
60 | raise Exception('The number of keypoints is above threshold: %d' % n_pts)
61 | pad_pts = torch.zeros((2, self.max_pts - n_pts)) - 1
62 |
63 | kps = torch.cat([torch.stack([x_pts, y_pts]), pad_pts], dim=1).to(self.device)
64 |
65 | return kps, n_pts
66 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | r"""Superclass for semantic correspondence datasets"""
2 | import os
3 |
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | from PIL import Image
7 | import torch
8 |
9 | from model.base.geometry import Geometry
10 |
11 |
12 | class CorrespondenceDataset(Dataset):
13 | r"""Parent class of PFPascal, PFWillow, Caltech, and SPair"""
14 | def __init__(self, benchmark, datapath, thres, device, split):
15 | r"""CorrespondenceDataset constructor"""
16 | super(CorrespondenceDataset, self).__init__()
17 |
18 | # {Directory name, Layout path, Image path, Annotation path, PCK threshold}
19 | self.metadata = {
20 | 'pfwillow': ('PF-WILLOW',
21 | 'test_pairs.csv',
22 | '',
23 | '',
24 | 'bbox'),
25 | 'pfpascal': ('PF-PASCAL',
26 | '_pairs.csv',
27 | 'JPEGImages',
28 | 'Annotations',
29 | 'img'),
30 | 'caltech': ('Caltech-101',
31 | 'test_pairs_caltech_with_category.csv',
32 | '101_ObjectCategories',
33 | '',
34 | ''),
35 | 'spair': ('SPair-71k',
36 | 'Layout/large',
37 | 'JPEGImages',
38 | 'PairAnnotation',
39 | 'bbox')
40 | }
41 |
42 | # Directory path for train, val, or test splits
43 | base_path = os.path.join(os.path.abspath(datapath), self.metadata[benchmark][0])
44 | if benchmark == 'pfpascal':
45 | self.spt_path = os.path.join(base_path, split+'_pairs.csv')
46 | elif benchmark == 'spair':
47 | self.spt_path = os.path.join(base_path, self.metadata[benchmark][1], split+'.txt')
48 | else:
49 | self.spt_path = os.path.join(base_path, self.metadata[benchmark][1])
50 |
51 | # Directory path for images
52 | self.img_path = os.path.join(base_path, self.metadata[benchmark][2])
53 |
54 | # Directory path for annotations
55 | if benchmark == 'spair':
56 | self.ann_path = os.path.join(base_path, self.metadata[benchmark][3], split)
57 | else:
58 | self.ann_path = os.path.join(base_path, self.metadata[benchmark][3])
59 |
60 | # Miscellaneous
61 | if benchmark == 'caltech':
62 | self.max_pts = 400
63 | else:
64 | self.max_pts = 40
65 | self.split = split
66 | self.device = device
67 | self.imside = 240
68 | self.benchmark = benchmark
69 | self.range_ts = torch.arange(self.max_pts)
70 | self.thres = self.metadata[benchmark][4] if thres == 'auto' else thres
71 | self.transform = transforms.Compose([transforms.Resize((self.imside, self.imside)),
72 | transforms.ToTensor(),
73 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
74 | std=[0.229, 0.224, 0.225])])
75 |
76 | # To get initialized in subclass constructors
77 | self.train_data = []
78 | self.src_imnames = []
79 | self.trg_imnames = []
80 | self.cls = []
81 | self.cls_ids = []
82 | self.src_kps = []
83 | self.trg_kps = []
84 |
85 | def __len__(self):
86 | r"""Returns the number of pairs"""
87 | return len(self.train_data)
88 |
89 | def __getitem__(self, idx):
90 | r"""Constructs and return a batch"""
91 |
92 | # Image names
93 | batch = dict()
94 | batch['src_imname'] = self.src_imnames[idx]
95 | batch['trg_imname'] = self.trg_imnames[idx]
96 |
97 | # Class of instances in the images
98 | batch['category_id'] = self.cls_ids[idx]
99 | batch['category'] = self.cls[batch['category_id']]
100 |
101 | # Image as numpy (original width, original height)
102 | src_pil = self.get_image(self.src_imnames, idx)
103 | trg_pil = self.get_image(self.trg_imnames, idx)
104 | batch['src_imsize'] = src_pil.size
105 | batch['trg_imsize'] = trg_pil.size
106 |
107 | # Image as tensor
108 | batch['src_img'] = self.transform(src_pil).to(self.device)
109 | batch['trg_img'] = self.transform(trg_pil).to(self.device)
110 |
111 | # Key-points (re-scaled)
112 | batch['src_kps'], num_pts = self.get_points(self.src_kps, idx, src_pil.size)
113 | batch['trg_kps'], _ = self.get_points(self.trg_kps, idx, trg_pil.size)
114 | batch['n_pts'] = torch.tensor(num_pts)
115 |
116 | # The number of pairs in training split
117 | batch['datalen'] = len(self.train_data)
118 |
119 | return batch
120 |
121 | def get_image(self, imnames, idx):
122 | r"""Reads PIL image from path"""
123 | path = os.path.join(self.img_path, imnames[idx])
124 | return Image.open(path).convert('RGB')
125 |
126 | def get_pckthres(self, batch, imsize):
127 | r"""Computes PCK threshold"""
128 | if self.thres == 'bbox':
129 | bbox = batch['trg_bbox'].clone()
130 | bbox_w = (bbox[2] - bbox[0])
131 | bbox_h = (bbox[3] - bbox[1])
132 | pckthres = torch.max(bbox_w, bbox_h)
133 | elif self.thres == 'img':
134 | imsize_t = batch['trg_img'].size()
135 | pckthres = torch.tensor(max(imsize_t[1], imsize_t[2]))
136 | else:
137 | raise Exception('Invalid pck threshold type: %s' % self.thres)
138 | return pckthres.float().to(self.device)
139 |
140 | def get_points(self, pts_list, idx, org_imsize):
141 | r"""Returns key-points of an image with size of (240,240)"""
142 | xy, n_pts = pts_list[idx].size()
143 | pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 1
144 | x_crds = pts_list[idx][0] * (self.imside / org_imsize[0])
145 | y_crds = pts_list[idx][1] * (self.imside / org_imsize[1])
146 | kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1).to(self.device)
147 |
148 | return kps, n_pts
149 |
150 | def match_idx(self, kps, n_pts):
151 | r"""Samples the nearst feature (receptive field) indices"""
152 | nearest_idx = find_knn(Geometry.rf_center, kps.t())
153 | nearest_idx -= (self.range_ts >= n_pts).to(self.device).long()
154 |
155 | return nearest_idx
156 |
157 |
158 | def find_knn(db_vectors, qr_vectors):
159 | r"""Finds K-nearest neighbors (Euclidean distance)"""
160 | db = db_vectors.unsqueeze(1).repeat(1, qr_vectors.size(0), 1)
161 | qr = qr_vectors.unsqueeze(0).repeat(db_vectors.size(0), 1, 1)
162 | dist = (db - qr).pow(2).sum(2).pow(0.5).t()
163 | _, nearest_idx = dist.min(dim=1)
164 |
165 | return nearest_idx
166 |
--------------------------------------------------------------------------------
/data/download.py:
--------------------------------------------------------------------------------
1 | r"""Functions to download semantic correspondence datasets"""
2 | import tarfile
3 | import os
4 |
5 | import requests
6 |
7 | from . import pfpascal
8 | from . import pfwillow
9 | from . import caltech
10 | from . import spair
11 |
12 |
13 | def load_dataset(benchmark, datapath, thres, device, split='test'):
14 | r"""Instantiates desired correspondence dataset"""
15 | correspondence_benchmark = {
16 | 'pfpascal': pfpascal.PFPascalDataset,
17 | 'pfwillow': pfwillow.PFWillowDataset,
18 | 'caltech': caltech.CaltechDataset,
19 | 'spair': spair.SPairDataset,
20 | }
21 |
22 | dataset = correspondence_benchmark.get(benchmark)
23 | if dataset is None:
24 | raise Exception('Invalid benchmark dataset %s.' % benchmark)
25 |
26 | return dataset(benchmark, datapath, thres, device, split)
27 |
28 |
29 | def download_from_google(token_id, filename):
30 | r"""Downloads desired filename from Google drive"""
31 | print('Downloading %s ...' % os.path.basename(filename))
32 |
33 | url = 'https://docs.google.com/uc?export=download'
34 | destination = filename + '.tar.gz'
35 | session = requests.Session()
36 |
37 | response = session.get(url, params={'id': token_id, 'confirm':'t'}, stream=True)
38 | token = get_confirm_token(response)
39 |
40 | if token:
41 | params = {'id': token_id, 'confirm': token}
42 | response = session.get(url, params=params, stream=True)
43 | save_response_content(response, destination)
44 | file = tarfile.open(destination, 'r:gz')
45 |
46 | print("Extracting %s ..." % destination)
47 | file.extractall(filename)
48 | file.close()
49 |
50 | os.remove(destination)
51 | os.rename(filename, filename + '_tmp')
52 | os.rename(os.path.join(filename + '_tmp', os.path.basename(filename)), filename)
53 | os.rmdir(filename+'_tmp')
54 |
55 |
56 | def get_confirm_token(response):
57 | r"""Retrieves confirm token"""
58 | for key, value in response.cookies.items():
59 | if key.startswith('download_warning'):
60 | return value
61 |
62 | return None
63 |
64 |
65 | def save_response_content(response, destination):
66 | r"""Saves the response to the destination"""
67 | chunk_size = 32768
68 |
69 | with open(destination, "wb") as file:
70 | for chunk in response.iter_content(chunk_size):
71 | if chunk:
72 | file.write(chunk)
73 |
74 |
75 | def download_dataset(datapath, benchmark):
76 | r"""Downloads semantic correspondence benchmark dataset from Google drive"""
77 | if not os.path.isdir(datapath):
78 | os.mkdir(datapath)
79 |
80 | file_data = {
81 | 'pfwillow': ('1tDP0y8RO5s45L-vqnortRaieiWENQco_', 'PF-WILLOW'),
82 | 'pfpascal': ('1OOwpGzJnTsFXYh-YffMQ9XKM_Kl_zdzg', 'PF-PASCAL'),
83 | 'caltech': ('1IV0E5sJ6xSdDyIvVSTdZjPHELMwGzsMn', 'Caltech-101'),
84 | 'spair': ('1KSvB0k2zXA06ojWNvFjBv0Ake426Y76k', 'SPair-71k')
85 | # 'spair': ('1s73NVEFPro260H1tXxCh1ain7oApR8of', 'SPair-71k') old version
86 | }
87 |
88 | file_id, filename = file_data[benchmark]
89 | abs_filepath = os.path.join(datapath, filename)
90 |
91 | if not os.path.isdir(abs_filepath):
92 | download_from_google(file_id, abs_filepath)
93 |
--------------------------------------------------------------------------------
/data/pfpascal.py:
--------------------------------------------------------------------------------
1 | r"""PF-PASCAL dataset"""
2 | import os
3 |
4 | import scipy.io as sio
5 | import pandas as pd
6 | import numpy as np
7 | import torch
8 |
9 | from .dataset import CorrespondenceDataset
10 |
11 |
12 | class PFPascalDataset(CorrespondenceDataset):
13 | r"""Inherits CorrespondenceDataset"""
14 | def __init__(self, benchmark, datapath, thres, device, split):
15 | r"""PF-PASCAL dataset constructor"""
16 | super(PFPascalDataset, self).__init__(benchmark, datapath, thres, device, split)
17 |
18 | self.train_data = pd.read_csv(self.spt_path)
19 | self.src_imnames = np.array(self.train_data.iloc[:, 0])
20 | self.trg_imnames = np.array(self.train_data.iloc[:, 1])
21 | self.cls = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
22 | 'bus', 'car', 'cat', 'chair', 'cow',
23 | 'diningtable', 'dog', 'horse', 'motorbike', 'person',
24 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
25 | self.cls_ids = self.train_data.iloc[:, 2].values.astype('int') - 1
26 |
27 | if split == 'trn':
28 | self.flip = self.train_data.iloc[:, 3].values.astype('int')
29 | self.src_kps = []
30 | self.trg_kps = []
31 | self.src_bbox = []
32 | self.trg_bbox = []
33 | for src_imname, trg_imname, cls in zip(self.src_imnames, self.trg_imnames, self.cls_ids):
34 | src_anns = os.path.join(self.ann_path, self.cls[cls],
35 | os.path.basename(src_imname))[:-4] + '.mat'
36 | trg_anns = os.path.join(self.ann_path, self.cls[cls],
37 | os.path.basename(trg_imname))[:-4] + '.mat'
38 |
39 | src_kp = torch.tensor(read_mat(src_anns, 'kps')).float()
40 | trg_kp = torch.tensor(read_mat(trg_anns, 'kps')).float()
41 | src_box = torch.tensor(read_mat(src_anns, 'bbox')[0].astype(float))
42 | trg_box = torch.tensor(read_mat(trg_anns, 'bbox')[0].astype(float))
43 |
44 | src_kps = []
45 | trg_kps = []
46 | for src_kk, trg_kk in zip(src_kp, trg_kp):
47 | if len(torch.isnan(src_kk).nonzero()) != 0 or \
48 | len(torch.isnan(trg_kk).nonzero()) != 0:
49 | continue
50 | else:
51 | src_kps.append(src_kk)
52 | trg_kps.append(trg_kk)
53 | self.src_kps.append(torch.stack(src_kps).t())
54 | self.trg_kps.append(torch.stack(trg_kps).t())
55 | self.src_bbox.append(src_box)
56 | self.trg_bbox.append(trg_box)
57 |
58 | self.src_imnames = list(map(lambda x: os.path.basename(x), self.src_imnames))
59 | self.trg_imnames = list(map(lambda x: os.path.basename(x), self.trg_imnames))
60 |
61 | def __getitem__(self, idx):
62 | r"""Constructs and return a batch for PF-PASCAL dataset"""
63 | batch = super(PFPascalDataset, self).__getitem__(idx)
64 |
65 | # Object bounding-box (resized following self.imside)
66 | batch['src_bbox'] = self.get_bbox(self.src_bbox, idx, batch['src_imsize'])
67 | batch['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, batch['trg_imsize'])
68 | batch['pckthres'] = self.get_pckthres(batch, batch['trg_imsize'])
69 |
70 | # Horizontal flipping key-points during training
71 | if self.split == 'trn' and self.flip[idx]:
72 | self.horizontal_flip(batch)
73 | batch['flip'] = 1
74 | else:
75 | batch['flip'] = 0
76 |
77 | batch['src_kpidx'] = self.match_idx(batch['src_kps'], batch['n_pts'])
78 | batch['trg_kpidx'] = self.match_idx(batch['trg_kps'], batch['n_pts'])
79 |
80 | return batch
81 |
82 | def get_bbox(self, bbox_list, idx, imsize):
83 | r"""Returns object bounding-box"""
84 | bbox = bbox_list[idx].clone()
85 | bbox[0::2] *= (self.imside / imsize[0])
86 | bbox[1::2] *= (self.imside / imsize[1])
87 | return bbox.to(self.device)
88 |
89 | def horizontal_flip(self, batch):
90 | tmp = batch['src_bbox'][0].clone()
91 | batch['src_bbox'][0] = batch['src_img'].size(2) - batch['src_bbox'][2]
92 | batch['src_bbox'][2] = batch['src_img'].size(2) - tmp
93 |
94 | tmp = batch['trg_bbox'][0].clone()
95 | batch['trg_bbox'][0] = batch['trg_img'].size(2) - batch['trg_bbox'][2]
96 | batch['trg_bbox'][2] = batch['trg_img'].size(2) - tmp
97 |
98 | batch['src_kps'][0][:batch['n_pts']] = batch['src_img'].size(2) - batch['src_kps'][0][:batch['n_pts']]
99 | batch['trg_kps'][0][:batch['n_pts']] = batch['trg_img'].size(2) - batch['trg_kps'][0][:batch['n_pts']]
100 |
101 | batch['src_img'] = torch.flip(batch['src_img'], dims=(2,))
102 | batch['trg_img'] = torch.flip(batch['trg_img'], dims=(2,))
103 |
104 |
105 | def read_mat(path, obj_name):
106 | r"""Reads specified objects from Matlab data file, (.mat)"""
107 | mat_contents = sio.loadmat(path)
108 | mat_obj = mat_contents[obj_name]
109 |
110 | return mat_obj
111 |
--------------------------------------------------------------------------------
/data/pfwillow.py:
--------------------------------------------------------------------------------
1 | r"""PF-WILLOW dataset"""
2 | import os
3 |
4 | import pandas as pd
5 | import numpy as np
6 | import torch
7 |
8 | from .dataset import CorrespondenceDataset
9 |
10 |
11 | class PFWillowDataset(CorrespondenceDataset):
12 | r"""Inherits CorrespondenceDataset"""
13 | def __init__(self, benchmark, datapath, thres, device, split):
14 | r"""PF-WILLOW dataset constructor"""
15 | super(PFWillowDataset, self).__init__(benchmark, datapath, thres, device, split)
16 |
17 | self.train_data = pd.read_csv(self.spt_path)
18 | self.src_imnames = np.array(self.train_data.iloc[:, 0])
19 | self.trg_imnames = np.array(self.train_data.iloc[:, 1])
20 | self.src_kps = self.train_data.iloc[:, 2:22].values
21 | self.trg_kps = self.train_data.iloc[:, 22:].values
22 | self.cls = ['car(G)', 'car(M)', 'car(S)', 'duck(S)',
23 | 'motorbike(G)', 'motorbike(M)', 'motorbike(S)',
24 | 'winebottle(M)', 'winebottle(wC)', 'winebottle(woC)']
25 | self.cls_ids = list(map(lambda names: self.cls.index(names.split('/')[1]), self.src_imnames))
26 | self.src_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.src_imnames))
27 | self.trg_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.trg_imnames))
28 |
29 | def __getitem__(self, idx):
30 | r"""Constructs and return a batch for PF-WILLOW dataset"""
31 | batch = super(PFWillowDataset, self).__getitem__(idx)
32 | batch['pckthres'] = self.get_pckthres(batch).to(self.device)
33 |
34 | batch['src_kpidx'] = self.match_idx(batch['src_kps'], batch['n_pts'])
35 | batch['trg_kpidx'] = self.match_idx(batch['trg_kps'], batch['n_pts'])
36 |
37 | return batch
38 |
39 | def get_pckthres(self, batch):
40 | r"""Computes PCK threshold"""
41 | if self.thres == 'bbox':
42 | return max(batch['trg_kps'].max(1)[0] - batch['trg_kps'].min(1)[0]).clone()
43 | elif self.thres == 'img':
44 | return torch.tensor(max(batch['trg_img'].size()[1], batch['trg_img'].size()[2]))
45 | else:
46 | raise Exception('Invalid pck evaluation level: %s' % self.thres)
47 |
48 | def get_points(self, pts_list, idx, org_imsize):
49 | r"""Returns key-points of an image"""
50 | point_coords = pts_list[idx, :].reshape(2, 10)
51 | point_coords = torch.tensor(point_coords.astype(np.float32))
52 | xy, n_pts = point_coords.size()
53 | pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 1
54 | x_crds = point_coords[0] * (self.imside / org_imsize[0])
55 | y_crds = point_coords[1] * (self.imside / org_imsize[1])
56 | kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1).to(self.device)
57 |
58 | return kps, n_pts
59 |
--------------------------------------------------------------------------------
/data/spair.py:
--------------------------------------------------------------------------------
1 | r"""SPair-71k dataset"""
2 | import json
3 | import glob
4 | import os
5 |
6 | from PIL import Image
7 | import torch
8 |
9 | from .dataset import CorrespondenceDataset
10 |
11 |
12 | class SPairDataset(CorrespondenceDataset):
13 | r"""Inherits CorrespondenceDataset"""
14 | def __init__(self, benchmark, datapath, thres, device, split):
15 | r"""SPair-71k dataset constructor"""
16 | super(SPairDataset, self).__init__(benchmark, datapath, thres, device, split)
17 |
18 | self.train_data = open(self.spt_path).read().split('\n')
19 | self.train_data = self.train_data[:len(self.train_data) - 1]
20 | self.src_imnames = list(map(lambda x: x.split('-')[1] + '.jpg', self.train_data))
21 | self.trg_imnames = list(map(lambda x: x.split('-')[2].split(':')[0] + '.jpg', self.train_data))
22 | self.cls = os.listdir(self.img_path)
23 | self.cls.sort()
24 |
25 | anntn_files = []
26 | for data_name in self.train_data:
27 | anntn_files.append(glob.glob('%s/%s.json' % (self.ann_path, data_name))[0])
28 | anntn_files = list(map(lambda x: json.load(open(x)), anntn_files))
29 | self.src_kps = list(map(lambda x: torch.tensor(x['src_kps']).t().float(), anntn_files))
30 | self.trg_kps = list(map(lambda x: torch.tensor(x['trg_kps']).t().float(), anntn_files))
31 | self.src_bbox = list(map(lambda x: torch.tensor(x['src_bndbox']).float(), anntn_files))
32 | self.trg_bbox = list(map(lambda x: torch.tensor(x['trg_bndbox']).float(), anntn_files))
33 | self.cls_ids = list(map(lambda x: self.cls.index(x['category']), anntn_files))
34 |
35 | self.vpvar = list(map(lambda x: torch.tensor(x['viewpoint_variation']), anntn_files))
36 | self.scvar = list(map(lambda x: torch.tensor(x['scale_variation']), anntn_files))
37 | self.trncn = list(map(lambda x: torch.tensor(x['truncation']), anntn_files))
38 | self.occln = list(map(lambda x: torch.tensor(x['occlusion']), anntn_files))
39 |
40 | def __getitem__(self, idx):
41 | r"""Constructs and return a batch for SPair-71k dataset"""
42 | batch = super(SPairDataset, self).__getitem__(idx)
43 |
44 | batch['src_bbox'] = self.get_bbox(self.src_bbox, idx, batch['src_imsize'])
45 | batch['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, batch['trg_imsize'])
46 | batch['pckthres'] = self.get_pckthres(batch, batch['trg_imsize'])
47 |
48 | batch['src_kpidx'] = self.match_idx(batch['src_kps'], batch['n_pts'])
49 | batch['trg_kpidx'] = self.match_idx(batch['trg_kps'], batch['n_pts'])
50 |
51 | batch['vpvar'] = self.vpvar[idx]
52 | batch['scvar'] = self.scvar[idx]
53 | batch['trncn'] = self.trncn[idx]
54 | batch['occln'] = self.occln[idx]
55 |
56 | return batch
57 |
58 | def get_image(self, img_names, idx):
59 | r"""Returns image tensor"""
60 | path = os.path.join(self.img_path, self.cls[self.cls_ids[idx]], img_names[idx])
61 |
62 | return Image.open(path).convert('RGB')
63 |
64 | def get_bbox(self, bbox_list, idx, imsize):
65 | r"""Returns object bounding-box"""
66 | bbox = bbox_list[idx].clone()
67 | bbox[0::2] *= (self.imside / imsize[0])
68 | bbox[1::2] *= (self.imside / imsize[1])
69 | return bbox.to(self.device)
70 |
--------------------------------------------------------------------------------
/model/base/correlation.py:
--------------------------------------------------------------------------------
1 | r"""Provides functions that creates/manipulates correlation matrices"""
2 | import torch.nn.functional as F
3 | import torch
4 |
5 |
6 | class Correlation:
7 | @classmethod
8 | def bmm_interp(cls, src_feat, trg_feat, interp_size):
9 | r"""Performs batch-wise matrix-multiplication after interpolation"""
10 | src_feat = F.interpolate(src_feat, interp_size, mode='bilinear', align_corners=True)
11 | trg_feat = F.interpolate(trg_feat, interp_size, mode='bilinear', align_corners=True)
12 |
13 | src_feat = src_feat.view(src_feat.size(0), src_feat.size(1), -1).transpose(1, 2)
14 | trg_feat = trg_feat.view(trg_feat.size(0), trg_feat.size(1), -1)
15 |
16 | return torch.bmm(src_feat, trg_feat)
17 |
18 | @classmethod
19 | def mutual_nn_filter(cls, correlation_matrix):
20 | r"""Mutual nearest neighbor filtering (Rocco et al. NeurIPS'18)"""
21 | corr_src_max = torch.max(correlation_matrix, dim=2, keepdim=True)[0]
22 | corr_trg_max = torch.max(correlation_matrix, dim=1, keepdim=True)[0]
23 | corr_src_max[corr_src_max == 0] += 1e-30
24 | corr_trg_max[corr_trg_max == 0] += 1e-30
25 |
26 | corr_src = correlation_matrix / corr_src_max
27 | corr_trg = correlation_matrix / corr_trg_max
28 |
29 | return correlation_matrix * (corr_src * corr_trg)
30 |
--------------------------------------------------------------------------------
/model/base/geometry.py:
--------------------------------------------------------------------------------
1 | """Provides functions that manipulate boxes and points"""
2 | import torch
3 |
4 | from .correlation import Correlation
5 |
6 |
7 | class Geometry:
8 | @classmethod
9 | def initialize(cls, feat_size, device):
10 | cls.max_pts = 400
11 | cls.eps = 1e-30
12 | cls.rfs = cls.receptive_fields(11, 4, feat_size).to(device)
13 | cls.rf_center = Geometry.center(cls.rfs)
14 |
15 | @classmethod
16 | def center(cls, box):
17 | r"""Computes centers, (x, y), of box (N, 4)"""
18 | x_center = box[:, 0] + (box[:, 2] - box[:, 0]) // 2
19 | y_center = box[:, 1] + (box[:, 3] - box[:, 1]) // 2
20 | return torch.stack((x_center, y_center)).t().to(box.device)
21 |
22 | @classmethod
23 | def receptive_fields(cls, rfsz, jsz, feat_size):
24 | r"""Returns a set of receptive fields (N, 4)"""
25 | width = feat_size[1]
26 | height = feat_size[0]
27 |
28 | feat_ids = torch.tensor(list(range(width))).repeat(1, height).t().repeat(1, 2)
29 | feat_ids[:, 0] = torch.tensor(list(range(height))).unsqueeze(1).repeat(1, width).view(-1)
30 |
31 | box = torch.zeros(feat_ids.size()[0], 4)
32 | box[:, 0] = feat_ids[:, 1] * jsz - rfsz // 2
33 | box[:, 1] = feat_ids[:, 0] * jsz - rfsz // 2
34 | box[:, 2] = feat_ids[:, 1] * jsz + rfsz // 2
35 | box[:, 3] = feat_ids[:, 0] * jsz + rfsz // 2
36 |
37 | return box
38 |
39 | @classmethod
40 | def gaussian2d(cls, side=7):
41 | r"""Returns 2-dimensional gaussian filter"""
42 | dim = [side, side]
43 |
44 | siz = torch.LongTensor(dim)
45 | sig_sq = (siz.float()/2/2.354).pow(2)
46 | siz2 = (siz-1)/2
47 |
48 | x_axis = torch.arange(-siz2[0], siz2[0] + 1).unsqueeze(0).expand(dim).float()
49 | y_axis = torch.arange(-siz2[1], siz2[1] + 1).unsqueeze(1).expand(dim).float()
50 |
51 | gaussian = torch.exp(-(x_axis.pow(2)/2/sig_sq[0] + y_axis.pow(2)/2/sig_sq[1]))
52 | gaussian = gaussian / gaussian.sum()
53 |
54 | return gaussian
55 |
56 | @classmethod
57 | def neighbours(cls, box, kps):
58 | r"""Returns boxes in one-hot format that covers given keypoints"""
59 | box_duplicate = box.unsqueeze(2).repeat(1, 1, len(kps.t())).transpose(0, 1)
60 | kps_duplicate = kps.unsqueeze(1).repeat(1, len(box), 1)
61 |
62 | xmin = kps_duplicate[0].ge(box_duplicate[0])
63 | ymin = kps_duplicate[1].ge(box_duplicate[1])
64 | xmax = kps_duplicate[0].le(box_duplicate[2])
65 | ymax = kps_duplicate[1].le(box_duplicate[3])
66 |
67 | nbr_onehot = torch.mul(torch.mul(xmin, ymin), torch.mul(xmax, ymax)).t()
68 | n_neighbours = nbr_onehot.sum(dim=1)
69 |
70 | return nbr_onehot, n_neighbours
71 |
72 | @classmethod
73 | def transfer_kps(cls, correlation_matrix, kps, n_pts):
74 | r"""Transfer keypoints by nearest-neighbour assignment"""
75 | correlation_matrix = Correlation.mutual_nn_filter(correlation_matrix)
76 |
77 | prd_kps = []
78 | for ct, kpss, np in zip(correlation_matrix, kps, n_pts):
79 |
80 | # 1. Prepare geometries & argmax target indices
81 | kp = kpss.narrow_copy(1, 0, np)
82 | _, trg_argmax_idx = torch.max(ct, dim=1)
83 | geomet = cls.rfs[:, :2].unsqueeze(0).repeat(len(kp.t()), 1, 1)
84 |
85 | # 2. Retrieve neighbouring source boxes that cover source key-points
86 | src_nbr_onehot, n_neighbours = cls.neighbours(cls.rfs, kp)
87 |
88 | # 3. Get displacements from source neighbouring box centers to each key-point
89 | src_displacements = kp.t().unsqueeze(1).repeat(1, len(cls.rfs), 1) - geomet
90 | src_displacements = src_displacements * src_nbr_onehot.unsqueeze(2).repeat(1, 1, 2).float()
91 |
92 | # 4. Transfer the neighbours based on given correlation matrix
93 | vector_summator = torch.zeros_like(geomet)
94 | src_idx = src_nbr_onehot.nonzero()
95 |
96 | trg_idx = trg_argmax_idx.index_select(dim=0, index=src_idx[:, 1])
97 | vector_summator[src_idx[:, 0], src_idx[:, 1]] = geomet[src_idx[:, 0], trg_idx]
98 | vector_summator += src_displacements
99 | prd = (vector_summator.sum(dim=1) / n_neighbours.unsqueeze(1).repeat(1, 2).float()).t()
100 |
101 | # 5. Concatenate pad-points
102 | pads = (torch.zeros((2, cls.max_pts - np)).to(prd.device) - 1)
103 | prd = torch.cat([prd, pads], dim=1)
104 | prd_kps.append(prd)
105 |
106 | return torch.stack(prd_kps)
107 |
--------------------------------------------------------------------------------
/model/base/norm.py:
--------------------------------------------------------------------------------
1 | r"""Normalization functions"""
2 | import torch.nn.functional as F
3 | import torch
4 |
5 |
6 | class Norm:
7 | r"""Vector normalization"""
8 | @classmethod
9 | def feat_normalize(cls, x, interp_size):
10 | r"""L2-normalizes given 2D feature map after interpolation"""
11 | x = F.interpolate(x, interp_size, mode='bilinear', align_corners=True)
12 | return x.pow(2).sum(1).view(x.size(0), -1)
13 |
14 | @classmethod
15 | def l1normalize(cls, x):
16 | r"""L1-normalization"""
17 | vector_sum = torch.sum(x, dim=2, keepdim=True)
18 | vector_sum[vector_sum == 0] = 1.0
19 | return x / vector_sum
20 |
21 | @classmethod
22 | def unit_gaussian_normalize(cls, x):
23 | r"""Make each (row) distribution into unit gaussian"""
24 | correlation_matrix = x - x.mean(dim=2).unsqueeze(2).expand_as(x)
25 |
26 | with torch.no_grad():
27 | standard_deviation = correlation_matrix.std(dim=2)
28 | standard_deviation[standard_deviation == 0] = 1.0
29 | correlation_matrix /= standard_deviation.unsqueeze(2).expand_as(correlation_matrix)
30 |
31 | return correlation_matrix
32 |
--------------------------------------------------------------------------------
/model/base/resnet.py:
--------------------------------------------------------------------------------
1 | r"""ResNet code from PyTorch library"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 |
6 |
7 | __all__ = ['Backbone', 'resnet50', 'resnet101']
8 |
9 |
10 | model_urls = {
11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16 | }
17 |
18 |
19 | def conv3x3(in_planes, out_planes, stride=1):
20 | """3x3 convolution with padding"""
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=1, groups=2, bias=False)
23 |
24 |
25 | def conv1x1(in_planes, out_planes, stride=1):
26 | """1x1 convolution"""
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, groups=2, bias=False)
28 |
29 |
30 | class Bottleneck(nn.Module):
31 | expansion = 4
32 |
33 | def __init__(self, inplanes, planes, stride=1, downsample=None):
34 | super(Bottleneck, self).__init__()
35 | self.conv1 = conv1x1(inplanes, planes)
36 | self.bn1 = nn.BatchNorm2d(planes)
37 | self.conv2 = conv3x3(planes, planes, stride)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 | self.conv3 = conv1x1(planes, planes * self.expansion)
40 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
41 | self.relu = nn.ReLU(inplace=True)
42 | self.downsample = downsample
43 | self.stride = stride
44 |
45 | def forward(self, x):
46 | identity = x
47 |
48 | out = self.conv1(x)
49 | out = self.bn1(out)
50 | out = self.relu(out)
51 |
52 | out = self.conv2(out)
53 | out = self.bn2(out)
54 | out = self.relu(out)
55 |
56 | out = self.conv3(out)
57 | out = self.bn3(out)
58 |
59 | if self.downsample is not None:
60 | identity = self.downsample(x)
61 |
62 | out += identity
63 | out = self.relu(out)
64 |
65 | return out
66 |
67 |
68 | class Backbone(nn.Module):
69 | def __init__(self, block, layers, zero_init_residual=False):
70 | super(Backbone, self).__init__()
71 |
72 | self.inplanes = 128
73 | self.conv1 = nn.Conv2d(6, 128, kernel_size=7, stride=2, padding=3, groups=2,
74 | bias=False)
75 | self.bn1 = nn.BatchNorm2d(128)
76 | self.relu = nn.ReLU(inplace=True)
77 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
78 | self.layer1 = self._make_layer(block, 128, layers[0])
79 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2)
80 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2)
81 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2)
82 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
83 | self.fc = nn.Linear(512 * block.expansion, 1000)
84 |
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
88 | elif isinstance(m, nn.BatchNorm2d):
89 | nn.init.constant_(m.weight, 1)
90 | nn.init.constant_(m.bias, 0)
91 |
92 | # Zero-initialize the last BN in each residual branch,
93 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
94 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
95 | if zero_init_residual:
96 | for m in self.modules():
97 | if isinstance(m, Bottleneck):
98 | nn.init.constant_(m.bn3.weight, 0)
99 | # elif isinstance(m, BasicBlock):
100 | # nn.init.constant_(m.bn2.weight, 0)
101 |
102 | def _make_layer(self, block, planes, blocks, stride=1):
103 | downsample = None
104 | if stride != 1 or self.inplanes != planes * block.expansion:
105 | downsample = nn.Sequential(
106 | conv1x1(self.inplanes, planes * block.expansion, stride),
107 | nn.BatchNorm2d(planes * block.expansion),
108 | )
109 |
110 | layers = []
111 | layers.append(block(self.inplanes, planes, stride, downsample))
112 | self.inplanes = planes * block.expansion
113 | for _ in range(1, blocks):
114 | layers.append(block(self.inplanes, planes))
115 |
116 | return nn.Sequential(*layers)
117 |
118 |
119 | def resnet50(pretrained=False, **kwargs):
120 | """Constructs a ResNet-50 model.
121 |
122 | Args:
123 | pretrained (bool): If True, returns a model pre-trained on ImageNet
124 | """
125 | model = Backbone(Bottleneck, [3, 4, 6, 3], **kwargs)
126 | if pretrained:
127 | weights = model_zoo.load_url(model_urls['resnet50'])
128 |
129 | for key in weights:
130 | if key.split('.')[0] == 'fc':
131 | weights[key] = weights[key].clone()
132 | continue
133 | weights[key] = torch.cat([weights[key].clone(), weights[key].clone()], dim=0)
134 |
135 | model.load_state_dict(weights)
136 | return model
137 |
138 |
139 | def resnet101(pretrained=False, **kwargs):
140 | """Constructs a ResNet-101 model.
141 |
142 | Args:
143 | pretrained (bool): If True, returns a model pre-trained on ImageNet
144 | """
145 | model = Backbone(Bottleneck, [3, 4, 23, 3], **kwargs)
146 | if pretrained:
147 | weights = model_zoo.load_url(model_urls['resnet101'])
148 |
149 | for key in weights:
150 | if key.split('.')[0] == 'fc':
151 | weights[key] = weights[key].clone()
152 | continue
153 | weights[key] = torch.cat([weights[key].clone(), weights[key].clone()], dim=0)
154 |
155 | model.load_state_dict(weights)
156 | return model
157 |
--------------------------------------------------------------------------------
/model/dhpf.py:
--------------------------------------------------------------------------------
1 | """Implementation of Dynamic Hyperpixel Flow"""
2 | from functools import reduce
3 | from operator import add
4 |
5 | import torch.nn as nn
6 | import torch
7 |
8 | from .base.correlation import Correlation
9 | from .base.geometry import Geometry
10 | from .base.norm import Norm
11 | from .base import resnet
12 | from . import gating
13 | from . import rhm
14 |
15 |
16 | class DynamicHPF:
17 | r"""Dynamic Hyperpixel Flow (DHPF)"""
18 | def __init__(self, backbone, device, img_side=240):
19 | r"""Constructor for DHPF"""
20 | super(DynamicHPF, self).__init__()
21 |
22 | # 1. Backbone network initialization
23 | if backbone == 'resnet50':
24 | self.backbone = resnet.resnet50(pretrained=True).to(device)
25 | self.in_channels = [64, 256, 256, 256, 512, 512, 512, 512, 1024,
26 | 1024, 1024, 1024, 1024, 1024, 2048, 2048, 2048]
27 | nbottlenecks = [3, 4, 6, 3]
28 | elif backbone == 'resnet101':
29 | self.backbone = resnet.resnet101(pretrained=True).to(device)
30 | self.in_channels = [64, 256, 256, 256, 512, 512, 512, 512,
31 | 1024, 1024, 1024, 1024, 1024, 1024, 1024,
32 | 1024, 1024, 1024, 1024, 1024, 1024, 1024,
33 | 1024, 1024, 1024, 1024, 1024, 1024, 1024,
34 | 1024, 1024, 2048, 2048, 2048]
35 | nbottlenecks = [3, 4, 23, 3]
36 | else:
37 | raise Exception('Unavailable backbone: %s' % backbone)
38 | self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks)))
39 | self.layer_ids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)])
40 | self.backbone.eval()
41 |
42 | # 2. Dynamic layer gatings initialization
43 | self.learner = gating.GumbelFeatureSelection(self.in_channels).to(device)
44 |
45 | # 3. Miscellaneous
46 | self.relu = nn.ReLU()
47 | self.upsample_size = [int(img_side / 4)] * 2
48 | Geometry.initialize(self.upsample_size, device)
49 | self.rhm = rhm.HoughMatching(Geometry.rfs, torch.tensor([img_side, img_side]).to(device))
50 |
51 | # Forward pass
52 | def __call__(self, *args, **kwargs):
53 | # 1. Compute correlations between hyperimages
54 | src_img = args[0]
55 | trg_img = args[1]
56 | correlation_matrix, layer_sel = self.hyperimage_correlation(src_img, trg_img)
57 |
58 | # 2. Compute geometric matching scores to re-weight appearance matching scores (RHM)
59 | with torch.no_grad(): # no back-prop thru rhm due to memory issue
60 | geometric_scores = torch.stack([self.rhm.run(c.clone().detach()) for c in correlation_matrix], dim=0)
61 | correlation_matrix *= geometric_scores
62 |
63 | return correlation_matrix, layer_sel
64 |
65 | def hyperimage_correlation(self, src_img, trg_img):
66 | r"""Dynamically construct hyperimages and compute their correlations"""
67 | layer_sel = []
68 | correlation, src_norm, trg_norm = 0, 0, 0
69 |
70 | # Concatenate source & target images (B,6,H,W)
71 | # Perform group convolution (group=2) for faster inference time
72 | pair_img = torch.cat([src_img, trg_img], dim=1)
73 |
74 | # Layer 0
75 | with torch.no_grad():
76 | feat = self.backbone.conv1.forward(pair_img)
77 | feat = self.backbone.bn1.forward(feat)
78 | feat = self.backbone.relu.forward(feat)
79 | feat = self.backbone.maxpool.forward(feat)
80 |
81 | src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone()
82 | trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone()
83 |
84 | # Save base maps
85 | base_src_feat = self.learner.reduction_ffns[0](src_feat)
86 | base_trg_feat = self.learner.reduction_ffns[0](trg_feat)
87 | base_correlation = Correlation.bmm_interp(base_src_feat, base_trg_feat, self.upsample_size)
88 | base_src_norm = Norm.feat_normalize(base_src_feat, self.upsample_size)
89 | base_trg_norm = Norm.feat_normalize(base_trg_feat, self.upsample_size)
90 |
91 | src_feat, trg_feat, lsel = self.learner(0, src_feat, trg_feat)
92 | if src_feat is not None and trg_feat is not None:
93 | correlation += Correlation.bmm_interp(src_feat, trg_feat, self.upsample_size)
94 | src_norm += Norm.feat_normalize(src_feat, self.upsample_size)
95 | trg_norm += Norm.feat_normalize(trg_feat, self.upsample_size)
96 | layer_sel.append(lsel)
97 |
98 | # Layer 1-4
99 | for hid, (bid, lid) in enumerate(zip(self.bottleneck_ids, self.layer_ids)):
100 | with torch.no_grad():
101 | res = feat
102 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
103 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
104 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
105 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
106 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
107 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
108 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
109 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)
110 | if bid == 0:
111 | res = self.backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res)
112 | feat += res
113 |
114 | src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone()
115 | trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone()
116 |
117 | src_feat, trg_feat, lsel = self.learner(hid + 1, src_feat, trg_feat)
118 | if src_feat is not None and trg_feat is not None:
119 | correlation += Correlation.bmm_interp(src_feat, trg_feat, self.upsample_size)
120 | src_norm += Norm.feat_normalize(src_feat, self.upsample_size)
121 | trg_norm += Norm.feat_normalize(trg_feat, self.upsample_size)
122 | layer_sel.append(lsel)
123 |
124 | with torch.no_grad():
125 | feat = self.backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
126 |
127 | layer_sel = torch.stack(layer_sel).t()
128 |
129 | # If no layers are selected, select the base map
130 | if (layer_sel.sum(dim=1) == 0).sum() > 0:
131 | empty_sel = (layer_sel.sum(dim=1) == 0).nonzero().view(-1).long()
132 | if src_img.size(0) == 1:
133 | correlation = base_correlation
134 | src_norm = base_src_norm
135 | trg_norm = base_trg_norm
136 | else:
137 | correlation[empty_sel] += base_correlation[empty_sel]
138 | src_norm[empty_sel] += base_src_norm[empty_sel]
139 | trg_norm[empty_sel] += base_trg_norm[empty_sel]
140 |
141 | if self.learner.training:
142 | src_norm[src_norm == 0.0] += 0.0001
143 | trg_norm[trg_norm == 0.0] += 0.0001
144 | src_norm = src_norm.pow(0.5).unsqueeze(2)
145 | trg_norm = trg_norm.pow(0.5).unsqueeze(1)
146 |
147 | # Appearance matching confidence (p(m_a)): cosine similarity between hyperpimages
148 | correlation_ts = self.relu(correlation / (torch.bmm(src_norm, trg_norm) + 0.001)).pow(2)
149 |
150 | return correlation_ts, layer_sel
151 |
152 | def parameters(self):
153 | return self.learner.parameters()
154 |
155 | def state_dict(self):
156 | return self.learner.state_dict()
157 |
158 | def load_state_dict(self, state_dict):
159 | self.learner.load_state_dict(state_dict)
160 |
161 | def eval(self):
162 | self.learner.eval()
163 |
164 | def train(self):
165 | self.learner.train()
166 |
--------------------------------------------------------------------------------
/model/gating.py:
--------------------------------------------------------------------------------
1 | r"""Implementation of Dynamic Layer Gating (DLG)"""
2 | import torch.nn as nn
3 | import torch
4 |
5 |
6 | class GumbelFeatureSelection(nn.Module):
7 | r"""Dynamic layer gating with Gumbel-max trick"""
8 | def __init__(self, in_channels, reduction=8, hidden_size=32):
9 | r"""Constructor for DLG"""
10 | super(GumbelFeatureSelection, self).__init__()
11 |
12 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
13 | self.softmax = nn.Softmax(dim=1)
14 | self.reduction = reduction
15 |
16 | # Learnable modules in Dynamic Hyperpixel Flow
17 | self.reduction_ffns = [] # Convolutional Feature Transformation (CFT)
18 | self.gumbel_ffns = [] # Gumbel Layer Gating (GLG)
19 | for in_channel in in_channels:
20 | out_channel = in_channel // self.reduction
21 | reduction_ffn = nn.Sequential(
22 | nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=False),
23 | nn.BatchNorm2d(out_channel),
24 | nn.ReLU(inplace=True)
25 | )
26 | gumbel_ffn = nn.Sequential(
27 | nn.Conv2d(in_channel, hidden_size, kernel_size=1, bias=False),
28 | nn.BatchNorm2d(hidden_size),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(hidden_size, 2, kernel_size=1, bias=False)
31 | )
32 | self.reduction_ffns.append(reduction_ffn)
33 | self.gumbel_ffns.append(gumbel_ffn)
34 | self.reduction_ffns = nn.ModuleList(self.reduction_ffns)
35 | self.gumbel_ffns = nn.ModuleList(self.gumbel_ffns)
36 |
37 | def forward(self, lid, src_feat, trg_feat):
38 | r"""DLG forward pass"""
39 | relevance = self.gumbel_ffns[lid](self.avgpool(src_feat) + self.avgpool(trg_feat))
40 |
41 | # For measuring per-pair inference time on test set
42 | if not self.training and len(relevance) == 1:
43 | selected = relevance.max(dim=1)[1].squeeze()
44 |
45 | # Perform CFT iff the layer is selected
46 | if selected:
47 | src_x = self.reduction_ffns[lid](src_feat)
48 | trg_x = self.reduction_ffns[lid](trg_feat)
49 | else:
50 | src_x = None
51 | trg_x = None
52 | layer_sel = relevance.view(-1).max(dim=0)[1].unsqueeze(0).float()
53 | else:
54 | # Hard selection during forward pass (layer_sel)
55 | # Soft gradients during backward pass (dL/dy)
56 | y = self.gumbel_softmax(relevance.squeeze())
57 | _y = self._softmax(y)
58 | layer_sel = y[:, 1] + _y[:, 1]
59 |
60 | src_x = self.reduction_ffns[lid](src_feat)
61 | trg_x = self.reduction_ffns[lid](trg_feat)
62 | src_x = src_x * y[:, 1].view(-1, 1, 1, 1) + src_x * _y[:, 1].view(-1, 1, 1, 1)
63 | trg_x = trg_x * y[:, 1].view(-1, 1, 1, 1) + trg_x * _y[:, 1].view(-1, 1, 1, 1)
64 |
65 | return src_x, trg_x, layer_sel
66 |
67 | def _softmax(self, soft_sample):
68 | r"""Gumbel-max trick: replaces argmax with softmax during backward pass: soft_sample + _soft_sample"""
69 | hard_sample_idx = torch.max(soft_sample, dim=1)[1].unsqueeze(1)
70 | hard_sample = soft_sample.detach().clone().zero_().scatter(dim=1, index=hard_sample_idx, source=1)
71 |
72 | _soft_sample = (hard_sample - soft_sample.detach().clone())
73 |
74 | return _soft_sample
75 |
76 | def gumbel_softmax(self, logits, temperature=1, eps=1e-10):
77 | """Softly draws a sample from the Gumbel distribution"""
78 | if self.training:
79 | gumbel_noise = -torch.log(eps - torch.log(logits.detach().clone().uniform_() + eps))
80 | gumbel_input = logits + gumbel_noise
81 | else:
82 | gumbel_input = logits
83 |
84 | if gumbel_input.dim() == 1:
85 | gumbel_input = gumbel_input.unsqueeze(0)
86 |
87 | soft_sample = self.softmax(gumbel_input / temperature)
88 |
89 | return soft_sample
90 |
--------------------------------------------------------------------------------
/model/objective.py:
--------------------------------------------------------------------------------
1 | r"""Training objectives of DHPF"""
2 | import math
3 |
4 | import torch.nn.functional as F
5 | import torch
6 |
7 | from .base.geometry import Correlation
8 | from .base.norm import Norm
9 |
10 |
11 | class Objective:
12 | r"""Provides training objectives of DHPF"""
13 | @classmethod
14 | def initialize(cls, target_rate, alpha):
15 | cls.softmax = torch.nn.Softmax(dim=1)
16 | cls.target_rate = target_rate
17 | cls.alpha = alpha
18 | cls.eps = 1e-30
19 |
20 | @classmethod
21 | def weighted_cross_entropy(cls, correlation_matrix, easy_match, hard_match, batch):
22 | r"""Computes sum of weighted cross-entropy values between ground-truth and prediction"""
23 | loss_buf = correlation_matrix.new_zeros(correlation_matrix.size(0))
24 | correlation_matrix = Norm.unit_gaussian_normalize(correlation_matrix)
25 |
26 | for idx, (ct, thres, npt) in enumerate(zip(correlation_matrix, batch['pckthres'], batch['n_pts'])):
27 |
28 | # Hard (incorrect) match
29 | if len(hard_match['src'][idx]) > 0:
30 | cross_ent = cls.cross_entropy(ct, hard_match['src'][idx], hard_match['trg'][idx])
31 | loss_buf[idx] += cross_ent.sum()
32 |
33 | # Easy (correct) match
34 | if len(easy_match['src'][idx]) > 0:
35 | cross_ent = cls.cross_entropy(ct, easy_match['src'][idx], easy_match['trg'][idx])
36 | smooth_weight = (easy_match['dist'][idx] / (thres * cls.alpha)).pow(2)
37 | loss_buf[idx] += (smooth_weight * cross_ent).sum()
38 |
39 | loss_buf[idx] /= npt
40 |
41 | return torch.mean(loss_buf)
42 |
43 | @classmethod
44 | def cross_entropy(cls, correlation_matrix, src_match, trg_match):
45 | r"""Cross-entropy between predicted pdf and ground-truth pdf (one-hot vector)"""
46 | pdf = cls.softmax(correlation_matrix.index_select(0, src_match))
47 | prob = pdf[range(len(trg_match)), trg_match]
48 | cross_ent = -torch.log(prob + cls.eps)
49 |
50 | return cross_ent
51 |
52 | @classmethod
53 | def information_entropy(cls, correlation_matrix, rescale_factor=4):
54 | r"""Computes information entropy of all candidate matches"""
55 | bsz = correlation_matrix.size(0)
56 |
57 | correlation_matrix = Correlation.mutual_nn_filter(correlation_matrix)
58 |
59 | side = int(math.sqrt(correlation_matrix.size(1)))
60 | new_side = side // rescale_factor
61 |
62 | trg2src_dist = correlation_matrix.view(bsz, -1, side, side)
63 | src2trg_dist = correlation_matrix.view(bsz, side, side, -1).permute(0, 3, 1, 2)
64 |
65 | # Squeeze distributions for reliable entropy computation
66 | trg2src_dist = F.interpolate(trg2src_dist, [new_side, new_side], mode='bilinear', align_corners=True)
67 | src2trg_dist = F.interpolate(src2trg_dist, [new_side, new_side], mode='bilinear', align_corners=True)
68 |
69 | src_pdf = Norm.l1normalize(trg2src_dist.view(bsz, -1, (new_side * new_side)))
70 | trg_pdf = Norm.l1normalize(src2trg_dist.view(bsz, -1, (new_side * new_side)))
71 |
72 | src_pdf[src_pdf == 0.0] = cls.eps
73 | trg_pdf[trg_pdf == 0.0] = cls.eps
74 |
75 | src_ent = (-(src_pdf * torch.log2(src_pdf)).sum(dim=2)).view(bsz, -1)
76 | trg_ent = (-(trg_pdf * torch.log2(trg_pdf)).sum(dim=2)).view(bsz, -1)
77 | score_net = (src_ent + trg_ent).mean(dim=1) / 2
78 |
79 | return score_net.mean()
80 |
81 | @classmethod
82 | def layer_selection_loss(cls, layer_sel):
83 | r"""Encourages model to select each layer at a certain rate"""
84 | return (layer_sel.mean(dim=0) - cls.target_rate).pow(2).sum()
85 |
--------------------------------------------------------------------------------
/model/rhm.py:
--------------------------------------------------------------------------------
1 | """Implementation of regularized Hough matching algorithm (RHM)"""
2 | import math
3 |
4 | import torch.nn.functional as F
5 | import torch
6 |
7 | from .base.geometry import Geometry
8 |
9 |
10 | class HoughMatching:
11 | r"""Regularized Hough matching algorithm"""
12 | def __init__(self, rf, img_side, ncells=8192):
13 | r"""Constructor of HoughMatching"""
14 | super(HoughMatching, self).__init__()
15 |
16 | device = rf.device
17 | self.nbins_x, self.nbins_y, hs_cellsize = self.build_hspace(img_side, img_side, ncells)
18 | self.bin_ids = self.compute_bin_id(img_side, rf, rf, hs_cellsize, self.nbins_x)
19 | self.hspace = rf.new_zeros((len(rf), self.nbins_y * self.nbins_x))
20 | self.hbin_ids = self.bin_ids.add(torch.arange(0, len(rf)).to(device).
21 | mul(self.hspace.size(1)).unsqueeze(1).expand_as(self.bin_ids))
22 | self.hsfilter = Geometry.gaussian2d(7).to(device)
23 |
24 | def run(self, votes):
25 | r"""Regularized Hough matching"""
26 | hspace = self.hspace.view(-1).index_add(0, self.hbin_ids.view(-1), votes.view(-1)).view_as(self.hspace)
27 | hspace = torch.sum(hspace, dim=0)
28 | hspace = F.conv2d(hspace.view(1, 1, self.nbins_y, self.nbins_x),
29 | self.hsfilter.unsqueeze(0).unsqueeze(0), padding=3).view(-1)
30 |
31 | return torch.index_select(hspace, dim=0, index=self.bin_ids.view(-1)).view_as(votes)
32 |
33 | def compute_bin_id(self, src_imsize, src_box, trg_box, hs_cellsize, nbins_x):
34 | r"""Computes Hough space bin ids for voting"""
35 | src_ptref = src_imsize.float()
36 | src_trans = Geometry.center(src_box)
37 | trg_trans = Geometry.center(trg_box)
38 | xy_vote = (src_ptref.unsqueeze(0).expand_as(src_trans) - src_trans).unsqueeze(2).\
39 | repeat(1, 1, len(trg_box)) + trg_trans.t().unsqueeze(0).repeat(len(src_box), 1, 1)
40 |
41 | bin_ids = (xy_vote / hs_cellsize).long()
42 |
43 | return bin_ids[:, 0, :] + bin_ids[:, 1, :] * nbins_x
44 |
45 | def build_hspace(self, src_imsize, trg_imsize, ncells):
46 | r"""Build Hough space"""
47 | hs_width = src_imsize[0] + trg_imsize[0]
48 | hs_height = src_imsize[1] + trg_imsize[1]
49 | hs_cellsize = math.sqrt((hs_width * hs_height) / ncells)
50 | nbins_x = int(hs_width / hs_cellsize) + 1
51 | nbins_y = int(hs_height / hs_cellsize) + 1
52 |
53 | return nbins_x, nbins_y, hs_cellsize
54 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | r"""Dynamic Hyperpixel Flow testing code"""
2 | import argparse
3 |
4 | from torch.utils.data import DataLoader
5 | import torch
6 |
7 | from common.evaluation import Evaluator
8 | from common.logger import AverageMeter
9 | from common.logger import Logger
10 | from common import utils
11 | from model.base.geometry import Geometry
12 | from model import dhpf
13 | from data import download
14 |
15 |
16 | def test(model, dataloader):
17 | r"""Code for testing DHPF"""
18 | average_meter = AverageMeter(dataloader.dataset.benchmark)
19 |
20 | for idx, batch in enumerate(dataloader):
21 |
22 | # 1. DHPF forward pass
23 | correlation_matrix, layer_sel = model(batch['src_img'], batch['trg_img'])
24 |
25 | # 2. Transfer key-points (nearest neighbor assignment)
26 | prd_kps = Geometry.transfer_kps(correlation_matrix, batch['src_kps'], batch['n_pts'])
27 |
28 | # 3. Evaluate predictions
29 | eval_result = Evaluator.evaluate(prd_kps, batch)
30 | average_meter.update(eval_result, layer_sel.detach(), batch['category'])
31 | average_meter.write_process(idx, len(dataloader))
32 |
33 | # Write evaluation results
34 | Logger.visualize_selection(average_meter.sel_buffer)
35 | average_meter.write_result('Test')
36 |
37 |
38 | if __name__ == '__main__':
39 |
40 | # Arguments parsing
41 | parser = argparse.ArgumentParser(description='Dynamic Hyperpixel Flow Pytorch Implementation')
42 | parser.add_argument('--datapath', type=str, default='../Datasets_DHPF')
43 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['resnet50', 'resnet101'])
44 | parser.add_argument('--benchmark', type=str, default='pfpascal', choices=['pfpascal', 'pfwillow', 'caltech', 'spair'])
45 | parser.add_argument('--thres', type=str, default='auto', choices=['auto', 'img', 'bbox'])
46 | parser.add_argument('--alpha', type=float, default=0.1)
47 | parser.add_argument('--logpath', type=str, default='')
48 | parser.add_argument('--bsz', type=int, default=16)
49 | parser.add_argument('--load', type=str, default='')
50 | args = parser.parse_args()
51 | Logger.initialize(args)
52 | utils.fix_randseed(seed=0)
53 |
54 | # Model initialization
55 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
56 | model = dhpf.DynamicHPF(args.backbone, device)
57 | model.load_state_dict(torch.load(args.load))
58 | model.eval()
59 |
60 | # Dataset download & initialization
61 | download.download_dataset(args.datapath, args.benchmark)
62 | test_ds = download.load_dataset(args.benchmark, args.datapath, args.thres, device, 'test')
63 | test_dl = DataLoader(test_ds, batch_size=args.bsz, shuffle=False)
64 | Evaluator.initialize(args.benchmark, args.alpha)
65 |
66 | # Test DHPF
67 | with torch.no_grad(): test(model, test_dl)
68 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | r"""Dynamic Hyperpixel Flow training (validation) code"""
2 | import argparse
3 |
4 | from torch.utils.data import DataLoader
5 | import torch.optim as optim
6 | import torch
7 |
8 | from common.evaluation import Evaluator
9 | from common.logger import AverageMeter
10 | from common.logger import Logger
11 | from common import supervision as sup
12 | from common import utils
13 | from model.base.geometry import Geometry
14 | from model.objective import Objective
15 | from model import dhpf
16 | from data import download
17 |
18 |
19 | def train(epoch, model, dataloader, strategy, optimizer, training):
20 | r"""Code for training DHPF"""
21 | model.train() if training else model.eval()
22 | average_meter = AverageMeter(dataloader.dataset.benchmark)
23 |
24 | for idx, batch in enumerate(dataloader):
25 |
26 | # 1. DHPF forward pass
27 | src_img, trg_img = strategy.get_image_pair(batch, training)
28 | correlation_matrix, layer_sel = model(src_img, trg_img)
29 |
30 | # 2. Transfer key-points (nearest neighbor assignment)
31 | prd_kps = Geometry.transfer_kps(strategy.get_correlation(correlation_matrix), batch['src_kps'], batch['n_pts'])
32 |
33 | # 3. Evaluate predictions
34 | eval_result = Evaluator.evaluate(prd_kps, batch)
35 |
36 | # 4. Compute loss to update weights
37 | loss = strategy.compute_loss(correlation_matrix, eval_result, layer_sel, batch)
38 | if training:
39 | optimizer.zero_grad()
40 | loss.backward()
41 | optimizer.step()
42 | average_meter.update(eval_result, layer_sel.detach(), batch['category'], loss.item())
43 | average_meter.write_process(idx, len(dataloader), epoch)
44 |
45 | # Write evaluation results
46 | average_meter.write_result('Training' if training else 'Validation', epoch)
47 |
48 | avg_loss = utils.mean(average_meter.loss_buffer)
49 | avg_pck = utils.mean(average_meter.buffer['pck'])
50 | return avg_loss, avg_pck
51 |
52 |
53 | if __name__ == '__main__':
54 |
55 | # Arguments parsing
56 | parser = argparse.ArgumentParser(description='Dynamic Hyperpixel Flow Pytorch Implementation')
57 | parser.add_argument('--datapath', type=str, default='../Datasets_DHPF')
58 | parser.add_argument('--backbone', type=str, default='resnet101', choices=['resnet50', 'resnet101'])
59 | parser.add_argument('--benchmark', type=str, default='pfpascal', choices=['pfpascal', 'spair'])
60 | parser.add_argument('--thres', type=str, default='auto', choices=['auto', 'img', 'bbox'])
61 | parser.add_argument('--supervision', type=str, default='strong', choices=['weak', 'strong'])
62 | parser.add_argument('--selection', type=float, default=0.5)
63 | parser.add_argument('--alpha', type=float, default=0.1)
64 | parser.add_argument('--logpath', type=str, default='')
65 | parser.add_argument('--lr', type=float, default=0.03)
66 | parser.add_argument('--niter', type=int, default=100)
67 | parser.add_argument('--bsz', type=int, default=8)
68 | args = parser.parse_args()
69 | Logger.initialize(args)
70 | utils.fix_randseed(seed=0)
71 |
72 | # Model initialization
73 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
74 | model = dhpf.DynamicHPF(args.backbone, device)
75 | Objective.initialize(args.selection, args.alpha)
76 | strategy = sup.WeakSupStrategy() if args.supervision == 'weak' else sup.StrongSupStrategy()
77 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.5)
78 |
79 | # Dataset download & initialization
80 | download.download_dataset(args.datapath, args.benchmark)
81 | trn_ds = download.load_dataset(args.benchmark, args.datapath, args.thres, device, 'trn')
82 | val_ds = download.load_dataset(args.benchmark, args.datapath, args.thres, device, 'val')
83 | trn_dl = DataLoader(trn_ds, batch_size=args.bsz, shuffle=True)
84 | val_dl = DataLoader(val_ds, batch_size=args.bsz, shuffle=False)
85 | Evaluator.initialize(args.benchmark, args.alpha)
86 |
87 | # Train DHPF
88 | best_val_pck = float('-inf')
89 | for epoch in range(args.niter):
90 |
91 | trn_loss, trn_pck = train(epoch, model, trn_dl, strategy, optimizer, training=True)
92 | with torch.no_grad():
93 | val_loss, val_pck = train(epoch, model, val_dl, strategy, optimizer, training=False)
94 |
95 | # Save the best model
96 | if val_pck > best_val_pck:
97 | best_val_pck = val_pck
98 | Logger.save_model(model, epoch, val_pck)
99 | Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss}, epoch)
100 | Logger.tbd_writer.add_scalars('data/pck', {'trn_pck': trn_pck, 'val_pck': val_pck}, epoch)
101 |
102 | Logger.tbd_writer.close()
103 | Logger.info('==================== Finished training ====================')
104 |
--------------------------------------------------------------------------------