├── .gitignore
├── QUICKSTART.md
├── README.md
├── code
├── learn
│ ├── backend.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── _init_paths.py
│ │ ├── building_ae.py
│ │ ├── building_corners.py
│ │ ├── building_corners_full.py
│ │ ├── building_full.py
│ │ ├── building_order_enum.py
│ │ └── data_util.py
│ ├── eval_order.py
│ ├── metrics
│ │ ├── get_metric.py
│ │ └── new_utils.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── _init_paths.py
│ │ ├── corner_models.py
│ │ ├── corner_to_edge.py
│ │ ├── dat_blocks.py
│ │ ├── deformable_transformer.py
│ │ ├── deformable_transformer_full.py
│ │ ├── deformable_transformer_original.py
│ │ ├── drn.py
│ │ ├── edge_full_models.py
│ │ ├── edge_models.py
│ │ ├── full_models.py
│ │ ├── loss.py
│ │ ├── mlp.py
│ │ ├── ops
│ │ │ ├── .gitignore
│ │ │ ├── functions
│ │ │ │ ├── __init__.py
│ │ │ │ └── ms_deform_attn_func.py
│ │ │ ├── make.sh
│ │ │ ├── modules
│ │ │ │ ├── __init__.py
│ │ │ │ └── ms_deform_attn.py
│ │ │ ├── setup.py
│ │ │ ├── src
│ │ │ │ ├── cpu
│ │ │ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ │ │ └── ms_deform_attn_cpu.h
│ │ │ │ ├── cuda
│ │ │ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ │ │ ├── ms_deform_attn_cuda.h
│ │ │ │ │ └── ms_deform_im2col_cuda.cuh
│ │ │ │ ├── ms_deform_attn.h
│ │ │ │ └── vision.cpp
│ │ │ └── test.py
│ │ ├── order_class_models.py
│ │ ├── order_metric_models.py
│ │ ├── stacked_hg.py
│ │ ├── tcn.py
│ │ ├── unet.py
│ │ └── utils.py
│ ├── my_utils.py
│ ├── timer.py
│ ├── train_corner.py
│ ├── train_edge.py
│ ├── train_metric.py
│ ├── train_order_class.py
│ ├── train_order_metric.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── _init_paths.py
│ │ ├── geometry_utils.py
│ │ ├── misc.py
│ │ └── nn_utils.py
└── preprocess
│ └── data_gen.py
├── requirements.txt
├── resources
├── addin_warning.png
├── crop.gif
├── rotate.gif
└── translate.gif
└── setup_env.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | ./data/
2 | ./ckpts/
3 | data
4 | data_small/
5 | ckpts
6 | *_full
7 | __pycache__/
8 | .vscode/
9 | *.swp
10 | *.zip
11 | fid_csv/
12 |
--------------------------------------------------------------------------------
/QUICKSTART.md:
--------------------------------------------------------------------------------
1 | # Quickstart
2 |
3 | ## Plugin installation (Windows machine)
4 |
5 | Download the quickstart package [here](https://www.dropbox.com/scl/fi/jljkehuddx3df6hf6ptau/quickstart.zip?rlkey=bzxi1b13r00s6u29drkziazgv&dl=0), and unzip to somewhere on your computer. Yours should look like this:
6 | ```
7 | .
8 | ├── ckpts/
9 | ├── code/
10 | ├── data/
11 | ├── plugins/
12 | ├── plugins_built/
13 | └── ...
14 | ```
15 |
16 | Install the Revit plugin by copying the content of `plugins_built/` to:
17 | ```
18 | %APPDATA%\Autodesk\Revit\Addins\2022
19 | ```
20 | You can navigate to that folder by copying the path above and paste into File Explorer's address bar. Change the year number at the end if you have a different Revit version.
21 |
22 | Start up Revit.
23 | You should see the warning below when you start up Revit for the first time after installing the plugin.
24 | Click "Always Load" if you don't want to see this warning again.
25 | If you do not see the warning, then the plugin has not been installed correctly.
26 |
27 | 
28 |
29 | ## Server setup (Windows or Linux)
30 |
31 | We use Miniconda 3 to manage the python environment, installed at `$HOME` directory.
32 | Installation instructions for Miniconda can be found [here](https://docs.conda.io/projects/miniconda/en/latest/miniconda-install.html).
33 |
34 | Once you have Miniconda installed, run the provided script: `sh ./setup_env.sh`.
35 |
36 | ## Using our provided example point cloud
37 |
38 | If you would like to use your own point cloud, skip to the section below.
39 |
40 | Open up `data/revit_projects/32_ShortOffice_05_F2.rvt` inside Revit.
41 |
42 | Now you need to link the provided ReCap project, which contains the point cloud.
43 |
44 | From the top ribbon menu, click on `Insert -> Manage Links -> Point Clouds -> 32_ShortOffice... -> Reload From...`.
45 |
46 | Open `data/recap_projects/32_ShortOffice_05_F2/32_ShortOffice_05_F2_s0p01m.rcp`.
47 |
48 | Finally, click "OK" and you should see the point cloud.
49 |
50 | Switch to the server machine, run the following commands to start the server:
51 | ```
52 | cd code/learn
53 | conda activate bim
54 | python backend.py demo-floor
55 | ```
56 |
57 | You are now ready to use the assistive system.
58 |
59 | ## Using your own point cloud
60 |
61 | Using your point cloud involves a few more steps, as the backend server needs to understand
62 | the relationship between its internal and the Revit coordinate system.
63 |
64 | As prerequisites, your point cloud should be in LAZ format and also in a ReCap project.
65 | Please prepare the ReCap project before proceeding to the next step.
66 |
67 | Open up `data/revit_projects/example.rvt` inside Revit.
68 |
69 | Import the point cloud from the top ribbon menu: `Insert -> Point Clouds`.
70 |
71 | Transform the point cloud inside Revit such that the ground plane is at Level 1 and that majority of the walls are axis-aligned. Please see the below GIFs for how to do the transforms.
72 |
73 |
74 | Translate ground plane to level 1
75 |
76 | 
77 |
78 |
79 |
80 |
81 |
82 | Rotate walls so they are axis-aligned
83 |
84 | 
85 |
86 |
87 |
88 |
89 |
90 | Use section box to define the rough bounding box of point cloud. Section box parameters
91 | are used by backend server for processing later on.
92 |
93 | 
94 |
95 | Finally, save out the transforms by clicking from the top ribbion menu: `Add-Ins -> Save Transform`. You should see a file named `transform.txt` inside the same folder as your Revit project. If your server machine is different from your client machine, please upload the transform text file to somewhere on your server.
96 |
97 | Switch to the server machine, run the following commands to start the server:
98 | ```
99 | cd code/learn
100 | conda activate bim
101 | python backend.py demo-user \
102 | --floor-name \
103 | --laz-f \
104 | --laz-transform-f \
105 | --corner-ckpt-f ../../ckpts/corner/11/checkpoint.pth \
106 | --edge-ckpt-f ../../ckpts/edge_sample_16/11/checkpoint.pth \
107 | --metric-ckpt-f ../../ckpts/order_metric/11/checkpoint_latest.pth
108 | ```
109 |
110 | Replace `<...>` with your own paramenters. Wait until all preprocessing is done, it may take 10 minutes depending on your computer hardware. Once it says "Server listening", you are now ready to use the assistive system.
111 |
112 | ## Using the assistive system
113 |
114 | To make the plugin easier to use, you should setup some keyboard shortcuts.
115 |
116 | From the top-left, click on `File -> Options -> User Interface -> Keyboard Shortcuts: Customize -> Filter: Add-Ins Tab`.
117 | Bind the following commands to the corresponding keys:
118 |
119 | | Command | Shortcuts |
120 | |----------------:|:---|
121 | |Obtain Prediction| F2 |
122 | |Add Corner | F3 |
123 | |Send Corners | F4 |
124 | |Next / Accept Green | 4 |
125 | |Reject / Accept Yellow | 5 |
126 | |Accept Red | 6 |
127 |
128 | To enable assistance, first go to the `Add-Ins` tab and click on `Autocomplete`.
129 | The icon should change into a pause logo.
130 |
131 | Now manually draw one wall, hit the Escape key twice to exit the "Modify Wall" mode, and you should see the next three suggested walls in solid (next) and dashed (subsequent) pink lines.
132 |
133 | Run `Next / Accept Green` command to accept the solid pink line as the next wall to add.
134 | You may interleave manual drawing or accept command however you like.
135 |
136 | You may also run `Reject / Accept Yellow` to choose one of three candidate walls to add next.
137 | Run the corresponding command to accept the colored suggestion.
138 |
139 | To simplify wall drawing, one may also provide wall junctions and query the backend to automatically infer relevant walls.
140 | This has the benefit of adding multiple walls at once, especially around higher-degree junctions.
141 |
142 | To do so:
143 | 1. Hover the mouse over the junction in the point cloud, and hit the `Add Corner` shortcut.
144 | 2. (Optional) Drag the ring to modify its location.
145 | 3. Once the desired junctions are added, hit `Send Corners`.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A-Scan2BIM
2 |
3 | Official implementation of the paper [A-Scan2BIM: Assistive Scan to Building Information Modeling](https://drive.google.com/file/d/1zvGfdlLYbd_oAp7Oc-1vF2Czhl1A7q75/view) (__BMVC 2023, oral__)
4 |
5 | Please also visit our [project website](https://a-scan2bim.github.io/) for a video demo of our assistive system.
6 |
7 | # Updates
8 |
9 | [06/14/2024]
10 | Updated our assistive demo so it accepts user-provided point cloud. Please visit [here](QUICKSTART.md) for instructions.
11 |
12 | [12/08/2023]
13 | We are unable to release data for two floors, as the point clouds are unfortunately not available for download.
14 |
15 | [11/30/2023]
16 | Due to a bug in our heuristic baseline method, we have updated our order evaluations. Please see the updated arXiv paper for the latest results.
17 |
18 |
19 | # Table of contents (also TODOs)
20 |
21 | - [x] [Prerequisites](#prerequisites)
22 | - [x] [Quickstart](#quickstart)
23 | - [x] [Training and evaluation](#training-and-evaluation)
24 | - [ ] [Collecting your own data](#collecting-your-own-data)
25 | - [ ] [Building and developing the plugin](#building-and-developing-the-plugin)
26 | - [x] [Contact](#contact)
27 | - [x] [Bibtex](#bibtex)
28 | - [x] [Acknowledgment](#acknowledgment)
29 |
30 | # Prerequisites
31 |
32 | For more flexibility, our system is designed as a client-server model: a Revit plugin written in C# serves as the front end, and a python server which performs all the neural network computation.
33 |
34 | To run our system, there are two options:
35 |
36 | 1. Run both the plugin and the server on the same machine
37 | 2. Run the server on a separate machine, and port-forward to client via local network or ssh
38 |
39 | The advantage of option 2 over 1 is that the server can run on either Linux or Window machines, and can potentially serve multiple clients at the same time.
40 | Of course, the client machine needs to be Windows as Revit is Windows-only.
41 |
42 | The code has been tested with Revit 2022.
43 | Other versions will likely work but are not verified.
44 | If you are a student or an educator, you can get it for free [here](https://www.autodesk.com/education/edu-software/overview?sorting=featured&filters=individual).
45 |
46 | # Quickstart
47 |
48 | To try out our assistive system, please visit [here](QUICKSTART.md) for instructions.
49 |
50 | # Training and evaluation
51 |
52 | At a high-level, our system consists of two components:
53 | candidate wall enumeration network,
54 | and next wall prediction network.
55 | They are trained stand-alone, which we will describe the process step-by-step.
56 |
57 | ## Data/environment preparation
58 |
59 | First, fill out the data request form [here](https://forms.gle/Apg86MauTep2KTxx8).
60 | We will be in contact with you to provide the data download links.
61 | Once you have the data, unzip it in the root directory of the repository.
62 |
63 | Since we do not own the point clouds, please download them from the [workshop website](https://cv4aec.github.io/). You would need the data from the [3D challenge](https://codalab.lisn.upsaclay.fr/competitions/12405), both the train and test data. Rename and move all the LAZ files to the data folder like below:
64 |
65 | ```
66 | data/
67 | ├── history/
68 | ├── transforms/
69 | ├── all_floors.txt
70 | └── laz/
71 | ├── 05_MedOffice_01_F2.laz
72 | ├── 06_MedOffice_02_F1.laz
73 | ├── 07_MedOffice_03_F3.laz
74 | ├── 08_ShortOffice_01_F2.laz
75 | ├── 11_MedOffice_05_F2.laz
76 | ├── 11_MedOffice_05_F4.laz
77 | ├── 19_MedOffice_07_F4.laz
78 | ├── 25_Parking_01_F1.laz
79 | ├── 32_ShortOffice_05_F1.laz
80 | ├── 32_ShortOffice_05_F2.laz
81 | ├── 32_ShortOffice_05_F3.laz
82 | ├── 33_SmallBuilding_03_F1.laz
83 | ├── 35_Lab_02_F1.laz
84 | └── 35_Lab_02_F2.laz
85 | ```
86 |
87 | To install all python dependencies, run the provided script: `sh ./setup_env.sh`.
88 |
89 | Then run the following command to preprocess the data:
90 | ```
91 | cd code/preprocess
92 | python data_gen.py
93 | ```
94 |
95 | Also please download pretrained models from [here](https://www.dropbox.com/scl/fi/cwhgu92a6ndl212nls59i/ckpts_full.zip?rlkey=pabethcn0w0rxqk0k5x1da0dv&dl=0) and extract to the root directory of the repository.
96 | Some weights are necessary for network initialization (`pretrained/`) or evaluation (`ae/`), but others you may delete if training from scratch.
97 |
98 | ## Training candidate wall enumerator
99 |
100 | We borrow the HEAT architecture, which consists of two components: corner detector, and edge classifier.
101 | However in our case, we do not train end-to-end due to the large input size.
102 |
103 | ```
104 | cd code/learn
105 |
106 | # Step 1: train corner detector
107 | for i in {0..15}
108 | do
109 | python train_corner.py --test_idx i
110 | done
111 |
112 | # Step 2: cache detected corners to disk
113 | python backend.py export_corners
114 |
115 | # Step 3: train edge classifier
116 | for i in {0..15}
117 | do
118 | python train_edge.py --test_idx i
119 | done
120 |
121 | # Step 4: cache detected edges to disk
122 | python backend.py save_edge_preds
123 | ```
124 |
125 | Note that we do leave-one-out (per-floor) cross-validation, hence the for loops for step 1 and 3.
126 |
127 | ## Training next wall predictor
128 |
129 | Our next wall predictor along with the classifier baseline method are both trained on GT walls and order.
130 |
131 | Training is initialized from pretrained weights of the previous candidate wall enumeration task (variant where only one reference point is used).
132 |
133 | We have provided the pretrained weights for your convenience (see above for link to pretrained checkpoints), but they may not be necessary if training for your own task.
134 |
135 | ```
136 | cd code/learn
137 |
138 | # Training our method (metric-learning based)
139 | for i in {0..15}
140 | do
141 | python train_order_metric.py --test_idx i
142 | done
143 |
144 | # Training classifier baseline
145 | for i in {0..15}
146 | do
147 | python train_order_class.py --test_idx i
148 | done
149 | ```
150 |
151 | ## Evaluation
152 |
153 | To evaluate reconstruction metrics:
154 | ```
155 | python backend.py compute-metrics
156 | ```
157 |
158 | To evaluate order metrics:
159 | ```
160 | python eval_order.py eval-all-floors
161 | python eval_order.py plot-seq-FID
162 | ```
163 |
164 | To evaluate entropy and accuracy of next wall prediction:
165 | ```
166 | python eval_order.py eval-entropy
167 | python eval_order.py eval-all-acc-wrt-history
168 | ```
169 |
170 | # Collecting your own data
171 |
172 | Coming soon...
173 |
174 | # Building and developing the plugin
175 |
176 | Coming soon...
177 |
178 | # Contact
179 |
180 | Weilian Song, weilians@sfu.ca
181 |
182 | # Bibtex
183 | ```
184 | @article{weilian2023ascan2bim,
185 | author = {Song, Weilian and Luo, Jieliang and Zhao, Dale and Fu, Yan and Cheng, Chin-Yi and Furukawa, Yasutaka},
186 | title = {A-Scan2BIM: Assistive Scan to Building Information Modeling},
187 | journal = {British Machine Vision Conference (BMVC)},
188 | year = {2023},
189 | }
190 | ```
191 |
192 | # Acknowledgment
193 |
194 | This research is partially supported by NSERC Discovery Grants with Accelerator Supplements and the DND/NSERC Discovery Grant Supplement, NSERC Alliance Grants, and the John R. Evans Leaders Fund (JELF).
195 |
196 | We are also grateful to the [CV4AEC CVPR workshop](https://cv4aec.github.io/) for providing the point clouds.
197 |
198 | And finally, much of the plugin code was borrowed from Jeremy Tammik's sample add-in [WinTooltip](https://github.com/jeremytammik/WinTooltip).
199 | Hats off :tophat: to Jeremy for all the wonderful tutorials he has written over the years.
--------------------------------------------------------------------------------
/code/learn/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/code/learn/datasets/__init__.py
--------------------------------------------------------------------------------
/code/learn/datasets/_init_paths.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 |
4 |
5 | def add_path(path):
6 | if path not in sys.path: sys.path.insert(0, path)
7 |
8 |
9 | this_dir = osp.dirname(__file__)
10 |
11 | project_path = osp.abspath(osp.join(this_dir, '..'))
12 | add_path(project_path)
--------------------------------------------------------------------------------
/code/learn/datasets/data_util.py:
--------------------------------------------------------------------------------
1 | from PIL import ImageFilter
2 | from torchvision import transforms
3 |
4 |
5 | def RandomBlur(radius=2.):
6 | blur = GaussianBlur(radius=radius)
7 | full_transform = transforms.RandomApply([blur], p=.3)
8 | return full_transform
9 |
10 |
11 | class ImageFilterTransform(object):
12 |
13 | def __init__(self):
14 | raise NotImplementedError
15 |
16 | def __call__(self, img):
17 | return img.filter(self.filter)
18 |
19 |
20 | class GaussianBlur(ImageFilterTransform):
21 |
22 | def __init__(self, radius=2.):
23 | self.filter = ImageFilter.GaussianBlur(radius=radius)
24 |
--------------------------------------------------------------------------------
/code/learn/metrics/get_metric.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pickle
4 | import cv2
5 | from metrics.new_utils import *
6 |
7 |
8 | class Metric:
9 | def calc(self, gt_data, conv_data, thresh=8.0, iou_thresh=0.7):
10 | ### compute corners precision/recall
11 | gts = gt_data["corners"]
12 | dets = conv_data["corners"]
13 |
14 | per_sample_corner_tp = 0.0
15 | per_sample_corner_fp = 0.0
16 | per_sample_corner_length = gts.shape[0]
17 | found = [False] * gts.shape[0]
18 | c_det_annot = {}
19 |
20 | # for each corner detection
21 | for i, det in enumerate(dets):
22 | # get closest gt
23 | near_gt = [0, 999999.0, (0.0, 0.0)]
24 | for k, gt in enumerate(gts):
25 | dist = np.linalg.norm(gt - det)
26 | if dist < near_gt[1]:
27 | near_gt = [k, dist, gt]
28 | if near_gt[1] <= thresh and not found[near_gt[0]]:
29 | per_sample_corner_tp += 1.0
30 | found[near_gt[0]] = True
31 | c_det_annot[i] = near_gt[0]
32 | else:
33 | per_sample_corner_fp += 1.0
34 |
35 | per_corner_score = {
36 | "recall": per_sample_corner_tp / gts.shape[0],
37 | "precision": per_sample_corner_tp
38 | / (per_sample_corner_tp + per_sample_corner_fp + 1e-8),
39 | }
40 |
41 | ### compute edges precision/recall
42 | per_sample_edge_tp = 0.0
43 | per_sample_edge_fp = 0.0
44 | edge_corner_annots = gt_data["edges"]
45 | per_sample_edge_length = edge_corner_annots.shape[0]
46 |
47 | false_edge_ids = []
48 | match_pred_ids = set()
49 |
50 | for l, e_det in enumerate(conv_data["edges"]):
51 | c1, c2 = e_det
52 |
53 | # check if corners are mapped
54 | if (c1 not in c_det_annot.keys()) or (c2 not in c_det_annot.keys()):
55 | per_sample_edge_fp += 1.0
56 | false_edge_ids.append(l)
57 | continue
58 | # check hit
59 | c1_prime = c_det_annot[c1]
60 | c2_prime = c_det_annot[c2]
61 | is_hit = False
62 |
63 | for k, e_annot in enumerate(edge_corner_annots):
64 | c3, c4 = e_annot
65 | if ((c1_prime == c3) and (c2_prime == c4)) or (
66 | (c1_prime == c4) and (c2_prime == c3)
67 | ):
68 | is_hit = True
69 | match_pred_ids.add(l)
70 | break
71 |
72 | # hit
73 | if is_hit:
74 | per_sample_edge_tp += 1.0
75 | else:
76 | per_sample_edge_fp += 1.0
77 | false_edge_ids.append(l)
78 |
79 | per_edge_score = {
80 | "recall": per_sample_edge_tp / edge_corner_annots.shape[0],
81 | "precision": per_sample_edge_tp
82 | / (per_sample_edge_tp + per_sample_edge_fp + 1e-8),
83 | }
84 |
85 | # computer regions precision/recall
86 | conv_mask = render(
87 | corners=conv_data["corners"],
88 | edges=conv_data["edges"],
89 | render_pad=0,
90 | edge_linewidth=1,
91 | )[0]
92 | conv_mask = 1 - conv_mask
93 | conv_mask = conv_mask.astype(np.uint8)
94 | labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4)
95 |
96 | # cv2.imwrite('mask-pred.png', region_mask.astype(np.uint8) * 20)
97 |
98 | background_label = region_mask[0, 0]
99 | all_conv_masks = []
100 | for region_i in range(1, labels):
101 | if region_i == background_label:
102 | continue
103 | the_region = region_mask == region_i
104 | if the_region.sum() < 20:
105 | continue
106 | all_conv_masks.append(the_region)
107 |
108 | gt_mask = render(
109 | corners=gt_data["corners"],
110 | edges=gt_data["edges"],
111 | render_pad=0,
112 | edge_linewidth=1,
113 | )[0]
114 | gt_mask = 1 - gt_mask
115 | gt_mask = gt_mask.astype(np.uint8)
116 | labels, region_mask = cv2.connectedComponents(gt_mask, connectivity=4)
117 |
118 | # cv2.imwrite('mask-gt.png', region_mask.astype(np.uint8) * 20)
119 |
120 | background_label = region_mask[0, 0]
121 | all_gt_masks = []
122 | for region_i in range(1, labels):
123 | if region_i == background_label:
124 | continue
125 | the_region = region_mask == region_i
126 | if the_region.sum() < 20:
127 | continue
128 | all_gt_masks.append(the_region)
129 |
130 | per_sample_region_tp = 0.0
131 | per_sample_region_fp = 0.0
132 | per_sample_region_length = len(all_gt_masks)
133 | found = [False] * len(all_gt_masks)
134 | for i, r_det in enumerate(all_conv_masks):
135 | # gt closest gt
136 | near_gt = [0, 0, None]
137 | for k, r_gt in enumerate(all_gt_masks):
138 | iou = np.logical_and(r_gt, r_det).sum() / float(
139 | np.logical_or(r_gt, r_det).sum()
140 | )
141 | if iou > near_gt[1]:
142 | near_gt = [k, iou, r_gt]
143 | if near_gt[1] >= iou_thresh and not found[near_gt[0]]:
144 | per_sample_region_tp += 1.0
145 | found[near_gt[0]] = True
146 | else:
147 | per_sample_region_fp += 1.0
148 |
149 | # per_region_score = {
150 | # 'recall': per_sample_region_tp / len(all_gt_masks),
151 | # 'precision': per_sample_region_tp / (per_sample_region_tp + per_sample_region_fp + 1e-8)
152 | # }
153 | per_region_score = {"recall": 1.0, "precision": 1.0}
154 |
155 | return {
156 | "corner_tp": per_sample_corner_tp,
157 | "corner_fp": per_sample_corner_fp,
158 | "corner_length": per_sample_corner_length,
159 | "edge_tp": per_sample_edge_tp,
160 | "edge_fp": per_sample_edge_fp,
161 | "edge_length": per_sample_edge_length,
162 | "region_tp": per_sample_region_tp,
163 | "region_fp": per_sample_region_fp,
164 | "region_length": per_sample_region_length,
165 | "corner": per_corner_score,
166 | "edge": per_edge_score,
167 | "region": per_region_score,
168 | # for visualizing edges
169 | "matched_pred_edge_ids": list(match_pred_ids),
170 | "false_pred_edge_ids": false_edge_ids,
171 | }
172 |
173 | def calc_corner(self, gt_data, conv_data, thresh=8.0, iou_thresh=0.7):
174 | ### compute corners precision/recall
175 | gts = gt_data["corners"]
176 | dets = conv_data["corners"]
177 |
178 | per_sample_corner_tp = 0.0
179 | per_sample_corner_fp = 0.0
180 | per_sample_corner_length = gts.shape[0]
181 | found = [False] * gts.shape[0]
182 | c_det_annot = {}
183 |
184 | # for each corner detection
185 | for i, det in enumerate(dets):
186 | # get closest gt
187 | near_gt = [0, 999999.0, (0.0, 0.0)]
188 | for k, gt in enumerate(gts):
189 | dist = np.linalg.norm(gt - det)
190 | if dist < near_gt[1]:
191 | near_gt = [k, dist, gt]
192 | if near_gt[1] <= thresh and not found[near_gt[0]]:
193 | per_sample_corner_tp += 1.0
194 | found[near_gt[0]] = True
195 | c_det_annot[i] = near_gt[0]
196 | else:
197 | per_sample_corner_fp += 1.0
198 |
199 | per_corner_score = {
200 | "recall": per_sample_corner_tp / gts.shape[0],
201 | "precision": per_sample_corner_tp
202 | / (per_sample_corner_tp + per_sample_corner_fp + 1e-8),
203 | }
204 |
205 | return {
206 | "corner_tp": per_sample_corner_tp,
207 | "corner_fp": per_sample_corner_fp,
208 | "corner_length": per_sample_corner_length,
209 | "corner": per_corner_score,
210 | }
211 |
212 |
213 | def compute_metrics(gt_data, pred_data, thresh=8.0):
214 | metric = Metric()
215 | score = metric.calc(gt_data, pred_data, thresh=thresh)
216 | return score
217 |
218 |
219 | def get_recall_and_precision(tp, fp, length):
220 | recall = tp / (length + 1e-8)
221 | precision = tp / (tp + fp + 1e-8)
222 | return recall, precision
223 |
224 |
225 | if __name__ == "__main__":
226 | base_path = "./"
227 | gt_datapath = "../data/cities_dataset/annot"
228 | metric = Metric()
229 | corner_tp = 0.0
230 | corner_fp = 0.0
231 | corner_length = 0.0
232 | edge_tp = 0.0
233 | edge_fp = 0.0
234 | edge_length = 0.0
235 | region_tp = 0.0
236 | region_fp = 0.0
237 | region_length = 0.0
238 | for file_name in os.listdir(base_path):
239 | if len(file_name) < 10:
240 | continue
241 | f = open(os.path.join(base_path, file_name), "rb")
242 | gt_data = np.load(
243 | os.path.join(gt_datapath, file_name + ".npy"), allow_pickle=True
244 | ).tolist()
245 | candidate = pickle.load(f)
246 | conv_corners = candidate.graph.getCornersArray()
247 | conv_edges = candidate.graph.getEdgesArray()
248 | conv_data = {"corners": conv_corners, "edges": conv_edges}
249 | score = metric.calc(gt_data, conv_data)
250 | corner_tp += score["corner_tp"]
251 | corner_fp += score["corner_fp"]
252 | corner_length += score["corner_length"]
253 | edge_tp += score["edge_tp"]
254 | edge_fp += score["edge_fp"]
255 | edge_length += score["edge_length"]
256 | region_tp += score["region_tp"]
257 | region_fp += score["region_fp"]
258 | region_length += score["region_length"]
259 |
260 | f = open(os.path.join(base_path, "score.txt"), "w")
261 | # corner
262 | recall, precision = get_recall_and_precision(corner_tp, corner_fp, corner_length)
263 | f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
264 | print(
265 | "corners - precision: %.3f recall: %.3f f_score: %.3f"
266 | % (precision, recall, f_score)
267 | )
268 | f.write(
269 | "corners - precision: %.3f recall: %.3f f_score: %.3f\n"
270 | % (precision, recall, f_score)
271 | )
272 |
273 | # edge
274 | recall, precision = get_recall_and_precision(edge_tp, edge_fp, edge_length)
275 | f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
276 | print(
277 | "edges - precision: %.3f recall: %.3f f_score: %.3f"
278 | % (precision, recall, f_score)
279 | )
280 | f.write(
281 | "edges - precision: %.3f recall: %.3f f_score: %.3f\n"
282 | % (precision, recall, f_score)
283 | )
284 |
285 | # region
286 | recall, precision = get_recall_and_precision(region_tp, region_fp, region_length)
287 | f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
288 | print(
289 | "regions - precision: %.3f recall: %.3f f_score: %.3f"
290 | % (precision, recall, f_score)
291 | )
292 | f.write(
293 | "regions - precision: %.3f recall: %.3f f_score: %.3f\n"
294 | % (precision, recall, f_score)
295 | )
296 |
297 | f.close()
298 |
--------------------------------------------------------------------------------
/code/learn/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/code/learn/models/__init__.py
--------------------------------------------------------------------------------
/code/learn/models/_init_paths.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 |
4 |
5 | def add_path(path):
6 | if path not in sys.path: sys.path.insert(0, path)
7 |
8 |
9 | this_dir = osp.dirname(__file__)
10 |
11 | project_path = osp.abspath(osp.join(this_dir, '..'))
12 | add_path(project_path)
--------------------------------------------------------------------------------
/code/learn/models/corner_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import math
6 | from models.deformable_transformer_original import (
7 | DeformableTransformerEncoderLayer,
8 | DeformableTransformerEncoder,
9 | DeformableTransformerDecoder,
10 | DeformableTransformerDecoderLayer,
11 | DeformableAttnDecoderLayer,
12 | )
13 | from models.ops.modules import MSDeformAttn
14 | from models.mlp import MLP
15 | from models.unet import convrelu
16 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
17 | from einops import rearrange, repeat
18 | from einops.layers.torch import Rearrange
19 | from utils.misc import NestedTensor
20 | import scipy.ndimage.filters as filters
21 |
22 |
23 | class CornerEnum(nn.Module):
24 | def __init__(
25 | self,
26 | input_dim,
27 | hidden_dim,
28 | num_feature_levels,
29 | backbone_strides,
30 | backbone_num_channels,
31 | ):
32 | super(CornerEnum, self).__init__()
33 | self.input_dim = input_dim
34 | self.hidden_dim = hidden_dim
35 | self.num_feature_levels = num_feature_levels
36 |
37 | if num_feature_levels > 1:
38 | num_backbone_outs = len(backbone_strides)
39 | input_proj_list = []
40 | for _ in range(num_backbone_outs):
41 | in_channels = backbone_num_channels[_]
42 | input_proj_list.append(
43 | nn.Sequential(
44 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
45 | nn.GroupNorm(32, hidden_dim),
46 | )
47 | )
48 | for _ in range(num_feature_levels - num_backbone_outs):
49 | input_proj_list.append(
50 | nn.Sequential(
51 | nn.Conv2d(
52 | in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
53 | ),
54 | nn.GroupNorm(32, hidden_dim),
55 | )
56 | )
57 | in_channels = hidden_dim
58 | self.input_proj = nn.ModuleList(input_proj_list)
59 | else:
60 | self.input_proj = nn.ModuleList(
61 | [
62 | nn.Sequential(
63 | nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1),
64 | nn.GroupNorm(32, hidden_dim),
65 | )
66 | ]
67 | )
68 |
69 | self.patch_size = 4
70 | patch_dim = (self.patch_size**2) * input_dim
71 | self.to_patch_embedding = nn.Sequential(
72 | Rearrange(
73 | "b (h p1) (w p2) c -> b (h w) (p1 p2 c)",
74 | p1=self.patch_size,
75 | p2=self.patch_size,
76 | ),
77 | nn.Linear(patch_dim, input_dim),
78 | nn.Linear(input_dim, hidden_dim),
79 | )
80 |
81 | self.pixel_pe_fc = nn.Linear(input_dim, hidden_dim)
82 | self.transformer = CornerTransformer(
83 | d_model=hidden_dim,
84 | nhead=8,
85 | num_encoder_layers=1,
86 | dim_feedforward=1024,
87 | dropout=0.1,
88 | )
89 |
90 | self.img_pos = PositionEmbeddingSine(hidden_dim // 2)
91 |
92 | @staticmethod
93 | def get_ms_feat(xs, img_mask):
94 | out: Dict[str, NestedTensor] = {}
95 | # out = list()
96 | for name, x in sorted(xs.items()):
97 | m = img_mask
98 | assert m is not None
99 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
100 | out[name] = NestedTensor(x, mask)
101 | # out.append(NestedTensor(x, mask))
102 | return out
103 |
104 | @staticmethod
105 | def get_decoder_reference_points(height, width, device):
106 | ref_y, ref_x = torch.meshgrid(
107 | torch.linspace(
108 | 0.5, height - 0.5, height, dtype=torch.float32, device=device
109 | ),
110 | torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
111 | )
112 | ref_y = ref_y.reshape(-1)[None] / height
113 | ref_x = ref_x.reshape(-1)[None] / width
114 | ref = torch.stack((ref_x, ref_y), -1)
115 | return ref
116 |
117 | def forward(self, image_feats, feat_mask, pixels_feat, pixels, all_image_feats):
118 | # process image features
119 | features = self.get_ms_feat(image_feats, feat_mask)
120 |
121 | srcs = []
122 | masks = []
123 | all_pos = []
124 |
125 | new_features = list()
126 | for name, x in sorted(features.items()):
127 | new_features.append(x)
128 | features = new_features
129 |
130 | for l, feat in enumerate(features):
131 | src, mask = feat.decompose()
132 | mask = mask.to(src.device)
133 | srcs.append(self.input_proj[l](src))
134 | pos = self.img_pos(src).to(src.dtype)
135 | all_pos.append(pos)
136 | masks.append(mask)
137 | assert mask is not None
138 |
139 | if self.num_feature_levels > len(srcs):
140 | _len_srcs = len(srcs)
141 | for l in range(_len_srcs, self.num_feature_levels):
142 | if l == _len_srcs:
143 | src = self.input_proj[l](features[-1].tensors)
144 | else:
145 | src = self.input_proj[l](srcs[-1])
146 | m = feat_mask
147 | mask = (
148 | F.interpolate(m[None].float(), size=src.shape[-2:])
149 | .to(torch.bool)[0]
150 | .to(src.device)
151 | )
152 | pos_l = self.img_pos(src).to(src.dtype)
153 | srcs.append(src)
154 | masks.append(mask)
155 | all_pos.append(pos_l)
156 |
157 | sp_inputs = self.to_patch_embedding(pixels_feat)
158 |
159 | # compute the reference points
160 | H_tgt = W_tgt = int(np.sqrt(sp_inputs.shape[1]))
161 | reference_points_s1 = self.get_decoder_reference_points(
162 | H_tgt, W_tgt, sp_inputs.device
163 | )
164 | rp_all_pixels = self.get_decoder_reference_points(256, 256, sp_inputs.device)
165 |
166 | corner_logits = self.transformer(
167 | srcs, masks, all_pos, sp_inputs, reference_points_s1, all_image_feats
168 | )
169 | return corner_logits
170 |
171 |
172 | class PositionEmbeddingSine(nn.Module):
173 | """
174 | This is a more standard version of the position embedding, very similar to the one
175 | used by the Attention is all you need paper, generalized to work on images.
176 | """
177 |
178 | def __init__(
179 | self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
180 | ):
181 | super().__init__()
182 | self.num_pos_feats = num_pos_feats
183 | self.temperature = temperature
184 | self.normalize = normalize
185 | if scale is not None and normalize is False:
186 | raise ValueError("normalize should be True if scale is passed")
187 | if scale is None:
188 | scale = 2 * math.pi
189 | self.scale = scale
190 |
191 | def forward(self, x):
192 | mask = torch.zeros([x.shape[0], x.shape[2], x.shape[3]]).bool().to(x.device)
193 | not_mask = ~mask
194 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
195 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
196 | if self.normalize:
197 | eps = 1e-6
198 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
199 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
200 |
201 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
202 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
203 |
204 | pos_x = x_embed[:, :, :, None] / dim_t
205 | pos_y = y_embed[:, :, :, None] / dim_t
206 | pos_x = torch.stack(
207 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
208 | ).flatten(3)
209 | pos_y = torch.stack(
210 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
211 | ).flatten(3)
212 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
213 | return pos
214 |
215 |
216 | class CornerTransformer(nn.Module):
217 | def __init__(
218 | self,
219 | d_model=512,
220 | nhead=8,
221 | num_encoder_layers=6,
222 | dim_feedforward=1024,
223 | dropout=0.1,
224 | activation="relu",
225 | return_intermediate_dec=False,
226 | num_feature_levels=4,
227 | dec_n_points=4,
228 | enc_n_points=4,
229 | ):
230 | super(CornerTransformer, self).__init__()
231 |
232 | encoder_layer = DeformableTransformerEncoderLayer(
233 | d_model,
234 | dim_feedforward,
235 | dropout,
236 | activation,
237 | num_feature_levels,
238 | nhead,
239 | enc_n_points,
240 | )
241 | self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
242 |
243 | decoder_attn_layer = DeformableAttnDecoderLayer(
244 | d_model,
245 | dim_feedforward,
246 | dropout,
247 | activation,
248 | num_feature_levels,
249 | nhead,
250 | dec_n_points,
251 | )
252 | self.per_edge_decoder = DeformableTransformerDecoder(
253 | decoder_attn_layer, 1, False, with_sa=False
254 | )
255 |
256 | self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
257 |
258 | # upconv layers
259 | self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
260 | self.conv_up1 = convrelu(256 + 256, 256, 3, 1)
261 | self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
262 | self.conv_original_size2 = convrelu(64 + 128, d_model, 3, 1)
263 | self.output_fc_1 = nn.Linear(d_model, 1)
264 | self.output_fc_2 = nn.Linear(d_model, 1)
265 |
266 | self._reset_parameters()
267 |
268 | def _reset_parameters(self):
269 | for p in self.parameters():
270 | if p.dim() > 1:
271 | nn.init.xavier_uniform_(p)
272 | for m in self.modules():
273 | if isinstance(m, MSDeformAttn):
274 | m._reset_parameters()
275 | # xavier_uniform_(self.reference_points.weight.data, gain=1.0)
276 | # constant_(self.reference_points.bias.data, 0.)
277 | normal_(self.level_embed)
278 |
279 | def get_valid_ratio(self, mask):
280 | _, H, W = mask.shape
281 | valid_H = torch.sum(~mask[:, :, 0], 1)
282 | valid_W = torch.sum(~mask[:, 0, :], 1)
283 | valid_ratio_h = valid_H.float() / H
284 | valid_ratio_w = valid_W.float() / W
285 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
286 | return valid_ratio
287 |
288 | def forward(
289 | self, srcs, masks, pos_embeds, query_embed, reference_points, all_image_feats
290 | ):
291 | # prepare input for encoder
292 | src_flatten = []
293 | mask_flatten = []
294 | lvl_pos_embed_flatten = []
295 | spatial_shapes = []
296 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
297 | bs, c, h, w = src.shape
298 | spatial_shape = (h, w)
299 | spatial_shapes.append(spatial_shape)
300 | src = src.flatten(2).transpose(1, 2)
301 | mask = mask.flatten(1)
302 | pos_embed = pos_embed.flatten(2).transpose(1, 2)
303 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
304 | lvl_pos_embed_flatten.append(lvl_pos_embed)
305 | src_flatten.append(src)
306 | mask_flatten.append(mask)
307 | src_flatten = torch.cat(src_flatten, 1)
308 | mask_flatten = torch.cat(mask_flatten, 1)
309 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
310 | spatial_shapes = torch.as_tensor(
311 | spatial_shapes, dtype=torch.long, device=src_flatten.device
312 | )
313 | level_start_index = torch.cat(
314 | (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
315 | )
316 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
317 |
318 | # encoder
319 | memory = self.encoder(
320 | src_flatten,
321 | spatial_shapes,
322 | level_start_index,
323 | valid_ratios,
324 | lvl_pos_embed_flatten,
325 | mask_flatten,
326 | )
327 |
328 | # prepare input for decoder
329 | bs, _, c = memory.shape
330 |
331 | tgt = query_embed
332 |
333 | # relational decoder
334 | hs_pixels_s1, _ = self.per_edge_decoder(
335 | tgt,
336 | reference_points,
337 | memory,
338 | spatial_shapes,
339 | level_start_index,
340 | valid_ratios,
341 | query_embed,
342 | mask_flatten,
343 | )
344 |
345 | feats_s1, preds_s1 = self.generate_corner_preds(hs_pixels_s1, all_image_feats)
346 |
347 | return preds_s1
348 |
349 | def generate_corner_preds(self, outputs, conv_outputs):
350 | B, L, C = outputs.shape
351 | side = int(np.sqrt(L))
352 | outputs = outputs.view(B, side, side, C)
353 | outputs = outputs.permute(0, 3, 1, 2)
354 | outputs = torch.cat([outputs, conv_outputs["layer1"]], dim=1)
355 | x = self.conv_up1(outputs)
356 |
357 | x = self.upsample(x)
358 | x = torch.cat([x, conv_outputs["layer0"]], dim=1)
359 | x = self.conv_up0(x)
360 |
361 | x = self.upsample(x)
362 | x = torch.cat([x, conv_outputs["x_original"]], dim=1)
363 | x = self.conv_original_size2(x)
364 |
365 | logits = x.permute(0, 2, 3, 1)
366 | preds = self.output_fc_1(logits)
367 | preds = preds.squeeze(-1).sigmoid()
368 | return logits, preds
369 |
--------------------------------------------------------------------------------
/code/learn/models/corner_to_edge.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import scipy.ndimage.filters as filters
5 | import cv2
6 | import itertools
7 |
8 | NEIGHBOUR_SIZE = 5
9 | MATCH_THRESH = 5
10 | LOCAL_MAX_THRESH = 0.01
11 | viz_count = 0
12 |
13 | all_combibations = dict()
14 | for length in range(2, 351):
15 | ids = np.arange(length)
16 | combs = np.array(list(itertools.combinations(ids, 2)))
17 | all_combibations[length] = combs
18 |
19 | def prepare_edge_data(c_outputs, annots, images):
20 | bs = c_outputs.shape[0]
21 | # prepares parameters for each sample of the batch
22 | all_results = list()
23 |
24 | for b_i in range(bs):
25 | annot = annots[b_i]
26 | output = c_outputs[b_i]
27 | results = process_each_sample({'annot': annot, 'output': output, 'viz_img': images[b_i]})
28 | all_results.append(results)
29 |
30 | processed_corners = [item['corners'] for item in all_results]
31 | edge_coords = [item['edges'] for item in all_results]
32 | edge_labels = [item['labels'] for item in all_results]
33 |
34 | edge_info = {
35 | 'edge_coords': edge_coords,
36 | 'edge_labels': edge_labels,
37 | 'processed_corners': processed_corners
38 | }
39 |
40 | edge_data = collate_edge_info(edge_info)
41 | return edge_data
42 |
43 |
44 | def process_annot(annot, do_round=True):
45 | corners = np.array(list(annot.keys()))
46 | ind = np.lexsort(corners.T) # sort the g.t. corners to fix the order for the matching later
47 | corners = corners[ind] # sorted by y, then x
48 | corner_mapping = {tuple(k): v for v, k in enumerate(corners)}
49 |
50 | edges = list()
51 | for c, connections in annot.items():
52 | for other_c in connections:
53 | edge_pair = (corner_mapping[c], corner_mapping[tuple(other_c)])
54 | edges.append(edge_pair)
55 | corner_degrees = [len(annot[tuple(c)]) for c in corners]
56 | if do_round:
57 | corners = corners.round()
58 | return corners, edges, corner_degrees
59 |
60 |
61 | def process_each_sample(data):
62 | annot = data['annot']
63 | output = data['output']
64 | viz_img = data['viz_img']
65 |
66 | preds = output.detach().cpu().numpy()
67 |
68 | data_max = filters.maximum_filter(preds, NEIGHBOUR_SIZE)
69 | maxima = (preds == data_max)
70 | data_min = filters.minimum_filter(preds, NEIGHBOUR_SIZE)
71 | diff = ((data_max - data_min) > 0)
72 | maxima[diff == 0] = 0
73 | local_maximas = np.where((maxima > 0) & (preds > LOCAL_MAX_THRESH))
74 | pred_corners = np.stack(local_maximas, axis=-1)[:, [1, 0]] # to (x, y format)
75 |
76 | # produce edge labels labels from pred corners here
77 | #global viz_count
78 | #rand_num = np.random.rand()
79 | #if rand_num > 0.7 and len(pred_corners) > 0:
80 | #processed_corners, edges, labels = get_edge_label_det_only(pred_corners, annot) # only using det corners
81 | #output_path = './viz_edge_data/{}_example_det.png'.format(viz_count)
82 | #_visualize_edge_training_data(processed_corners, edges, labels, viz_img, output_path)
83 | #else:
84 | # # use g.t. corners, but mix with neg pred corners
85 | processed_corners, edges, labels = get_edge_label_mix_gt(pred_corners, annot)
86 | #output_path = './viz_training/{}_example_gt.png'.format(viz_count)
87 | #_visualize_edge_training_data(processed_corners, edges, labels, viz_img, output_path)
88 | #viz_count += 1
89 |
90 | results = {
91 | 'corners': processed_corners,
92 | 'edges': edges,
93 | 'labels': labels,
94 | }
95 | return results
96 |
97 |
98 | def get_edge_label_mix_gt(pred_corners, annot):
99 | ind = np.lexsort(pred_corners.T) # sort the pred corners to fix the order for matching
100 | pred_corners = pred_corners[ind] # sorted by y, then x
101 | gt_corners, edge_pairs, corner_degrees = process_annot(annot)
102 |
103 | output_to_gt = dict()
104 | gt_to_output = dict()
105 | diff = np.sqrt(((pred_corners[:, None] - gt_corners) ** 2).sum(-1))
106 | diff = diff.T
107 |
108 | if len(pred_corners) > 0:
109 | for target_i, target in enumerate(gt_corners):
110 | dist = diff[target_i]
111 | if len(output_to_gt) > 0:
112 | dist[list(output_to_gt.keys())] = 1000 # ignore already matched pred corners
113 | min_dist = dist.min()
114 | min_idx = dist.argmin()
115 | if min_dist < MATCH_THRESH and min_idx not in output_to_gt: # a positive match
116 | output_to_gt[min_idx] = (target_i, min_dist)
117 | gt_to_output[target_i] = min_idx
118 |
119 | all_corners = gt_corners.copy()
120 |
121 | # replace matched g.t. corners with pred corners
122 | for gt_i in range(len(gt_corners)):
123 | if gt_i in gt_to_output:
124 | all_corners[gt_i] = pred_corners[gt_to_output[gt_i]]
125 |
126 | nm_pred_ids = [i for i in range(len(pred_corners)) if i not in output_to_gt]
127 | nm_pred_ids = np.random.permutation(nm_pred_ids)
128 | if len(nm_pred_ids) > 0:
129 | nm_pred_corners = pred_corners[nm_pred_ids]
130 | if len(nm_pred_ids) + len(all_corners) <= 150:
131 | all_corners = np.concatenate([all_corners, nm_pred_corners], axis=0)
132 | else:
133 | all_corners = np.concatenate([all_corners, nm_pred_corners[:(150 - len(gt_corners)), :]], axis=0)
134 |
135 | processed_corners, edges, edge_ids, labels = _get_edges(all_corners, edge_pairs)
136 |
137 | return processed_corners, edges, labels
138 |
139 |
140 | def _get_edges(corners, edge_pairs):
141 | ind = np.lexsort(corners.T)
142 | corners = corners[ind] # sorted by y, then x
143 | corners = corners.round()
144 | id_mapping = {old: new for new, old in enumerate(ind)}
145 |
146 | all_ids = all_combibations[len(corners)]
147 | edges = corners[all_ids]
148 | labels = np.zeros(edges.shape[0])
149 |
150 | N = len(corners)
151 | edge_pairs = [(id_mapping[p[0]], id_mapping[p[1]]) for p in edge_pairs]
152 | edge_pairs = [p for p in edge_pairs if p[0] < p[1]]
153 | pos_ids = [int((2 * N - 1 - p[0]) * p[0] / 2 + p[1] - p[0] - 1) for p in edge_pairs]
154 | labels[pos_ids] = 1
155 |
156 | edge_ids = np.array(all_ids)
157 | return corners, edges, edge_ids, labels
158 |
159 |
160 | def collate_edge_info(data):
161 | batched_data = {}
162 | lengths_info = {}
163 | for field in data.keys():
164 | batch_values = data[field]
165 | all_lens = [len(value) for value in batch_values]
166 | max_len = max(all_lens)
167 | pad_value = 0
168 | batch_values = [pad_sequence(value, max_len, pad_value) for value in batch_values]
169 | batch_values = np.stack(batch_values, axis=0)
170 |
171 | if field in ['edge_coords', 'edge_labels', 'gt_values']:
172 | batch_values = torch.Tensor(batch_values).long()
173 | if field in ['processed_corners', 'edge_coords']:
174 | lengths_info[field] = all_lens
175 | batched_data[field] = batch_values
176 |
177 | # Add length and mask into the data, the mask if for Transformers' input format, True means padding
178 | for field, lengths in lengths_info.items():
179 | lengths_str = field + '_lengths'
180 | batched_data[lengths_str] = torch.Tensor(lengths).long()
181 | mask = torch.arange(max(lengths))
182 | mask = mask.unsqueeze(0).repeat(batched_data[field].shape[0], 1)
183 | mask = mask >= batched_data[lengths_str].unsqueeze(-1)
184 | mask_str = field + '_mask'
185 | batched_data[mask_str] = mask
186 |
187 | return batched_data
188 |
189 |
190 | def pad_sequence(seq, length, pad_value=0):
191 | if len(seq) == length:
192 | return seq
193 | else:
194 | pad_len = length - len(seq)
195 | if len(seq.shape) == 1:
196 | if pad_value == 0:
197 | paddings = np.zeros([pad_len, ])
198 | else:
199 | paddings = np.ones([pad_len, ]) * pad_value
200 | else:
201 | if pad_value == 0:
202 | paddings = np.zeros([pad_len, ] + list(seq.shape[1:]))
203 | else:
204 | paddings = np.ones([pad_len, ] + list(seq.shape[1:])) * pad_value
205 | padded_seq = np.concatenate([seq, paddings], axis=0)
206 | return padded_seq
207 |
208 |
209 | def get_infer_edge_pairs(corners, confs):
210 | ind = np.lexsort(corners.T)
211 | corners = corners[ind] # sorted by y, then x
212 | confs = confs[ind]
213 |
214 | edge_ids = all_combibations[len(corners)]
215 | edge_coords = corners[edge_ids]
216 |
217 | edge_coords = torch.tensor(np.array(edge_coords)).unsqueeze(0).long()
218 | mask = torch.zeros([edge_coords.shape[0], edge_coords.shape[1]]).bool()
219 | edge_ids = torch.tensor(np.array(edge_ids))
220 | return corners, confs, edge_coords, mask, edge_ids
221 |
222 |
223 |
224 | def get_mlm_info(labels, mlm=True):
225 | """
226 | :param labels: original edge labels
227 | :param mlm: whether enable mlm
228 | :return: g.t. values: 0-known false, 1-known true, 2-unknown
229 | """
230 | if mlm: # For training / evaluting with MLM
231 | rand_ratio = np.random.rand() * 0.5 + 0.5
232 | labels = torch.Tensor(labels)
233 | gt_rand = torch.rand(labels.size())
234 | gt_flag = torch.zeros_like(labels)
235 | gt_value = torch.zeros_like(labels)
236 | gt_flag[torch.where(gt_rand >= rand_ratio)] = 1
237 | gt_idx = torch.where(gt_flag == 1)
238 | pred_idx = torch.where(gt_flag == 0)
239 | gt_value[gt_idx] = labels[gt_idx]
240 | gt_value[pred_idx] = 2 # use 2 to represent unknown value, need to predict
241 | else:
242 | labels = torch.Tensor(labels)
243 | gt_flag = torch.zeros_like(labels)
244 | gt_value = torch.zeros_like(labels)
245 | gt_flag[:] = 0
246 | gt_value[:] = 2
247 | return gt_value
248 |
249 |
250 | def _get_rand_midpoint(edge):
251 | c1, c2 = edge
252 | _rand = np.random.rand() * 0.4 + 0.3
253 | mid_x = int(np.round(c1[0] + (c2[0] - c1[0]) * _rand))
254 | mid_y = int(np.round(c1[1] + (c2[1] - c1[1]) * _rand))
255 | mid_point = (mid_x, mid_y)
256 | return mid_point
257 |
258 |
259 | def _visualize_edge_training_data(corners, edges, edge_labels, image, save_path):
260 | image = image.transpose([1, 2, 0])
261 | image = (image * 255).astype(np.uint8)
262 | image = np.ascontiguousarray(image)
263 |
264 | for edge, label in zip(edges, edge_labels):
265 | if label == 1:
266 | cv2.line(image, tuple(edge[0].astype(np.int)), tuple(edge[1].astype(np.int)), (255, 255, 0), 2)
267 |
268 | for c in corners:
269 | cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
270 |
271 | cv2.imwrite(save_path, image)
272 |
--------------------------------------------------------------------------------
/code/learn/models/deformable_transformer_original.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torch import nn, Tensor
4 | from models.ops.modules import MSDeformAttn
5 | import torch.nn.functional as F
6 |
7 |
8 | class DeformableTransformerEncoderLayer(nn.Module):
9 | def __init__(
10 | self,
11 | d_model=256,
12 | d_ffn=1024,
13 | dropout=0.1,
14 | activation="relu",
15 | n_levels=4,
16 | n_heads=8,
17 | n_points=4,
18 | ):
19 | super().__init__()
20 |
21 | # self attention
22 | self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
23 | self.dropout1 = nn.Dropout(dropout)
24 | self.norm1 = nn.LayerNorm(d_model)
25 |
26 | # ffn
27 | self.linear1 = nn.Linear(d_model, d_ffn)
28 | self.activation = _get_activation_fn(activation)
29 | self.dropout2 = nn.Dropout(dropout)
30 | self.linear2 = nn.Linear(d_ffn, d_model)
31 | self.dropout3 = nn.Dropout(dropout)
32 | self.norm2 = nn.LayerNorm(d_model)
33 |
34 | @staticmethod
35 | def with_pos_embed(tensor, pos):
36 | return tensor if pos is None else tensor + pos
37 |
38 | def forward_ffn(self, src):
39 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
40 | src = src + self.dropout3(src2)
41 | src = self.norm2(src)
42 | return src
43 |
44 | def forward(
45 | self,
46 | src,
47 | pos,
48 | reference_points,
49 | spatial_shapes,
50 | level_start_index,
51 | padding_mask=None,
52 | ):
53 | # self attention
54 | src2 = self.self_attn(
55 | self.with_pos_embed(src, pos),
56 | reference_points,
57 | src,
58 | spatial_shapes,
59 | level_start_index,
60 | padding_mask,
61 | )
62 | src = src + self.dropout1(src2)
63 | src = self.norm1(src)
64 |
65 | # ffn
66 | src = self.forward_ffn(src)
67 |
68 | return src
69 |
70 |
71 | class DeformableTransformerEncoder(nn.Module):
72 | def __init__(self, encoder_layer, num_layers):
73 | super().__init__()
74 | self.layers = _get_clones(encoder_layer, num_layers)
75 | self.num_layers = num_layers
76 |
77 | @staticmethod
78 | def get_reference_points(spatial_shapes, valid_ratios, device):
79 | reference_points_list = []
80 | for lvl, (H_, W_) in enumerate(spatial_shapes):
81 | ref_y, ref_x = torch.meshgrid(
82 | torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
83 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
84 | )
85 | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
86 | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
87 | ref = torch.stack((ref_x, ref_y), -1)
88 | reference_points_list.append(ref)
89 | reference_points = torch.cat(reference_points_list, 1)
90 | reference_points = reference_points[:, :, None] * valid_ratios[:, None]
91 | return reference_points
92 |
93 | def forward(
94 | self,
95 | src,
96 | spatial_shapes,
97 | level_start_index,
98 | valid_ratios,
99 | pos=None,
100 | padding_mask=None,
101 | ):
102 | output = src
103 | reference_points = self.get_reference_points(
104 | spatial_shapes, valid_ratios, device=src.device
105 | )
106 | for _, layer in enumerate(self.layers):
107 | output = layer(
108 | output,
109 | pos,
110 | reference_points,
111 | spatial_shapes,
112 | level_start_index,
113 | padding_mask,
114 | )
115 |
116 | return output
117 |
118 |
119 | class DeformableAttnDecoderLayer(nn.Module):
120 | def __init__(
121 | self,
122 | d_model=256,
123 | d_ffn=1024,
124 | dropout=0.1,
125 | activation="relu",
126 | n_levels=4,
127 | n_heads=8,
128 | n_points=4,
129 | ):
130 | super().__init__()
131 | # cross attention
132 | self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
133 | self.dropout1 = nn.Dropout(dropout)
134 | self.norm1 = nn.LayerNorm(d_model)
135 |
136 | # ffn
137 | self.linear1 = nn.Linear(d_model, d_ffn)
138 | self.activation = _get_activation_fn(activation)
139 | self.dropout3 = nn.Dropout(dropout)
140 | self.linear2 = nn.Linear(d_ffn, d_model)
141 | self.dropout4 = nn.Dropout(dropout)
142 | self.norm3 = nn.LayerNorm(d_model)
143 |
144 | @staticmethod
145 | def with_pos_embed(tensor, pos):
146 | return tensor if pos is None else tensor + pos
147 |
148 | def forward_ffn(self, tgt):
149 | tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
150 | tgt = tgt + self.dropout4(tgt2)
151 | tgt = self.norm3(tgt)
152 | return tgt
153 |
154 | def forward(
155 | self,
156 | tgt,
157 | query_pos,
158 | reference_points,
159 | src,
160 | src_spatial_shapes,
161 | level_start_index,
162 | src_padding_mask=None,
163 | key_padding_mask=None,
164 | ):
165 | # cross attention
166 | tgt2 = self.cross_attn(
167 | self.with_pos_embed(tgt, query_pos),
168 | reference_points,
169 | src,
170 | src_spatial_shapes,
171 | level_start_index,
172 | src_padding_mask,
173 | )
174 | tgt = tgt + self.dropout1(tgt2)
175 | tgt = self.norm1(tgt)
176 |
177 | # ffn
178 | tgt = self.forward_ffn(tgt)
179 |
180 | return tgt
181 |
182 |
183 | class DeformableTransformerDecoderLayer(nn.Module):
184 | def __init__(
185 | self,
186 | d_model=256,
187 | d_ffn=1024,
188 | dropout=0.1,
189 | activation="relu",
190 | n_levels=4,
191 | n_heads=8,
192 | n_points=4,
193 | ):
194 | super().__init__()
195 | # cross attention
196 | self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
197 | self.dropout1 = nn.Dropout(dropout)
198 | self.norm1 = nn.LayerNorm(d_model)
199 |
200 | # self attention
201 | self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
202 | self.dropout2 = nn.Dropout(dropout)
203 | self.norm2 = nn.LayerNorm(d_model)
204 |
205 | # ffn
206 | self.linear1 = nn.Linear(d_model, d_ffn)
207 | self.activation = _get_activation_fn(activation)
208 | self.dropout3 = nn.Dropout(dropout)
209 | self.linear2 = nn.Linear(d_ffn, d_model)
210 | self.dropout4 = nn.Dropout(dropout)
211 | self.norm3 = nn.LayerNorm(d_model)
212 |
213 | @staticmethod
214 | def with_pos_embed(tensor, pos):
215 | return tensor if pos is None else tensor + pos
216 |
217 | def forward_ffn(self, tgt):
218 | tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
219 | tgt = tgt + self.dropout4(tgt2)
220 | tgt = self.norm3(tgt)
221 | return tgt
222 |
223 | def forward(
224 | self,
225 | tgt,
226 | query_pos,
227 | reference_points,
228 | src,
229 | src_spatial_shapes,
230 | level_start_index,
231 | src_padding_mask=None,
232 | key_padding_mask=None,
233 | get_image_feat=True,
234 | ):
235 | # self attention
236 | q = k = self.with_pos_embed(tgt, query_pos)
237 | tgt2 = self.self_attn(
238 | q.transpose(0, 1),
239 | k.transpose(0, 1),
240 | tgt.transpose(0, 1),
241 | key_padding_mask=key_padding_mask,
242 | )[0].transpose(0, 1)
243 | tgt = tgt + self.dropout2(tgt2)
244 | tgt = self.norm2(tgt)
245 |
246 | if get_image_feat:
247 | # cross attention
248 | tgt2 = self.cross_attn(
249 | self.with_pos_embed(tgt, query_pos),
250 | reference_points,
251 | src,
252 | src_spatial_shapes,
253 | level_start_index,
254 | src_padding_mask,
255 | )
256 | tgt = tgt + self.dropout1(tgt2)
257 | tgt = self.norm1(tgt)
258 |
259 | # ffn
260 | tgt = self.forward_ffn(tgt)
261 |
262 | return tgt
263 |
264 |
265 | class DeformableTransformerDecoder(nn.Module):
266 | def __init__(
267 | self, decoder_layer, num_layers, return_intermediate=False, with_sa=True
268 | ):
269 | super().__init__()
270 | self.layers = _get_clones(decoder_layer, num_layers)
271 | self.num_layers = num_layers
272 | self.return_intermediate = return_intermediate
273 | # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
274 | self.with_sa = with_sa
275 |
276 | def forward(
277 | self,
278 | tgt,
279 | reference_points,
280 | src,
281 | src_spatial_shapes,
282 | src_level_start_index,
283 | src_valid_ratios,
284 | query_pos=None,
285 | src_padding_mask=None,
286 | key_padding_mask=None,
287 | get_image_feat=True,
288 | ):
289 | output = tgt
290 |
291 | intermediate = []
292 | intermediate_reference_points = []
293 | for lid, layer in enumerate(self.layers):
294 | if reference_points.shape[-1] == 4:
295 | reference_points_input = (
296 | reference_points[:, :, None]
297 | * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
298 | )
299 | else:
300 | assert reference_points.shape[-1] == 2
301 | reference_points_input = (
302 | reference_points[:, :, None] * src_valid_ratios[:, None]
303 | )
304 | if self.with_sa:
305 | output = layer(
306 | output,
307 | query_pos,
308 | reference_points_input,
309 | src,
310 | src_spatial_shapes,
311 | src_level_start_index,
312 | src_padding_mask,
313 | key_padding_mask,
314 | get_image_feat,
315 | )
316 | else:
317 | output = layer(
318 | output,
319 | query_pos,
320 | reference_points_input,
321 | src,
322 | src_spatial_shapes,
323 | src_level_start_index,
324 | src_padding_mask,
325 | key_padding_mask,
326 | )
327 |
328 | if self.return_intermediate:
329 | intermediate.append(output)
330 | intermediate_reference_points.append(reference_points)
331 |
332 | if self.return_intermediate:
333 | return torch.stack(intermediate), torch.stack(intermediate_reference_points)
334 |
335 | return output, reference_points
336 |
337 |
338 | def _get_clones(module, N):
339 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
340 |
341 |
342 | def _get_activation_fn(activation):
343 | """Return an activation function given a string"""
344 | if activation == "relu":
345 | return F.relu
346 | if activation == "gelu":
347 | return F.gelu
348 | if activation == "glu":
349 | return F.glu
350 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
351 |
--------------------------------------------------------------------------------
/code/learn/models/edge_models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | from models.mlp import MLP
6 | from models.deformable_transformer import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \
7 | DeformableTransformerDecoder, DeformableTransformerDecoderLayer, DeformableAttnDecoderLayer
8 | from models.ops.modules import MSDeformAttn
9 | from models.corner_models import PositionEmbeddingSine
10 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
11 | import torch.nn.functional as F
12 | from utils.misc import NestedTensor
13 |
14 |
15 | class EdgeEnum(nn.Module):
16 | def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels, ):
17 | super(EdgeEnum, self).__init__()
18 | self.input_dim = input_dim
19 | self.hidden_dim = hidden_dim
20 | self.num_feature_levels = num_feature_levels
21 |
22 | if num_feature_levels > 1:
23 | num_backbone_outs = len(backbone_strides)
24 | input_proj_list = []
25 | for _ in range(num_backbone_outs):
26 | in_channels = backbone_num_channels[_]
27 | input_proj_list.append(nn.Sequential(
28 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
29 | nn.GroupNorm(32, hidden_dim),
30 | ))
31 | for _ in range(num_feature_levels - num_backbone_outs):
32 | input_proj_list.append(nn.Sequential(
33 | nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
34 | nn.GroupNorm(32, hidden_dim),
35 | ))
36 | in_channels = hidden_dim
37 | self.input_proj = nn.ModuleList(input_proj_list)
38 | else:
39 | self.input_proj = nn.ModuleList([
40 | nn.Sequential(
41 | nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1),
42 | nn.GroupNorm(32, hidden_dim),
43 | )])
44 |
45 | self.img_pos = PositionEmbeddingSine(hidden_dim // 2)
46 |
47 | self.edge_input_fc = nn.Linear(input_dim * 2, hidden_dim)
48 | self.output_fc = MLP(input_dim=hidden_dim, hidden_dim=hidden_dim // 2, output_dim=2, num_layers=2)
49 |
50 | # self.mlm_embedding = nn.Embedding(3, input_dim)
51 | self.transformer = EdgeTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1,
52 | num_decoder_layers=6, dim_feedforward=1024, dropout=0.1)
53 |
54 | @staticmethod
55 | def get_ms_feat(xs, img_mask):
56 | out: Dict[str, NestedTensor] = {}
57 | # out = list()
58 | for name, x in sorted(xs.items()):
59 | m = img_mask
60 | assert m is not None
61 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
62 | out[name] = NestedTensor(x, mask)
63 | # out.append(NestedTensor(x, mask))
64 | return out
65 |
66 | def forward(self, image_feats, feat_mask, corner_outputs, edge_coords, edge_masks, gt_values, corner_nums,
67 | max_candidates, do_inference=False, get_hs=False):
68 | # Prepare ConvNet features
69 | features = self.get_ms_feat(image_feats, feat_mask)
70 |
71 | srcs = []
72 | masks = []
73 | all_pos = []
74 |
75 | new_features = list()
76 | for name, x in sorted(features.items()):
77 | new_features.append(x)
78 | features = new_features
79 |
80 | for l, feat in enumerate(features):
81 | src, mask = feat.decompose()
82 | mask = mask.to(src.device)
83 | srcs.append(self.input_proj[l](src))
84 | pos = self.img_pos(src).to(src.dtype)
85 | all_pos.append(pos)
86 | masks.append(mask)
87 | assert mask is not None
88 |
89 | if self.num_feature_levels > len(srcs):
90 | _len_srcs = len(srcs)
91 | for l in range(_len_srcs, self.num_feature_levels):
92 | if l == _len_srcs:
93 | src = self.input_proj[l](features[-1].tensors)
94 | else:
95 | src = self.input_proj[l](srcs[-1])
96 | m = feat_mask
97 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device)
98 | pos_l = self.img_pos(src).to(src.dtype)
99 | srcs.append(src)
100 | masks.append(mask)
101 | all_pos.append(pos_l)
102 |
103 | bs = edge_masks.size(0)
104 | num_edges = edge_masks.size(1)
105 |
106 | corner_feats = corner_outputs
107 | edge_feats = list()
108 | for b_i in range(bs):
109 | feats = corner_feats[b_i, edge_coords[b_i, :, :, 1], edge_coords[b_i, :, :, 0], :]
110 | edge_feats.append(feats)
111 | edge_feats = torch.stack(edge_feats, dim=0)
112 | edge_feats = edge_feats.view(bs, num_edges, -1)
113 |
114 | edge_inputs = self.edge_input_fc(edge_feats.view(bs * num_edges, -1))
115 | edge_inputs = edge_inputs.view(bs, num_edges, -1)
116 |
117 | edge_center = (edge_coords[:, :, 0, :].float() + edge_coords[:, :, 1, :].float()) / 2
118 | edge_center = edge_center / feat_mask.shape[1]
119 |
120 | return self.transformer(srcs, masks, all_pos, edge_inputs, edge_center, gt_values,
121 | edge_masks, corner_nums, max_candidates, do_inference, get_hs)
122 |
123 |
124 | class EdgeTransformer(nn.Module):
125 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
126 | num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
127 | activation="relu", return_intermediate_dec=False,
128 | num_feature_levels=4, dec_n_points=4, enc_n_points=4,
129 | ):
130 | super(EdgeTransformer, self).__init__()
131 |
132 | encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
133 | dropout, activation,
134 | num_feature_levels, nhead, enc_n_points)
135 | self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
136 |
137 | decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward,
138 | dropout, activation,
139 | num_feature_levels, nhead, dec_n_points)
140 | self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False)
141 |
142 | decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
143 | dropout, activation,
144 | num_feature_levels, nhead, dec_n_points)
145 | self.relational_decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers,
146 | return_intermediate_dec, with_sa=True)
147 |
148 | self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
149 | # self.reference_points = nn.Linear(d_model, 2)
150 |
151 | self.gt_label_embed = nn.Embedding(3, d_model)
152 |
153 | self.input_fc_hb = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2)
154 | self.input_fc_rel = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2)
155 |
156 | self.output_fc_1 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
157 | self.output_fc_2 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
158 | self.output_fc_3 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
159 | self._reset_parameters()
160 |
161 | def _reset_parameters(self):
162 | for p in self.parameters():
163 | if p.dim() > 1:
164 | nn.init.xavier_uniform_(p)
165 | for m in self.modules():
166 | if isinstance(m, MSDeformAttn):
167 | m._reset_parameters()
168 | # xavier_uniform_(self.reference_points.weight.data, gain=1.0)
169 | # constant_(self.reference_points.bias.data, 0.)
170 | normal_(self.level_embed)
171 |
172 | def get_valid_ratio(self, mask):
173 | _, H, W = mask.shape
174 | valid_H = torch.sum(~mask[:, :, 0], 1)
175 | valid_W = torch.sum(~mask[:, 0, :], 1)
176 | valid_ratio_h = valid_H.float() / H
177 | valid_ratio_w = valid_W.float() / W
178 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
179 | return valid_ratio
180 |
181 | def forward(self, srcs, masks, pos_embeds, query_embed, reference_points, labels, key_padding_mask, corner_nums,
182 | max_candidates, do_inference=False, get_hs=False):
183 | # prepare input for encoder
184 | src_flatten = []
185 | mask_flatten = []
186 | lvl_pos_embed_flatten = []
187 | spatial_shapes = []
188 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
189 | bs, c, h, w = src.shape
190 | spatial_shape = (h, w)
191 | spatial_shapes.append(spatial_shape)
192 | src = src.flatten(2).transpose(1, 2)
193 | mask = mask.flatten(1)
194 | pos_embed = pos_embed.flatten(2).transpose(1, 2)
195 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
196 | lvl_pos_embed_flatten.append(lvl_pos_embed)
197 | src_flatten.append(src)
198 | mask_flatten.append(mask)
199 | src_flatten = torch.cat(src_flatten, 1)
200 | mask_flatten = torch.cat(mask_flatten, 1)
201 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
202 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
203 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
204 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
205 |
206 | # encoder
207 | memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten,
208 | mask_flatten)
209 |
210 | # prepare input for decoder
211 | bs, _, c = memory.shape
212 |
213 | tgt = query_embed
214 |
215 | # reference_points = self.reference_points(query_embed).sigmoid()
216 | init_reference_out = reference_points
217 |
218 | # relational decoder
219 | hs_per_edge, _ = self.per_edge_decoder(tgt, reference_points, memory,
220 | spatial_shapes, level_start_index, valid_ratios, query_embed,
221 | mask_flatten)
222 |
223 | logits_per_edge = self.output_fc_1(hs_per_edge).permute(0, 2, 1)
224 |
225 | if get_hs:
226 | return hs_per_edge, logits_per_edge
227 |
228 | filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids = self.candidate_filtering(
229 | logits_per_edge,
230 | hs_per_edge, query_embed, reference_points,
231 | labels,
232 | key_padding_mask, corner_nums, max_candidates)
233 |
234 | if not do_inference:
235 | filtered_gt_values = self.generate_gt_masking(filtered_labels, filtered_mask)
236 | else:
237 | filtered_gt_values = filtered_labels
238 |
239 | # relational decoder with image feature
240 | gt_info = self.gt_label_embed(filtered_gt_values)
241 | hybrid_prim_hs = self.input_fc_hb(torch.cat([filtered_hs, gt_info], dim=-1))
242 |
243 | hs, inter_references = self.relational_decoder(hybrid_prim_hs, filtered_rp, memory,
244 | spatial_shapes, level_start_index, valid_ratios, filtered_query,
245 | mask_flatten,
246 | key_padding_mask=filtered_mask, get_image_feat=True)
247 |
248 | logits_final_hb = self.output_fc_2(hs).permute(0, 2, 1)
249 |
250 | # relational decoder without image feature
251 | rel_prim_hs = self.input_fc_rel(torch.cat([filtered_query, gt_info], dim=-1))
252 |
253 | hs_rel, _ = self.relational_decoder(rel_prim_hs, filtered_rp, memory,
254 | spatial_shapes, level_start_index, valid_ratios, filtered_query,
255 | mask_flatten,
256 | key_padding_mask=filtered_mask, get_image_feat=False)
257 |
258 | logits_final_rel = self.output_fc_3(hs_rel).permute(0, 2, 1)
259 |
260 | return logits_per_edge, logits_final_hb, logits_final_rel, selected_ids, filtered_mask, filtered_gt_values
261 |
262 | @staticmethod
263 | def candidate_filtering(logits, hs, query, rp, labels, key_padding_mask, corner_nums, max_candidates):
264 | B, L, _ = hs.shape
265 | preds = logits.detach().softmax(1)[:, 1, :] # BxL
266 | preds[key_padding_mask == True] = -1 # ignore the masking parts
267 | sorted_ids = torch.argsort(preds, dim=-1, descending=True)
268 | filtered_hs = list()
269 | filtered_mask = list()
270 | filtered_query = list()
271 | filtered_rp = list()
272 | filtered_labels = list()
273 | selected_ids = list()
274 | for b_i in range(B):
275 | num_candidates = corner_nums[b_i] * 3
276 | ids = sorted_ids[b_i, :max_candidates[b_i]]
277 | filtered_hs.append(hs[b_i][ids])
278 | new_mask = key_padding_mask[b_i][ids]
279 | new_mask[num_candidates:] = True
280 | filtered_mask.append(new_mask)
281 | filtered_query.append(query[b_i][ids])
282 | filtered_rp.append(rp[b_i][ids])
283 | filtered_labels.append(labels[b_i][ids])
284 | selected_ids.append(ids)
285 | filtered_hs = torch.stack(filtered_hs, dim=0)
286 | filtered_mask = torch.stack(filtered_mask, dim=0)
287 | filtered_query = torch.stack(filtered_query, dim=0)
288 | filtered_rp = torch.stack(filtered_rp, dim=0)
289 | filtered_labels = torch.stack(filtered_labels, dim=0)
290 | selected_ids = torch.stack(selected_ids, dim=0)
291 |
292 | return filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids
293 |
294 |
295 | @staticmethod
296 | def generate_gt_masking(labels, mask):
297 | bs = labels.shape[0]
298 | gt_values = torch.zeros_like(mask).long()
299 | for b_i in range(bs):
300 | edge_length = (mask[b_i] == 0).sum()
301 | rand_ratio = np.random.rand() * 0.5 + 0.5
302 | gt_rand = torch.rand(edge_length)
303 | gt_flag = torch.zeros(edge_length)
304 | gt_flag[torch.where(gt_rand >= rand_ratio)] = 1
305 | gt_idx = torch.where(gt_flag == 1)
306 | pred_idx = torch.where(gt_flag == 0)
307 | gt_values[b_i, gt_idx[0]] = labels[b_i, gt_idx[0]]
308 | gt_values[b_i, pred_idx[0]] = 2 # use 2 to represent unknown value, need to predict
309 | return gt_values
310 |
--------------------------------------------------------------------------------
/code/learn/models/full_models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | from models.mlp import MLP
6 | from models.deformable_transformer_full import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \
7 | DeformableTransformerDecoder, DeformableTransformerDecoderLayer, DeformableAttnDecoderLayer
8 | from models.ops.modules import MSDeformAttn
9 | from models.corner_models import PositionEmbeddingSine
10 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
11 | import torch.nn.functional as F
12 | from utils.misc import NestedTensor
13 | from utils.nn_utils import pos_encode_2d
14 |
15 |
16 | class EdgeEnum(nn.Module):
17 |
18 | def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels):
19 | super(EdgeEnum, self).__init__()
20 | self.input_dim = input_dim
21 | self.hidden_dim = hidden_dim
22 | self.num_feature_levels = num_feature_levels
23 |
24 | if num_feature_levels > 1:
25 | num_backbone_outs = len(backbone_strides)
26 | input_proj_list = []
27 | for _ in range(num_backbone_outs):
28 | in_channels = backbone_num_channels[_]
29 | input_proj_list.append(nn.Sequential(
30 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
31 | nn.GroupNorm(32, hidden_dim),
32 | ))
33 | for _ in range(num_feature_levels - num_backbone_outs):
34 | input_proj_list.append(nn.Sequential(
35 | nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
36 | nn.GroupNorm(32, hidden_dim),
37 | ))
38 | in_channels = hidden_dim
39 | self.input_proj = nn.ModuleList(input_proj_list)
40 | else:
41 | self.input_proj = nn.ModuleList([
42 | nn.Sequential(
43 | nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1),
44 | nn.GroupNorm(32, hidden_dim),
45 | )])
46 |
47 | self.img_pos = PositionEmbeddingSine(hidden_dim // 2)
48 | self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, hidden_dim))
49 | self.edge_input_fc = nn.Linear(input_dim * 2, hidden_dim)
50 |
51 | self.output_fc = MLP(input_dim=hidden_dim, hidden_dim=hidden_dim // 2, output_dim=2, num_layers=2)
52 |
53 | # self.mlm_embedding = nn.Embedding(3, input_dim)
54 | self.transformer = EdgeTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1,
55 | num_decoder_layers=6, dim_feedforward=1024, dropout=0.1)
56 |
57 | self._reset_parameters()
58 |
59 |
60 | def _reset_parameters(self):
61 | # for p in self.parameters():
62 | # if p.dim() > 1:
63 | # nn.init.xavier_uniform_(p)
64 | # for m in self.modules():
65 | # if isinstance(m, MSDeformAttn):
66 | # m._reset_parameters()
67 |
68 | # xavier_uniform_(self.reference_points.weight.data, gain=1.0)
69 | # constant_(self.reference_points.bias.data, 0.)
70 | normal_(self.level_embed)
71 |
72 |
73 | @staticmethod
74 | def get_ms_feat(xs, img_mask):
75 | out: Dict[str, NestedTensor] = {}
76 | # out = list()
77 | for name, x in sorted(xs.items()):
78 | m = img_mask
79 | assert m is not None
80 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
81 | out[name] = NestedTensor(x, mask)
82 | # out.append(NestedTensor(x, mask))
83 | return out
84 |
85 |
86 | def get_valid_ratio(self, mask):
87 | _, H, W = mask.shape
88 | valid_H = torch.sum(~mask[:, :, 0], 1)
89 | valid_W = torch.sum(~mask[:, 0, :], 1)
90 | valid_ratio_h = valid_H.float() / H
91 | valid_ratio_w = valid_W.float() / W
92 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
93 | return valid_ratio
94 |
95 |
96 | def prepare_image_feat(self, crop, backbone):
97 | image = torch.tensor(crop['image']).cuda().unsqueeze(0)
98 | with torch.no_grad():
99 | image_feats, feat_mask, _ = backbone(image)
100 | features = self.get_ms_feat(image_feats, feat_mask)
101 |
102 | srcs = []
103 | masks = []
104 | all_pos = []
105 |
106 | new_features = list()
107 | for name, x in sorted(features.items()):
108 | new_features.append(x)
109 | features = new_features
110 |
111 | for l, feat in enumerate(features):
112 | src, mask = feat.decompose()
113 | mask = mask.to(src.device)
114 | srcs.append(self.input_proj[l](src))
115 | pos = self.img_pos(src).to(src.dtype)
116 | all_pos.append(pos)
117 | masks.append(mask)
118 | assert mask is not None
119 |
120 | if self.num_feature_levels > len(srcs):
121 | _len_srcs = len(srcs)
122 | for l in range(_len_srcs, self.num_feature_levels):
123 | if l == _len_srcs:
124 | src = self.input_proj[l](features[-1].tensors)
125 | else:
126 | src = self.input_proj[l](srcs[-1])
127 | m = feat_mask
128 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device)
129 | pos_l = self.img_pos(src).to(src.dtype)
130 | srcs.append(src)
131 | masks.append(mask)
132 | all_pos.append(pos_l)
133 |
134 | # prepare input for encoder
135 | pos_embeds = all_pos
136 |
137 | src_flatten = []
138 | mask_flatten = []
139 | lvl_pos_embed_flatten = []
140 | spatial_shapes = []
141 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
142 | bs, c, h, w = src.shape
143 | spatial_shape = (h, w)
144 | spatial_shapes.append(spatial_shape)
145 | src = src.flatten(2).transpose(1, 2)
146 | mask = mask.flatten(1)
147 | pos_embed = pos_embed.flatten(2).transpose(1, 2)
148 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
149 | lvl_pos_embed_flatten.append(lvl_pos_embed)
150 | src_flatten.append(src)
151 | mask_flatten.append(mask)
152 | src_flatten = torch.cat(src_flatten, 1)
153 | mask_flatten = torch.cat(mask_flatten, 1)
154 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
155 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
156 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
157 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
158 |
159 | image_data = {
160 | 'src_flatten': src_flatten,
161 | 'spatial_shapes': spatial_shapes,
162 | 'level_start_index': level_start_index,
163 | 'valid_ratios': valid_ratios,
164 | 'lvl_pos_embed_flatten': lvl_pos_embed_flatten,
165 | 'mask_flatten': mask_flatten
166 | }
167 | return image_data
168 |
169 |
170 | def forward(self, data, backbone):
171 | # encode edges
172 | edge_coords = data['edge']
173 | edge_feats_a = pos_encode_2d(x=edge_coords[:,0], y=edge_coords[:,1])
174 | edge_feats_b = pos_encode_2d(x=edge_coords[:,2], y=edge_coords[:,3])
175 | edge_feats = torch.cat([edge_feats_a, edge_feats_b], dim=-1)
176 | edge_feats = self.edge_input_fc(edge_feats)
177 |
178 | # prepare image features
179 | for batch_crops in data['crops']:
180 | for crop in batch_crops:
181 | crop['image_data'] = self.prepare_image_feat(crop, backbone)
182 |
183 | # edge_center = (edge_coords[:, :, 0, :].float() + edge_coords[:, :, 1, :].float()) / 2
184 | # edge_center = edge_center / feat_mask.shape[1]
185 |
186 | # logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values = self.transformer(srcs,
187 | # masks,
188 | # all_pos,
189 | # edge_inputs,
190 | # edge_center,
191 | # gt_values,
192 | # edge_masks,
193 | # corner_nums,
194 | # max_candidates,
195 | # do_inference)
196 |
197 | self.transformer(edge_feats, data)
198 |
199 | return logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values
200 |
201 |
202 | class EdgeTransformer(nn.Module):
203 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
204 | num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
205 | activation="relu", return_intermediate_dec=False,
206 | num_feature_levels=4, dec_n_points=4, enc_n_points=4,
207 | ):
208 | super(EdgeTransformer, self).__init__()
209 |
210 | encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
211 | dropout, activation,
212 | num_feature_levels, nhead, enc_n_points)
213 | self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
214 |
215 | decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward,
216 | dropout, activation,
217 | num_feature_levels, nhead, dec_n_points)
218 | self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False)
219 |
220 | decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
221 | dropout, activation,
222 | num_feature_levels, nhead, dec_n_points)
223 | self.relational_decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers,
224 | return_intermediate_dec, with_sa=True)
225 |
226 | # self.reference_points = nn.Linear(d_model, 2)
227 |
228 | self.gt_label_embed = nn.Embedding(3, d_model)
229 |
230 | self.input_fc_hb = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2)
231 | self.input_fc_rel = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2)
232 |
233 | self.output_fc_1 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
234 | self.output_fc_2 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
235 | self.output_fc_3 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
236 |
237 | self._reset_parameters()
238 |
239 | def _reset_parameters(self):
240 | for p in self.parameters():
241 | if p.dim() > 1:
242 | nn.init.xavier_uniform_(p)
243 | for m in self.modules():
244 | if isinstance(m, MSDeformAttn):
245 | m._reset_parameters()
246 |
247 | def forward(self, edge_feats, data):
248 | src_flatten = []
249 | mask_flatten = []
250 | lvl_pos_embed_flatten = []
251 |
252 | for batch_crops in data['crops']:
253 | for crop in batch_crops:
254 | image_data = crop['image_data']
255 | with torch.no_grad():
256 | memory = self.encoder(image_data['src_flatten'],
257 | image_data['spatial_shapes'],
258 | image_data['level_start_index'],
259 | image_data['valid_ratios'],
260 | image_data['lvl_pos_embed_flatten'],
261 | image_data['mask_flatten'])
262 | image_data['memory'] = memory
263 |
264 | hs_per_edge, _ = self.per_edge_decoder(edge_feats, data)
265 |
266 | logits_per_edge = self.output_fc_1(hs_per_edge).permute(0, 2, 1)
267 |
268 | # relational decoder with image feature
269 | gt_info = self.gt_label_embed(filtered_gt_values)
270 | hybrid_prim_hs = self.input_fc_hb(torch.cat([filtered_hs, gt_info], dim=-1))
271 |
272 | hs, inter_references = self.relational_decoder(hybrid_prim_hs, filtered_rp, memory,
273 | spatial_shapes, level_start_index, valid_ratios, filtered_query,
274 | mask_flatten,
275 | key_padding_mask=filtered_mask, get_image_feat=True)
276 |
277 | logits_final_hb = self.output_fc_2(hs).permute(0, 2, 1)
278 |
279 | # relational decoder without image feature
280 | rel_prim_hs = self.input_fc_rel(torch.cat([filtered_query, gt_info], dim=-1))
281 |
282 | hs_rel, _ = self.relational_decoder(rel_prim_hs, filtered_rp, memory,
283 | spatial_shapes, level_start_index, valid_ratios, filtered_query,
284 | mask_flatten,
285 | key_padding_mask=filtered_mask, get_image_feat=False)
286 |
287 | logits_final_rel = self.output_fc_3(hs_rel).permute(0, 2, 1)
288 |
289 | return logits_per_edge, logits_final_hb, logits_final_rel, selected_ids, filtered_mask, filtered_gt_values
--------------------------------------------------------------------------------
/code/learn/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scipy.optimize import linear_sum_assignment
3 | from torch import nn
4 | import torch.nn.functional as F
5 | import scipy.ndimage.filters as filters
6 | import numpy as np
7 | from utils.geometry_utils import edge_acc
8 |
9 |
10 | class CornerCriterion(nn.Module):
11 | def __init__(self, image_size):
12 | super().__init__()
13 | self.gamma = 1
14 | self.loss_rate = 9
15 |
16 | # def forward(self, outputs_s1, outputs_s2, s2_mask, s2_candidates, targets, gauss_targets, epoch=0):
17 | def forward(self, outputs_s1, targets, gauss_targets, epoch=0):
18 | # Compute the acc first, use the acc to guide the setup of loss weight
19 | preds_s1 = (outputs_s1 >= 0.5).float()
20 | pos_target_ids = torch.where(targets == 1)
21 | correct = (preds_s1[pos_target_ids] == targets[pos_target_ids]).float().sum()
22 | # acc = correct / (preds.shape[0] * preds.shape[1] * preds.shape[2])
23 | # num_pos_preds = (preds == 1).sum()
24 | recall_s1 = correct / len(pos_target_ids[0])
25 | # prec = correct / num_pos_preds if num_pos_preds > 0 else torch.tensor(0).to(correct.device)
26 | # f_score = 2.0 * prec * recall / (recall + prec + 1e-8)
27 |
28 | rate = self.loss_rate
29 |
30 | loss_weight = (gauss_targets > 0.5).float() * rate + 1
31 | loss_s1 = F.binary_cross_entropy(
32 | outputs_s1, gauss_targets, weight=loss_weight, reduction="none"
33 | )
34 | loss_s1 = loss_s1.sum(-1).sum(-1).mean()
35 |
36 | # loss for stage-2
37 | # B, H, W = gauss_targets.shape
38 | # gauss_targets_1d = gauss_targets.view(B, H*W)
39 | # s2_ids = s2_candidates[:, :, 1] * H + s2_candidates[:, :, 0]
40 | # s2_labels = torch.gather(gauss_targets_1d, 1, s2_ids)
41 | # # try an aggressive labeling for s2
42 | # s2_th = 0.1
43 | # s2_labels = (s2_labels > s2_th).float()
44 | # loss_weight = (s2_labels > 0.5).float() * rate + 1
45 | # loss_s2 = F.binary_cross_entropy(outputs_s2, s2_labels, weight=loss_weight, reduction='none')
46 | # loss_s2[torch.where(s2_mask == True)] = 0
47 | # loss_s2 = loss_s2.sum(-1).sum(-1).mean()
48 |
49 | return loss_s1, recall_s1
50 | # , loss_s2, recall_s1
51 |
52 |
53 | class CornerCriterion4D(nn.Module):
54 | def __init__(self):
55 | super().__init__()
56 |
57 | def forward(self, outputs, corner_labels, offset_labels):
58 | loss_dict = {"loss_jloc": 0.0, "loss_joff": 0.0}
59 | for _, output in enumerate(outputs):
60 | loss_dict["loss_jloc"] += cross_entropy_loss_for_junction(
61 | output[:, :2], corner_labels
62 | )
63 | loss_dict["loss_joff"] += sigmoid_l1_loss(
64 | output[:, 2:4], offset_labels, -0.5, corner_labels.float()
65 | )
66 | return loss_dict
67 |
68 |
69 | def cross_entropy_loss_for_junction(logits, positive):
70 | nlogp = -F.log_softmax(logits, dim=1)
71 | loss = positive * nlogp[:, None, 1] + (1 - positive) * nlogp[:, None, 0]
72 | pos_rate = 4
73 | weights = (positive == 1) * pos_rate + 1
74 | loss = loss * weights
75 |
76 | loss = loss.mean()
77 | return loss
78 |
79 |
80 | def sigmoid_l1_loss(logits, targets, offset=0.0, mask=None):
81 | logp = torch.sigmoid(logits) + offset
82 | loss = torch.abs(logp - targets)
83 |
84 | if mask is not None:
85 | w = mask.mean(3, True).mean(2, True)
86 | w[w == 0] = 1
87 | loss = loss * (mask / w)
88 |
89 | loss = loss.mean() # avg over batch dim
90 | return loss
91 |
92 |
93 | class EdgeCriterion(nn.Module):
94 | def __init__(self):
95 | super().__init__()
96 | self.edge_loss = nn.CrossEntropyLoss(
97 | weight=torch.tensor([0.33, 1.0]).cuda(), reduction="none"
98 | )
99 | self.width_loss = nn.CrossEntropyLoss(reduction="none", ignore_index=0)
100 | self.gamma = 1.0 # used by focal loss (if enabled)
101 |
102 | def forward(
103 | self,
104 | logits_s1,
105 | logits_edge_hybrid,
106 | logits_edge_rel,
107 | logits_width_hybrid,
108 | logits_width_rel,
109 | s2_ids,
110 | s2_edge_mask,
111 | edge_labels,
112 | width_labels,
113 | edge_lengths,
114 | edge_mask,
115 | s2_gt_values,
116 | ):
117 | # ensure batch size of 1 and no padding
118 | assert len(logits_s1) == 1
119 | assert not s2_edge_mask.any()
120 |
121 | # loss for stage-1: edge filtering
122 | s1_losses = self.edge_loss(logits_s1, edge_labels)
123 | s1_losses[torch.where(edge_mask == True)] = 0
124 | s1_losses = s1_losses[torch.where(s1_losses > 0)].sum() / edge_mask.shape[0]
125 | gt_values = torch.ones_like(edge_mask).long() * 2
126 | s1_acc = edge_acc(logits_s1, edge_labels, edge_lengths, gt_values)
127 |
128 | # loss for stage-2
129 | s2_edge_labels = torch.gather(edge_labels, 1, s2_ids)
130 | s2_width_labels = torch.gather(width_labels, 1, s2_ids)
131 |
132 | ### Edge losses ###
133 |
134 | s2_losses_e_hybrid = self.edge_loss(logits_edge_hybrid, s2_edge_labels)
135 | s2_losses_e_hybrid[
136 | torch.where((s2_edge_mask == True) | (s2_gt_values != 2))
137 | ] = 0
138 | # aggregate the loss into the final scalar
139 | s2_losses_e_hybrid = (
140 | s2_losses_e_hybrid[torch.where(s2_losses_e_hybrid > 0)].sum()
141 | / s2_edge_mask.shape[0]
142 | )
143 | s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1)
144 | # compute edge-level f1-score
145 | s2_acc_e_hybrid = edge_acc(
146 | logits_edge_hybrid, s2_edge_labels, s2_edge_lengths, s2_gt_values
147 | )
148 |
149 | s2_losses_e_rel = self.edge_loss(logits_edge_rel, s2_edge_labels)
150 | s2_losses_e_rel[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0
151 | # aggregate the loss into the final scalar
152 | s2_losses_e_rel = (
153 | s2_losses_e_rel[torch.where(s2_losses_e_rel > 0)].sum()
154 | / s2_edge_mask.shape[0]
155 | )
156 | s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1)
157 | # compute edge-level f1-score
158 | s2_acc_e_rel = edge_acc(
159 | logits_edge_rel, s2_edge_labels, s2_edge_lengths, s2_gt_values
160 | )
161 |
162 | ### Width losses ###
163 |
164 | s2_losses_w_hybrid = self.width_loss(logits_width_hybrid, s2_width_labels)
165 | s2_losses_w_hybrid[torch.where(s2_edge_mask == True)] = 0
166 | # aggregate the loss into the final scalar
167 | s2_losses_w_hybrid = (
168 | s2_losses_w_hybrid[torch.where(s2_losses_w_hybrid > 0)].sum()
169 | / s2_edge_mask.shape[0]
170 | )
171 | s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1)
172 | # compute edge-level f1-score
173 | preds_width_hybrid = logits_width_hybrid.argmax(1)
174 | s2_acc_w_hybrid = (
175 | (
176 | preds_width_hybrid[s2_width_labels != 0]
177 | == s2_width_labels[s2_width_labels != 0]
178 | )
179 | .float()
180 | .mean()
181 | )
182 |
183 | s2_losses_w_rel = self.width_loss(logits_width_rel, s2_width_labels)
184 | s2_losses_w_rel[torch.where(s2_edge_mask == True)] = 0
185 | # aggregate the loss into the final scalar
186 | s2_losses_w_rel = (
187 | s2_losses_w_rel[torch.where(s2_losses_w_rel > 0)].sum()
188 | / s2_edge_mask.shape[0]
189 | )
190 | s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1)
191 | # compute edge-level f1-score
192 | preds_width_rel = logits_width_rel.argmax(1)
193 | s2_acc_w_rel = (
194 | (
195 | preds_width_rel[s2_width_labels != 0]
196 | == s2_width_labels[s2_width_labels != 0]
197 | )
198 | .float()
199 | .mean()
200 | )
201 |
202 | return (
203 | s1_losses,
204 | s1_acc,
205 | s2_losses_e_hybrid,
206 | s2_acc_e_hybrid,
207 | s2_losses_e_rel,
208 | s2_acc_e_rel,
209 | s2_losses_w_hybrid,
210 | s2_acc_w_hybrid,
211 | s2_losses_w_rel,
212 | s2_acc_w_rel,
213 | )
214 |
--------------------------------------------------------------------------------
/code/learn/models/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class MLP(nn.Module):
6 | """ Very simple multi-layer perceptron (also called FFN)"""
7 |
8 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
9 | super(MLP, self).__init__()
10 | self.output_dim = output_dim
11 | self.num_layers = num_layers
12 | h = [hidden_dim] * (num_layers - 1)
13 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
14 |
15 | def forward(self, x):
16 | B, N, D = x.size()
17 | x = x.reshape(B*N, D)
18 | for i, layer in enumerate(self.layers):
19 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
20 | x = x.view(B, N, self.output_dim)
21 | return x
22 |
--------------------------------------------------------------------------------
/code/learn/models/ops/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | dist/
3 | MultiScale*/
--------------------------------------------------------------------------------
/code/learn/models/ops/functions/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn_func import MSDeformAttnFunction
10 |
11 |
--------------------------------------------------------------------------------
/code/learn/models/ops/functions/ms_deform_attn_func.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.autograd import Function
16 | from torch.autograd.function import once_differentiable
17 |
18 | import MultiScaleDeformableAttention as MSDA
19 |
20 |
21 | class MSDeformAttnFunction(Function):
22 | @staticmethod
23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24 | ctx.im2col_step = im2col_step
25 | output = MSDA.ms_deform_attn_forward(
26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28 | return output
29 |
30 | @staticmethod
31 | @once_differentiable
32 | def backward(ctx, grad_output):
33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34 | grad_value, grad_sampling_loc, grad_attn_weight = \
35 | MSDA.ms_deform_attn_backward(
36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37 |
38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39 |
40 |
41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42 | # for debug and test only,
43 | # need to use cuda version instead
44 | N_, S_, M_, D_ = value.shape
45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47 | sampling_grids = 2 * sampling_locations - 1
48 | sampling_value_list = []
49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54 | # N_*M_, D_, Lq_, P_
55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56 | mode='bilinear', padding_mode='zeros', align_corners=False)
57 | sampling_value_list.append(sampling_value_l_)
58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61 | return output.transpose(1, 2).contiguous()
62 |
--------------------------------------------------------------------------------
/code/learn/models/ops/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # ------------------------------------------------------------------------------------------------
3 | # Deformable DETR
4 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------------------------------
7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | # ------------------------------------------------------------------------------------------------
9 |
10 | python3 setup.py build install
11 |
--------------------------------------------------------------------------------
/code/learn/models/ops/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn import MSDeformAttn
10 |
--------------------------------------------------------------------------------
/code/learn/models/ops/modules/ms_deform_attn.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import warnings
14 | import math
15 |
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 | from torch.nn.init import xavier_uniform_, constant_
20 |
21 | from ..functions import MSDeformAttnFunction
22 |
23 |
24 | def _is_power_of_2(n):
25 | if (not isinstance(n, int)) or (n < 0):
26 | raise ValueError(
27 | "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
28 | )
29 | return (n & (n - 1) == 0) and n != 0
30 |
31 |
32 | class MSDeformAttn(nn.Module):
33 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
34 | """
35 | Multi-Scale Deformable Attention Module
36 | :param d_model hidden dimension
37 | :param n_levels number of feature levels
38 | :param n_heads number of attention heads
39 | :param n_points number of sampling points per attention head per feature level
40 | """
41 | super().__init__()
42 | if d_model % n_heads != 0:
43 | raise ValueError(
44 | "d_model must be divisible by n_heads, but got {} and {}".format(
45 | d_model, n_heads
46 | )
47 | )
48 | _d_per_head = d_model // n_heads
49 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
50 | if not _is_power_of_2(_d_per_head):
51 | warnings.warn(
52 | "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
53 | "which is more efficient in our CUDA implementation."
54 | )
55 |
56 | self.im2col_step = 64
57 |
58 | self.d_model = d_model
59 | self.n_levels = n_levels
60 | self.n_heads = n_heads
61 | self.n_points = n_points
62 |
63 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
64 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
65 | self.value_proj = nn.Linear(d_model, d_model)
66 | self.output_proj = nn.Linear(d_model, d_model)
67 |
68 | self._reset_parameters()
69 |
70 | def _reset_parameters(self):
71 | constant_(self.sampling_offsets.weight.data, 0.0)
72 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (
73 | 2.0 * math.pi / self.n_heads
74 | )
75 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
76 | grid_init = (
77 | (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
78 | .view(self.n_heads, 1, 1, 2)
79 | .repeat(1, self.n_levels, self.n_points, 1)
80 | )
81 | for i in range(self.n_points):
82 | grid_init[:, :, i, :] *= i + 1
83 | with torch.no_grad():
84 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
85 | constant_(self.attention_weights.weight.data, 0.0)
86 | constant_(self.attention_weights.bias.data, 0.0)
87 | xavier_uniform_(self.value_proj.weight.data)
88 | constant_(self.value_proj.bias.data, 0.0)
89 | xavier_uniform_(self.output_proj.weight.data)
90 | constant_(self.output_proj.bias.data, 0.0)
91 |
92 | def forward(
93 | self,
94 | query,
95 | reference_points,
96 | input_flatten,
97 | input_spatial_shapes,
98 | input_level_start_index,
99 | input_padding_mask=None,
100 | return_sample_locs=False,
101 | ):
102 | """
103 | :param query (N, Length_{query}, C)
104 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
105 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
106 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
107 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
108 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
109 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
110 |
111 | :return output (N, Length_{query}, C)
112 | """
113 | N, Len_q, _ = query.shape
114 | N, Len_in, _ = input_flatten.shape
115 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
116 |
117 | value = self.value_proj(input_flatten)
118 | if input_padding_mask is not None:
119 | value = value.masked_fill(input_padding_mask[..., None], float(0))
120 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
121 | sampling_offsets = self.sampling_offsets(query).view(
122 | N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
123 | )
124 | attention_weights = self.attention_weights(query).view(
125 | N, Len_q, self.n_heads, self.n_levels * self.n_points
126 | )
127 | attention_weights = F.softmax(attention_weights, -1).view(
128 | N, Len_q, self.n_heads, self.n_levels, self.n_points
129 | )
130 | # N, Len_q, n_heads, n_levels, n_points, 2
131 | if reference_points.shape[-1] == 2:
132 | offset_normalizer = torch.stack(
133 | [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
134 | )
135 | sampling_locations = (
136 | reference_points[:, :, None, :, None, :]
137 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
138 | )
139 | elif reference_points.shape[-1] == 4:
140 | sampling_locations = (
141 | reference_points[:, :, None, :, None, :2]
142 | + sampling_offsets
143 | / self.n_points
144 | * reference_points[:, :, None, :, None, 2:]
145 | * 0.5
146 | )
147 | else:
148 | raise ValueError(
149 | "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
150 | reference_points.shape[-1]
151 | )
152 | )
153 | output = MSDeformAttnFunction.apply(
154 | value,
155 | input_spatial_shapes,
156 | input_level_start_index,
157 | sampling_locations,
158 | attention_weights,
159 | self.im2col_step,
160 | )
161 | output = self.output_proj(output)
162 |
163 | if return_sample_locs:
164 | return output, sampling_locations.detach().cpu()
165 | else:
166 | return output
167 |
--------------------------------------------------------------------------------
/code/learn/models/ops/setup.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | import os
10 | import glob
11 |
12 | import torch
13 |
14 | from torch.utils.cpp_extension import CUDA_HOME
15 | from torch.utils.cpp_extension import CppExtension
16 | from torch.utils.cpp_extension import CUDAExtension
17 |
18 | from setuptools import find_packages
19 | from setuptools import setup
20 |
21 | requirements = ["torch", "torchvision"]
22 |
23 | def get_extensions():
24 | this_dir = os.path.dirname(os.path.abspath(__file__))
25 | extensions_dir = os.path.join(this_dir, "src")
26 |
27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
30 |
31 | sources = main_file + source_cpu
32 | extension = CppExtension
33 | extra_compile_args = {"cxx": []}
34 | define_macros = []
35 |
36 | if torch.cuda.is_available() and CUDA_HOME is not None:
37 | extension = CUDAExtension
38 | sources += source_cuda
39 | define_macros += [("WITH_CUDA", None)]
40 | extra_compile_args["nvcc"] = [
41 | "-DCUDA_HAS_FP16=1",
42 | "-D__CUDA_NO_HALF_OPERATORS__",
43 | "-D__CUDA_NO_HALF_CONVERSIONS__",
44 | "-D__CUDA_NO_HALF2_OPERATORS__",
45 | ]
46 | else:
47 | raise NotImplementedError('Cuda is not availabel')
48 |
49 | sources = [os.path.join(extensions_dir, s) for s in sources]
50 | include_dirs = [extensions_dir]
51 | ext_modules = [
52 | extension(
53 | "MultiScaleDeformableAttention",
54 | sources,
55 | include_dirs=include_dirs,
56 | define_macros=define_macros,
57 | extra_compile_args=extra_compile_args,
58 | )
59 | ]
60 | return ext_modules
61 |
62 | setup(
63 | name="MultiScaleDeformableAttention",
64 | version="1.0",
65 | author="Weijie Su",
66 | url="https://github.com/fundamentalvision/Deformable-DETR",
67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
68 | packages=find_packages(exclude=("configs", "tests",)),
69 | ext_modules=get_extensions(),
70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
71 | )
72 |
--------------------------------------------------------------------------------
/code/learn/models/ops/src/cpu/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 |
13 | #include
14 | #include
15 |
16 |
17 | at::Tensor
18 | ms_deform_attn_cpu_forward(
19 | const at::Tensor &value,
20 | const at::Tensor &spatial_shapes,
21 | const at::Tensor &level_start_index,
22 | const at::Tensor &sampling_loc,
23 | const at::Tensor &attn_weight,
24 | const int im2col_step)
25 | {
26 | AT_ERROR("Not implement on cpu");
27 | }
28 |
29 | std::vector
30 | ms_deform_attn_cpu_backward(
31 | const at::Tensor &value,
32 | const at::Tensor &spatial_shapes,
33 | const at::Tensor &level_start_index,
34 | const at::Tensor &sampling_loc,
35 | const at::Tensor &attn_weight,
36 | const at::Tensor &grad_output,
37 | const int im2col_step)
38 | {
39 | AT_ERROR("Not implement on cpu");
40 | }
41 |
42 |
--------------------------------------------------------------------------------
/code/learn/models/ops/src/cpu/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor
15 | ms_deform_attn_cpu_forward(
16 | const at::Tensor &value,
17 | const at::Tensor &spatial_shapes,
18 | const at::Tensor &level_start_index,
19 | const at::Tensor &sampling_loc,
20 | const at::Tensor &attn_weight,
21 | const int im2col_step);
22 |
23 | std::vector
24 | ms_deform_attn_cpu_backward(
25 | const at::Tensor &value,
26 | const at::Tensor &spatial_shapes,
27 | const at::Tensor &level_start_index,
28 | const at::Tensor &sampling_loc,
29 | const at::Tensor &attn_weight,
30 | const at::Tensor &grad_output,
31 | const int im2col_step);
32 |
33 |
34 |
--------------------------------------------------------------------------------
/code/learn/models/ops/src/cuda/ms_deform_attn_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 | #include "cuda/ms_deform_im2col_cuda.cuh"
13 |
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 |
20 | at::Tensor ms_deform_attn_cuda_forward(
21 | const at::Tensor &value,
22 | const at::Tensor &spatial_shapes,
23 | const at::Tensor &level_start_index,
24 | const at::Tensor &sampling_loc,
25 | const at::Tensor &attn_weight,
26 | const int im2col_step)
27 | {
28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
33 |
34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
39 |
40 | const int batch = value.size(0);
41 | const int spatial_size = value.size(1);
42 | const int num_heads = value.size(2);
43 | const int channels = value.size(3);
44 |
45 | const int num_levels = spatial_shapes.size(0);
46 |
47 | const int num_query = sampling_loc.size(1);
48 | const int num_point = sampling_loc.size(4);
49 |
50 | const int im2col_step_ = std::min(batch, im2col_step);
51 |
52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
53 |
54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
55 |
56 | const int batch_n = im2col_step_;
57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
58 | auto per_value_size = spatial_size * num_heads * channels;
59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
61 | for (int n = 0; n < batch/im2col_step_; ++n)
62 | {
63 | auto columns = output_n.select(0, n);
64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66 | value.data() + n * im2col_step_ * per_value_size,
67 | spatial_shapes.data(),
68 | level_start_index.data(),
69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
72 | columns.data());
73 |
74 | }));
75 | }
76 |
77 | output = output.view({batch, num_query, num_heads*channels});
78 |
79 | return output;
80 | }
81 |
82 |
83 | std::vector ms_deform_attn_cuda_backward(
84 | const at::Tensor &value,
85 | const at::Tensor &spatial_shapes,
86 | const at::Tensor &level_start_index,
87 | const at::Tensor &sampling_loc,
88 | const at::Tensor &attn_weight,
89 | const at::Tensor &grad_output,
90 | const int im2col_step)
91 | {
92 |
93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
99 |
100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
106 |
107 | const int batch = value.size(0);
108 | const int spatial_size = value.size(1);
109 | const int num_heads = value.size(2);
110 | const int channels = value.size(3);
111 |
112 | const int num_levels = spatial_shapes.size(0);
113 |
114 | const int num_query = sampling_loc.size(1);
115 | const int num_point = sampling_loc.size(4);
116 |
117 | const int im2col_step_ = std::min(batch, im2col_step);
118 |
119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
120 |
121 | auto grad_value = at::zeros_like(value);
122 | auto grad_sampling_loc = at::zeros_like(sampling_loc);
123 | auto grad_attn_weight = at::zeros_like(attn_weight);
124 |
125 | const int batch_n = im2col_step_;
126 | auto per_value_size = spatial_size * num_heads * channels;
127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
130 |
131 | for (int n = 0; n < batch/im2col_step_; ++n)
132 | {
133 | auto grad_output_g = grad_output_n.select(0, n);
134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136 | grad_output_g.data(),
137 | value.data() + n * im2col_step_ * per_value_size,
138 | spatial_shapes.data(),
139 | level_start_index.data(),
140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
143 | grad_value.data() + n * im2col_step_ * per_value_size,
144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
146 |
147 | }));
148 | }
149 |
150 | return {
151 | grad_value, grad_sampling_loc, grad_attn_weight
152 | };
153 | }
--------------------------------------------------------------------------------
/code/learn/models/ops/src/cuda/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor ms_deform_attn_cuda_forward(
15 | const at::Tensor &value,
16 | const at::Tensor &spatial_shapes,
17 | const at::Tensor &level_start_index,
18 | const at::Tensor &sampling_loc,
19 | const at::Tensor &attn_weight,
20 | const int im2col_step);
21 |
22 | std::vector ms_deform_attn_cuda_backward(
23 | const at::Tensor &value,
24 | const at::Tensor &spatial_shapes,
25 | const at::Tensor &level_start_index,
26 | const at::Tensor &sampling_loc,
27 | const at::Tensor &attn_weight,
28 | const at::Tensor &grad_output,
29 | const int im2col_step);
30 |
31 |
--------------------------------------------------------------------------------
/code/learn/models/ops/src/ms_deform_attn.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 |
13 | #include "cpu/ms_deform_attn_cpu.h"
14 |
15 | #ifdef WITH_CUDA
16 | #include "cuda/ms_deform_attn_cuda.h"
17 | #endif
18 |
19 |
20 | at::Tensor
21 | ms_deform_attn_forward(
22 | const at::Tensor &value,
23 | const at::Tensor &spatial_shapes,
24 | const at::Tensor &level_start_index,
25 | const at::Tensor &sampling_loc,
26 | const at::Tensor &attn_weight,
27 | const int im2col_step)
28 | {
29 | if (value.type().is_cuda())
30 | {
31 | #ifdef WITH_CUDA
32 | return ms_deform_attn_cuda_forward(
33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | AT_ERROR("Not implemented on the CPU");
39 | }
40 |
41 | std::vector
42 | ms_deform_attn_backward(
43 | const at::Tensor &value,
44 | const at::Tensor &spatial_shapes,
45 | const at::Tensor &level_start_index,
46 | const at::Tensor &sampling_loc,
47 | const at::Tensor &attn_weight,
48 | const at::Tensor &grad_output,
49 | const int im2col_step)
50 | {
51 | if (value.type().is_cuda())
52 | {
53 | #ifdef WITH_CUDA
54 | return ms_deform_attn_cuda_backward(
55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56 | #else
57 | AT_ERROR("Not compiled with GPU support");
58 | #endif
59 | }
60 | AT_ERROR("Not implemented on the CPU");
61 | }
62 |
63 |
--------------------------------------------------------------------------------
/code/learn/models/ops/src/vision.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include "ms_deform_attn.h"
12 |
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16 | }
17 |
--------------------------------------------------------------------------------
/code/learn/models/ops/test.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import time
14 | import torch
15 | import torch.nn as nn
16 | from torch.autograd import gradcheck
17 |
18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
19 |
20 |
21 | N, M, D = 1, 2, 2
22 | Lq, L, P = 2, 2, 2
23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
25 | S = sum([(H*W).item() for H, W in shapes])
26 |
27 |
28 | torch.manual_seed(3)
29 |
30 |
31 | @torch.no_grad()
32 | def check_forward_equal_with_pytorch_double():
33 | value = torch.rand(N, S, M, D).cuda() * 0.01
34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
37 | im2col_step = 2
38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
40 | fwdok = torch.allclose(output_cuda, output_pytorch)
41 | max_abs_err = (output_cuda - output_pytorch).abs().max()
42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
43 |
44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
45 |
46 |
47 | @torch.no_grad()
48 | def check_forward_equal_with_pytorch_float():
49 | value = torch.rand(N, S, M, D).cuda() * 0.01
50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
53 | im2col_step = 2
54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
57 | max_abs_err = (output_cuda - output_pytorch).abs().max()
58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
59 |
60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
61 |
62 |
63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
64 |
65 | value = torch.rand(N, S, M, channels).cuda() * 0.01
66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
69 | im2col_step = 2
70 | func = MSDeformAttnFunction.apply
71 |
72 | value.requires_grad = grad_value
73 | sampling_locations.requires_grad = grad_sampling_loc
74 | attention_weights.requires_grad = grad_attn_weight
75 |
76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
77 |
78 | print(f'* {gradok} check_gradient_numerical(D={channels})')
79 |
80 |
81 | if __name__ == '__main__':
82 | check_forward_equal_with_pytorch_double()
83 | check_forward_equal_with_pytorch_float()
84 |
85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
86 | check_gradient_numerical(channels, True, True, True)
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/code/learn/models/order_class_models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from json import encoder
3 | from sys import flags
4 | import cv2
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | from tqdm import tqdm
9 | from models.mlp import MLP
10 |
11 | from models.deformable_transformer_full import (
12 | DeformableTransformerEncoderLayer,
13 | DeformableTransformerEncoder,
14 | DeformableTransformerDecoder,
15 | DeformableTransformerDecoderLayer,
16 | DeformableAttnDecoderLayer,
17 | )
18 | from models.ops.modules import MSDeformAttn
19 | from models.corner_models import PositionEmbeddingSine
20 | from models.unet import ResNetBackbone
21 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
22 | import torch.nn.functional as F
23 |
24 | # from utils.misc import NestedTensor
25 | from models.utils import pos_encode_1d, pos_encode_2d
26 |
27 | from shapely.geometry import Point, LineString, box, MultiLineString
28 | from shapely import affinity
29 |
30 |
31 | def unnormalize_edges(edges, normalize_param):
32 | lines = [((x0, y0), (x1, y1)) for (x0, y0, x1, y1) in edges]
33 | lines = MultiLineString(lines)
34 |
35 | # normalize so longest edge is 1000, and to not change aspect ratio
36 | (xfact, yfact) = normalize_param["scale"]
37 | xfact = 1 / xfact
38 | yfact = 1 / yfact
39 | lines = affinity.scale(lines, xfact=xfact, yfact=yfact, origin=(0, 0))
40 |
41 | # center edges around 0
42 | (xoff, yoff) = normalize_param["translation"]
43 | lines = affinity.translate(lines, xoff=-xoff, yoff=-yoff)
44 |
45 | # rotation
46 | if "rotation" in normalize_param.keys():
47 | angle = normalize_param["rotation"]
48 | lines = affinity.rotate(lines, -angle, origin=(0, 0))
49 |
50 | new_coords = [list(line.coords) for line in lines.geoms]
51 | new_coords = np.array(new_coords).reshape(-1, 4)
52 |
53 | return new_coords
54 |
55 |
56 | def vis_data(data):
57 | import matplotlib.pyplot as plt
58 |
59 | bs = len(data["floor_name"])
60 |
61 | for b_i in range(bs):
62 | edge_coords = data["edge_coords"][b_i].cpu().numpy()
63 | edge_mask = data["edge_coords_mask"][b_i].cpu().numpy()
64 | edge_order = data["edge_order"][b_i].cpu().numpy()
65 | label = data["label"][b_i].cpu().numpy()
66 |
67 | max_order = edge_order.max()
68 |
69 | for edge_i, (x0, y0, x1, y1) in enumerate(edge_coords):
70 | if not edge_mask[edge_i]:
71 | if edge_order[edge_i] == max_order:
72 | plt.plot([x0, x1], [y0, y1], "-or")
73 | else:
74 | plt.plot([x0, x1], [y0, y1], "-oy")
75 | plt.text(
76 | (x0 + x1) / 2, (y0 + y1) / 2, str(edge_order[edge_i]), color="c"
77 | )
78 |
79 | plt.title("%d" % label)
80 | plt.tight_layout()
81 | plt.show()
82 | plt.close()
83 |
84 |
85 | class EdgeTransformer(nn.Module):
86 | def __init__(
87 | self,
88 | d_model=512,
89 | nhead=8,
90 | num_encoder_layers=6,
91 | num_decoder_layers=6,
92 | dim_feedforward=1024,
93 | dropout=0.1,
94 | activation="relu",
95 | return_intermediate_dec=False,
96 | num_feature_levels=4,
97 | dec_n_points=4,
98 | enc_n_points=4,
99 | ):
100 | super(EdgeTransformer, self).__init__()
101 | self.d_model = d_model
102 |
103 | decoder_1_layer = DeformableTransformerDecoderLayer(
104 | d_model,
105 | dim_feedforward,
106 | dropout,
107 | activation,
108 | num_feature_levels,
109 | nhead,
110 | dec_n_points,
111 | )
112 | self.decoder_1 = DeformableTransformerDecoder(
113 | decoder_1_layer, num_decoder_layers, return_intermediate_dec, with_sa=True
114 | )
115 |
116 | self.type_embed = nn.Embedding(3, 128)
117 | self.input_head = nn.Linear(128 * 4, d_model)
118 |
119 | # self.final_head = MLP(
120 | # input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2
121 | # )
122 | self.final_head = nn.Linear(d_model, 2)
123 |
124 | self._reset_parameters()
125 |
126 | def _reset_parameters(self):
127 | for p in self.parameters():
128 | if p.dim() > 1:
129 | nn.init.xavier_uniform_(p)
130 | for m in self.modules():
131 | if isinstance(m, MSDeformAttn):
132 | m._reset_parameters()
133 |
134 | def get_geom_feats(self, coords):
135 | (bs, N, _) = coords.shape
136 | _coords = coords.reshape(bs * N, -1)
137 | _geom_enc_a = pos_encode_2d(x=_coords[:, 0], y=_coords[:, 1])
138 | _geom_enc_b = pos_encode_2d(x=_coords[:, 2], y=_coords[:, 3])
139 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1)
140 | geom_feats = _geom_feats.reshape(bs, N, -1)
141 |
142 | return geom_feats
143 |
144 | def forward(self, data):
145 | # vis_data(data)
146 |
147 | # obtain edge positional features
148 | edge_coords = data["edge_coords"]
149 | edge_mask = data["edge_coords_mask"]
150 | edge_order = data["edge_order"]
151 |
152 | dtype = edge_coords.dtype
153 | device = edge_coords.device
154 |
155 | # three types of nodes
156 | # 1. dummy node
157 | # 2. modelled node
158 | # 3. sequence node
159 |
160 | bs = len(edge_coords)
161 |
162 | # geom
163 | geom_feats = self.get_geom_feats(edge_coords)
164 | dummy_geom = torch.zeros(
165 | [bs, 1, geom_feats.shape[-1]], dtype=dtype, device=device
166 | )
167 | geom_feats = torch.cat([dummy_geom, geom_feats], dim=1)
168 |
169 | # node type (NEED TO BE BEFORE ORDER)
170 | max_order = 10
171 | node_type = torch.zeros_like(edge_order)
172 | node_type[edge_order < max_order] = 1
173 | node_type[edge_order == max_order] = 2
174 | assert not (node_type == 0).any()
175 |
176 | dummy_type = torch.zeros([bs, 1], dtype=dtype, device=device)
177 | node_type = torch.cat([dummy_type, node_type], dim=1)
178 | type_feats = self.type_embed(node_type.long())
179 |
180 | # order
181 | dummy_order = torch.zeros([bs, 1], dtype=dtype, device=device)
182 | edge_order = torch.cat([dummy_order, edge_order], dim=1)
183 | order_feats = pos_encode_1d(edge_order.flatten())
184 | order_feats = order_feats.reshape(edge_order.shape[0], edge_order.shape[1], -1)
185 |
186 | # need to pad the mask
187 | dummy_mask = torch.zeros([bs, 1], dtype=edge_mask.dtype, device=device)
188 | edge_mask = torch.cat([dummy_mask, edge_mask], dim=-1)
189 |
190 | # combine features
191 | edge_feats = torch.cat([geom_feats, type_feats, order_feats], dim=-1)
192 | edge_feats = self.input_head(edge_feats)
193 |
194 | # first do self-attention without any flags
195 | hs = self.decoder_1(
196 | edge_feats=edge_feats,
197 | geom_feats=geom_feats,
198 | image_feats=None,
199 | key_padding_mask=edge_mask,
200 | get_image_feat=False,
201 | )
202 |
203 | hs = self.final_head(hs)
204 |
205 | # return only the dummy node
206 | return hs[:, 0]
207 |
208 |
209 | class EdgeTransformer2(nn.Module):
210 | def __init__(
211 | self,
212 | d_model=256,
213 | nhead=8,
214 | num_encoder_layers=3,
215 | num_decoder_layers=3,
216 | dim_feedforward=1024,
217 | dropout=0.1,
218 | activation="relu",
219 | return_intermediate_dec=False,
220 | num_feature_levels=4,
221 | dec_n_points=4,
222 | enc_n_points=4,
223 | ):
224 | super(EdgeTransformer2, self).__init__()
225 | self.d_model = d_model
226 |
227 | encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
228 | self.encoder = nn.TransformerEncoder(
229 | encoder_layer, num_layers=num_encoder_layers
230 | )
231 |
232 | decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
233 | self.decoder = nn.TransformerDecoder(
234 | decoder_layer, num_layers=num_decoder_layers
235 | )
236 |
237 | self.type_embed = nn.Embedding(3, 128)
238 | self.input_head = nn.Linear(128 * 3 + 2, d_model)
239 |
240 | # self.final_head = MLP(
241 | # input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2
242 | # )
243 | self.final_head = nn.Linear(d_model, 2)
244 |
245 | self._reset_parameters()
246 |
247 | def _reset_parameters(self):
248 | for p in self.parameters():
249 | if p.dim() > 1:
250 | nn.init.xavier_uniform_(p)
251 | # for m in self.modules():
252 | # if isinstance(m, MSDeformAttn):
253 | # m._reset_parameters()
254 |
255 | def forward(self, data):
256 | # vis_data(data)
257 |
258 | # obtain edge positional features
259 | edge_coords = data["edge_coords"]
260 | edge_mask = data["edge_coords_mask"]
261 | modelled_coords = data["modelled_coords"]
262 | modelled_mask = data["modelled_coords_mask"]
263 |
264 | dtype = edge_coords.dtype
265 | device = edge_coords.device
266 |
267 | # first self-attention among the modelled edges
268 | (bs, N, _) = modelled_coords.shape
269 | _modelled_coords = modelled_coords.reshape(bs * N, -1)
270 | _geom_enc_a = pos_encode_2d(x=_modelled_coords[:, 0], y=_modelled_coords[:, 1])
271 | _geom_enc_b = pos_encode_2d(x=_modelled_coords[:, 2], y=_modelled_coords[:, 3])
272 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1)
273 | modelled_geom = _geom_feats.reshape(bs, N, -1)
274 | modelled_geom = modelled_geom.permute([1, 0, 2]) # bs,S,E -> S,bs,E
275 |
276 | memory = self.encoder(modelled_geom, src_key_padding_mask=modelled_mask)
277 |
278 | # then SA and CA among the sequence edges
279 | (bs, N, _) = edge_coords.shape
280 | _edge_coords = edge_coords.reshape(bs * N, -1)
281 | _geom_enc_a = pos_encode_2d(x=_edge_coords[:, 0], y=_edge_coords[:, 1])
282 | _geom_enc_b = pos_encode_2d(x=_edge_coords[:, 2], y=_edge_coords[:, 3])
283 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1)
284 | seq_geom = _geom_feats.reshape(bs, N, -1)
285 | seq_geom = seq_geom.permute([1, 0, 2]) # bs,S,E -> S,bs,E
286 |
287 | dummy_geom = torch.zeros(
288 | [1, bs, seq_geom.shape[-1]], dtype=dtype, device=device
289 | )
290 | seq_geom = torch.cat([dummy_geom, seq_geom], dim=0)
291 |
292 | # for edge positions, dummy node is 0, modelled edges are 0, sequence start at 1
293 | pos_inds = torch.arange(N + 1, dtype=dtype, device=device)
294 | pos_feats = pos_encode_1d(pos_inds)
295 | pos_feats = pos_feats.unsqueeze(1).repeat(1, bs, 1)
296 |
297 | # for edge type, dummy node is 0, modelled edges are 1, sequence start at 2
298 | flag_feats = torch.zeros([N + 1, bs], dtype=torch.int64, device=device)
299 | flag_feats[1:, :] = 1
300 | flag_feats = nn.functional.one_hot(flag_feats, num_classes=2)
301 |
302 | edge_feats = torch.cat([seq_geom, pos_feats, flag_feats], dim=-1)
303 | edge_feats = self.input_head(edge_feats)
304 |
305 | # need to pad the mask
306 | dummy_mask = torch.zeros([bs, 1], dtype=edge_mask.dtype, device=device)
307 | edge_mask = torch.cat([dummy_mask, edge_mask], dim=-1)
308 |
309 | # first do self-attention without any flags
310 | hs = self.decoder(
311 | tgt=edge_feats,
312 | memory=memory,
313 | tgt_key_padding_mask=edge_mask,
314 | memory_key_padding_mask=modelled_mask,
315 | )
316 | return self.final_head(hs[0])
317 |
318 | # hs = self.final_head(hs)
319 |
320 | # # return only the dummy node
321 | # return hs[:, 0]
322 |
--------------------------------------------------------------------------------
/code/learn/models/order_metric_models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from json import encoder
3 | from sys import flags
4 | import cv2
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | from tqdm import tqdm
9 | from models.mlp import MLP
10 |
11 | from models.deformable_transformer_full import (
12 | DeformableTransformerEncoderLayer,
13 | DeformableTransformerEncoder,
14 | DeformableTransformerDecoder,
15 | DeformableTransformerDecoderLayer,
16 | DeformableAttnDecoderLayer,
17 | )
18 | from models.ops.modules import MSDeformAttn
19 | from models.corner_models import PositionEmbeddingSine
20 | from models.unet import ResNetBackbone
21 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
22 | import torch.nn.functional as F
23 |
24 | # from utils.misc import NestedTensor
25 | from models.utils import pos_encode_1d, pos_encode_2d
26 |
27 | from shapely.geometry import Point, LineString, box, MultiLineString
28 | from shapely import affinity
29 |
30 |
31 | def unnormalize_edges(edges, normalize_param):
32 | lines = [((x0, y0), (x1, y1)) for (x0, y0, x1, y1) in edges]
33 | lines = MultiLineString(lines)
34 |
35 | # normalize so longest edge is 1000, and to not change aspect ratio
36 | (xfact, yfact) = normalize_param["scale"]
37 | xfact = 1 / xfact
38 | yfact = 1 / yfact
39 | lines = affinity.scale(lines, xfact=xfact, yfact=yfact, origin=(0, 0))
40 |
41 | # center edges around 0
42 | (xoff, yoff) = normalize_param["translation"]
43 | lines = affinity.translate(lines, xoff=-xoff, yoff=-yoff)
44 |
45 | # rotation
46 | if "rotation" in normalize_param.keys():
47 | angle = normalize_param["rotation"]
48 | lines = affinity.rotate(lines, -angle, origin=(0, 0))
49 |
50 | new_coords = [list(line.coords) for line in lines.geoms]
51 | new_coords = np.array(new_coords).reshape(-1, 4)
52 |
53 | return new_coords
54 |
55 |
56 | def vis_data(data):
57 | import matplotlib.pyplot as plt
58 |
59 | bs = len(data["floor_name"])
60 |
61 | for b_i in range(bs):
62 | edge_coords = data["edge_coords"][b_i].cpu().numpy()
63 | edge_mask = data["edge_coords_mask"][b_i].cpu().numpy()
64 | edge_order = data["edge_order"][b_i].cpu().numpy()
65 | label = data["label"][b_i].cpu().numpy()
66 |
67 | for edge_i, (x0, y0, x1, y1) in enumerate(edge_coords):
68 | if not edge_mask[edge_i]:
69 | if label[edge_i] == 0:
70 | assert edge_order[edge_i] == 0
71 | plt.plot([x0, x1], [y0, y1], "-or")
72 | elif label[edge_i] == 3:
73 | assert edge_order[edge_i] > 0
74 | plt.plot([x0, x1], [y0, y1], "--oc")
75 | plt.text(
76 | (x0 + x1) / 2, (y0 + y1) / 2, str(edge_order[edge_i]), color="c"
77 | )
78 | elif label[edge_i] == 1:
79 | assert edge_order[edge_i] == 1
80 | plt.plot([x0, x1], [y0, y1], "-oy")
81 | plt.text(
82 | (x0 + x1) / 2, (y0 + y1) / 2, str(edge_order[edge_i]), color="c"
83 | )
84 | elif label[edge_i] == 2:
85 | assert edge_order[edge_i] == 0
86 | plt.plot([x0, x1], [y0, y1], "-og")
87 | plt.text(
88 | (x0 + x1) / 2, (y0 + y1) / 2, str(edge_order[edge_i]), color="c"
89 | )
90 | else:
91 | raise Exception
92 |
93 | # plt.title("%d" % label)
94 | plt.tight_layout()
95 | plt.show()
96 | plt.close()
97 |
98 |
99 | class EdgeTransformer(nn.Module):
100 | def __init__(
101 | self,
102 | d_model=256,
103 | nhead=8,
104 | num_encoder_layers=6,
105 | num_decoder_layers=6,
106 | dim_feedforward=1024,
107 | dropout=0.1,
108 | activation="relu",
109 | return_intermediate_dec=False,
110 | num_feature_levels=4,
111 | dec_n_points=4,
112 | enc_n_points=4,
113 | ):
114 | super(EdgeTransformer, self).__init__()
115 | self.d_model = d_model
116 |
117 | decoder_1_layer = DeformableTransformerDecoderLayer(
118 | d_model,
119 | dim_feedforward,
120 | dropout,
121 | activation,
122 | num_feature_levels,
123 | nhead,
124 | dec_n_points,
125 | )
126 | self.decoder_1 = DeformableTransformerDecoder(
127 | decoder_1_layer, num_decoder_layers, return_intermediate_dec, with_sa=True
128 | )
129 |
130 | self.type_embed = nn.Embedding(3, 128)
131 | self.input_head = nn.Linear(128 * 4, d_model)
132 |
133 | # self.final_head = MLP(
134 | # input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2
135 | # )
136 | self.final_head = nn.Linear(d_model, d_model)
137 |
138 | self._reset_parameters()
139 |
140 | def _reset_parameters(self):
141 | for p in self.parameters():
142 | if p.dim() > 1:
143 | nn.init.xavier_uniform_(p)
144 | for m in self.modules():
145 | if isinstance(m, MSDeformAttn):
146 | m._reset_parameters()
147 |
148 | def get_geom_feats(self, coords):
149 | (bs, N, _) = coords.shape
150 | _coords = coords.reshape(bs * N, -1)
151 | _geom_enc_a = pos_encode_2d(x=_coords[:, 0], y=_coords[:, 1])
152 | _geom_enc_b = pos_encode_2d(x=_coords[:, 2], y=_coords[:, 3])
153 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1)
154 | geom_feats = _geom_feats.reshape(bs, N, -1)
155 |
156 | return geom_feats
157 |
158 | def forward(self, data):
159 | # vis_data(data)
160 |
161 | # obtain edge positional features
162 | edge_coords = data["edge_coords"]
163 | edge_mask = data["edge_coords_mask"]
164 | edge_order = data["edge_order"]
165 |
166 | dtype = edge_coords.dtype
167 | device = edge_coords.device
168 |
169 | # three types of nodes
170 | # 1. dummy node
171 | # 2. modelled node
172 | # 3. sequence node
173 |
174 | bs = len(edge_coords)
175 |
176 | # geom
177 | geom_feats = self.get_geom_feats(edge_coords)
178 |
179 | # node type (NEED TO BE BEFORE ORDER)
180 | max_order = 10
181 | node_type = torch.ones_like(edge_order)
182 | node_type[edge_order == 0] = 0
183 | node_type[edge_order == max_order] = 2
184 | type_feats = self.type_embed(node_type.long())
185 |
186 | # order
187 | order_feats = pos_encode_1d(edge_order.flatten())
188 | order_feats = order_feats.reshape(edge_order.shape[0], edge_order.shape[1], -1)
189 |
190 | # combine features and also add a dummy node
191 | edge_feats = torch.cat([geom_feats, type_feats, order_feats], dim=-1)
192 | edge_feats = self.input_head(edge_feats)
193 |
194 | # first do self-attention without any flags
195 | hs = self.decoder_1(
196 | edge_feats=edge_feats,
197 | geom_feats=geom_feats,
198 | image_feats=None,
199 | key_padding_mask=edge_mask,
200 | get_image_feat=False,
201 | )
202 |
203 | hs = self.final_head(hs)
204 |
205 | # return only the dummy node
206 | return hs
207 |
208 |
209 | class EdgeTransformer2(nn.Module):
210 | def __init__(
211 | self,
212 | d_model=256,
213 | nhead=8,
214 | num_encoder_layers=3,
215 | num_decoder_layers=3,
216 | dim_feedforward=1024,
217 | dropout=0.1,
218 | activation="relu",
219 | return_intermediate_dec=False,
220 | num_feature_levels=4,
221 | dec_n_points=4,
222 | enc_n_points=4,
223 | ):
224 | super(EdgeTransformer2, self).__init__()
225 | self.d_model = d_model
226 |
227 | encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
228 | self.encoder = nn.TransformerEncoder(
229 | encoder_layer, num_layers=num_encoder_layers
230 | )
231 |
232 | decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
233 | self.decoder = nn.TransformerDecoder(
234 | decoder_layer, num_layers=num_decoder_layers
235 | )
236 |
237 | self.type_embed = nn.Embedding(3, 128)
238 | self.input_head = nn.Linear(128 * 3 + 2, d_model)
239 |
240 | # self.final_head = MLP(
241 | # input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2
242 | # )
243 | self.final_head = nn.Linear(d_model, 2)
244 |
245 | self._reset_parameters()
246 |
247 | def _reset_parameters(self):
248 | for p in self.parameters():
249 | if p.dim() > 1:
250 | nn.init.xavier_uniform_(p)
251 | # for m in self.modules():
252 | # if isinstance(m, MSDeformAttn):
253 | # m._reset_parameters()
254 |
255 | def forward(self, data):
256 | # vis_data(data)
257 |
258 | # obtain edge positional features
259 | edge_coords = data["edge_coords"]
260 | edge_mask = data["edge_coords_mask"]
261 | modelled_coords = data["modelled_coords"]
262 | modelled_mask = data["modelled_coords_mask"]
263 |
264 | dtype = edge_coords.dtype
265 | device = edge_coords.device
266 |
267 | # first self-attention among the modelled edges
268 | (bs, N, _) = modelled_coords.shape
269 | _modelled_coords = modelled_coords.reshape(bs * N, -1)
270 | _geom_enc_a = pos_encode_2d(x=_modelled_coords[:, 0], y=_modelled_coords[:, 1])
271 | _geom_enc_b = pos_encode_2d(x=_modelled_coords[:, 2], y=_modelled_coords[:, 3])
272 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1)
273 | modelled_geom = _geom_feats.reshape(bs, N, -1)
274 | modelled_geom = modelled_geom.permute([1, 0, 2]) # bs,S,E -> S,bs,E
275 |
276 | memory = self.encoder(modelled_geom, src_key_padding_mask=modelled_mask)
277 |
278 | # then SA and CA among the sequence edges
279 | (bs, N, _) = edge_coords.shape
280 | _edge_coords = edge_coords.reshape(bs * N, -1)
281 | _geom_enc_a = pos_encode_2d(x=_edge_coords[:, 0], y=_edge_coords[:, 1])
282 | _geom_enc_b = pos_encode_2d(x=_edge_coords[:, 2], y=_edge_coords[:, 3])
283 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1)
284 | seq_geom = _geom_feats.reshape(bs, N, -1)
285 | seq_geom = seq_geom.permute([1, 0, 2]) # bs,S,E -> S,bs,E
286 |
287 | dummy_geom = torch.zeros(
288 | [1, bs, seq_geom.shape[-1]], dtype=dtype, device=device
289 | )
290 | seq_geom = torch.cat([dummy_geom, seq_geom], dim=0)
291 |
292 | # for edge positions, dummy node is 0, modelled edges are 0, sequence start at 1
293 | pos_inds = torch.arange(N + 1, dtype=dtype, device=device)
294 | pos_feats = pos_encode_1d(pos_inds)
295 | pos_feats = pos_feats.unsqueeze(1).repeat(1, bs, 1)
296 |
297 | # for edge type, dummy node is 0, modelled edges are 1, sequence start at 2
298 | flag_feats = torch.zeros([N + 1, bs], dtype=torch.int64, device=device)
299 | flag_feats[1:, :] = 1
300 | flag_feats = nn.functional.one_hot(flag_feats, num_classes=2)
301 |
302 | edge_feats = torch.cat([seq_geom, pos_feats, flag_feats], dim=-1)
303 | edge_feats = self.input_head(edge_feats)
304 |
305 | # need to pad the mask
306 | dummy_mask = torch.zeros([bs, 1], dtype=edge_mask.dtype, device=device)
307 | edge_mask = torch.cat([dummy_mask, edge_mask], dim=-1)
308 |
309 | # first do self-attention without any flags
310 | hs = self.decoder(
311 | tgt=edge_feats,
312 | memory=memory,
313 | tgt_key_padding_mask=edge_mask,
314 | memory_key_padding_mask=modelled_mask,
315 | )
316 | return self.final_head(hs[0])
317 |
318 | # hs = self.final_head(hs)
319 |
320 | # # return only the dummy node
321 | # return hs[:, 0]
322 |
--------------------------------------------------------------------------------
/code/learn/models/stacked_hg.py:
--------------------------------------------------------------------------------
1 | """
2 | Hourglass network inserted in the pre-activated Resnet
3 | Use lr=0.01 for current version
4 | (c) Nan Xue (HAWP)
5 | (c) Yichao Zhou (LCNN)
6 | (c) YANG, Wei
7 | """
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | __all__ = ["HourglassNet", "hg"]
13 |
14 |
15 | class Bottleneck2D(nn.Module):
16 | expansion = 2
17 |
18 | def __init__(self, inplanes, planes, stride=1, downsample=None):
19 | super(Bottleneck2D, self).__init__()
20 |
21 | self.bn1 = nn.BatchNorm2d(inplanes)
22 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
25 | self.bn3 = nn.BatchNorm2d(planes)
26 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = downsample
29 | self.stride = stride
30 |
31 | def forward(self, x):
32 | residual = x
33 |
34 | out = self.bn1(x)
35 | out = self.relu(out)
36 | out = self.conv1(out)
37 |
38 | out = self.bn2(out)
39 | out = self.relu(out)
40 | out = self.conv2(out)
41 |
42 | out = self.bn3(out)
43 | out = self.relu(out)
44 | out = self.conv3(out)
45 |
46 | if self.downsample is not None:
47 | residual = self.downsample(x)
48 |
49 | out += residual
50 |
51 | return out
52 |
53 |
54 | class Hourglass(nn.Module):
55 | def __init__(self, block, num_blocks, planes, depth):
56 | super(Hourglass, self).__init__()
57 | self.depth = depth
58 | self.block = block
59 | self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
60 |
61 | def _make_residual(self, block, num_blocks, planes):
62 | layers = []
63 | for i in range(0, num_blocks):
64 | layers.append(block(planes * block.expansion, planes))
65 | return nn.Sequential(*layers)
66 |
67 | def _make_hour_glass(self, block, num_blocks, planes, depth):
68 | hg = []
69 | for i in range(depth):
70 | res = []
71 | for j in range(3):
72 | res.append(self._make_residual(block, num_blocks, planes))
73 | if i == 0:
74 | res.append(self._make_residual(block, num_blocks, planes))
75 | hg.append(nn.ModuleList(res))
76 | return nn.ModuleList(hg)
77 |
78 | def _hour_glass_forward(self, n, x):
79 | up1 = self.hg[n - 1][0](x)
80 | low1 = F.max_pool2d(x, 2, stride=2)
81 | low1 = self.hg[n - 1][1](low1)
82 |
83 | if n > 1:
84 | low2 = self._hour_glass_forward(n - 1, low1)
85 | else:
86 | low2 = self.hg[n - 1][3](low1)
87 | low3 = self.hg[n - 1][2](low2)
88 | up2 = F.interpolate(low3, scale_factor=2)
89 | out = up1 + up2
90 | return out
91 |
92 | def forward(self, x):
93 | return self._hour_glass_forward(self.depth, x)
94 |
95 |
96 | class HourglassNet(nn.Module):
97 | """Hourglass model from Newell et al ECCV 2016"""
98 |
99 | def __init__(self, inplanes, num_feats, block, head, depth, num_stacks, num_blocks, num_classes):
100 | super(HourglassNet, self).__init__()
101 |
102 | self.inplanes = inplanes
103 | self.num_feats = num_feats
104 | self.num_stacks = num_stacks
105 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
106 | self.bn1 = nn.BatchNorm2d(self.inplanes)
107 | self.relu = nn.ReLU(inplace=True)
108 | self.layer1 = self._make_residual(block, self.inplanes, 1)
109 | self.layer2 = self._make_residual(block, self.inplanes, 1)
110 | self.layer3 = self._make_residual(block, self.num_feats, 1)
111 | self.maxpool = nn.MaxPool2d(2, stride=2)
112 |
113 | # build hourglass modules
114 | ch = self.num_feats * block.expansion
115 | # vpts = []
116 | hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
117 | for i in range(num_stacks):
118 | hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
119 | res.append(self._make_residual(block, self.num_feats, num_blocks))
120 | fc.append(self._make_fc(ch, ch))
121 | score.append(head(ch, num_classes))
122 | # vpts.append(VptsHead(ch))
123 | # vpts.append(nn.Linear(ch, 9))
124 | # score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
125 | # score[i].bias.data[0] += 4.6
126 | # score[i].bias.data[2] += 4.6
127 | if i < num_stacks - 1:
128 | fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
129 | score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
130 | self.hg = nn.ModuleList(hg)
131 | self.res = nn.ModuleList(res)
132 | self.fc = nn.ModuleList(fc)
133 | self.score = nn.ModuleList(score)
134 | # self.vpts = nn.ModuleList(vpts)
135 | self.fc_ = nn.ModuleList(fc_)
136 | self.score_ = nn.ModuleList(score_)
137 |
138 | def _make_residual(self, block, planes, blocks, stride=1):
139 | downsample = None
140 | if stride != 1 or self.inplanes != planes * block.expansion:
141 | downsample = nn.Sequential(
142 | nn.Conv2d(
143 | self.inplanes,
144 | planes * block.expansion,
145 | kernel_size=1,
146 | stride=stride,
147 | )
148 | )
149 |
150 | layers = []
151 | layers.append(block(self.inplanes, planes, stride, downsample))
152 | self.inplanes = planes * block.expansion
153 | for i in range(1, blocks):
154 | layers.append(block(self.inplanes, planes))
155 |
156 | return nn.Sequential(*layers)
157 |
158 | def _make_fc(self, inplanes, outplanes):
159 | bn = nn.BatchNorm2d(inplanes)
160 | conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
161 | return nn.Sequential(conv, bn, self.relu)
162 |
163 | def forward(self, x):
164 | out = []
165 | x = self.conv1(x)
166 | x = self.bn1(x)
167 | x = self.relu(x)
168 |
169 | x = self.layer1(x)
170 | x = self.maxpool(x)
171 | x = self.layer2(x)
172 | x = self.layer3(x)
173 |
174 | for i in range(self.num_stacks):
175 | y = self.hg[i](x)
176 | y = self.res[i](y)
177 | y = self.fc[i](y)
178 | score = self.score[i](y)
179 | out.append(score)
180 |
181 | if i < self.num_stacks - 1:
182 | fc_ = self.fc_[i](y)
183 | score_ = self.score_[i](score)
184 | x = x + fc_ + score_
185 |
186 | return out[::-1], y
187 |
188 | def train(self, mode=True):
189 | # Override train so that the training mode is set as we want
190 | nn.Module.train(self, mode)
191 | if mode:
192 | # fix all bn layers
193 | def set_bn_eval(m):
194 | classname = m.__class__.__name__
195 | if classname.find('BatchNorm') != -1:
196 | m.eval()
197 |
198 | self.apply(set_bn_eval)
199 |
200 |
201 | class MultitaskHead(nn.Module):
202 | def __init__(self, input_channels, num_class, head_size):
203 | super(MultitaskHead, self).__init__()
204 |
205 | m = int(input_channels / 4)
206 | heads = []
207 | for output_channels in sum(head_size, []):
208 | heads.append(
209 | nn.Sequential(
210 | nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
211 | nn.ReLU(inplace=True),
212 | nn.Conv2d(m, output_channels, kernel_size=1),
213 | )
214 | )
215 | self.heads = nn.ModuleList(heads)
216 | assert num_class == sum(sum(head_size, []))
217 |
218 | def forward(self, x):
219 | return torch.cat([head(x) for head in self.heads], dim=1)
220 |
221 |
222 | def build_hg():
223 | inplanes = 64
224 | num_feats = 256 //2
225 | depth = 4
226 | num_stacks = 2
227 | num_blocks = 1
228 | head_size = [[2], [2]]
229 |
230 | out_feature_channels = 256
231 |
232 | num_class = sum(sum(head_size, []))
233 | model = HourglassNet(
234 | block=Bottleneck2D,
235 | inplanes = inplanes,
236 | num_feats= num_feats,
237 | depth=depth,
238 | head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size),
239 | num_stacks = num_stacks,
240 | num_blocks = num_blocks,
241 | num_classes = num_class)
242 |
243 | model.out_feature_channels = out_feature_channels
244 |
245 | return model
246 |
247 |
--------------------------------------------------------------------------------
/code/learn/models/tcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils import weight_norm
4 |
5 |
6 | def get_activation(name):
7 | def hook(model, input, output):
8 | activation[name] = output.detach()
9 |
10 | return hook
11 |
12 |
13 | class Chomp1d(nn.Module):
14 | def __init__(self, chomp_size):
15 | super(Chomp1d, self).__init__()
16 | self.chomp_size = chomp_size
17 |
18 | def forward(self, x):
19 | return x[:, :, : -self.chomp_size].contiguous()
20 |
21 |
22 | class TemporalBlock(nn.Module):
23 | def __init__(
24 | self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2
25 | ):
26 | super(TemporalBlock, self).__init__()
27 | self.conv1 = weight_norm(
28 | nn.Conv1d(
29 | n_inputs,
30 | n_outputs,
31 | kernel_size,
32 | stride=stride,
33 | padding=padding,
34 | dilation=dilation,
35 | )
36 | )
37 | self.chomp1 = Chomp1d(padding)
38 | self.relu1 = nn.ReLU()
39 | self.dropout1 = nn.Dropout(dropout)
40 |
41 | self.conv2 = weight_norm(
42 | nn.Conv1d(
43 | n_outputs,
44 | n_outputs,
45 | kernel_size,
46 | stride=stride,
47 | padding=padding,
48 | dilation=dilation,
49 | )
50 | )
51 | self.chomp2 = Chomp1d(padding)
52 | self.relu2 = nn.ReLU()
53 | self.dropout2 = nn.Dropout(dropout)
54 |
55 | self.net = nn.Sequential(
56 | self.conv1,
57 | self.chomp1,
58 | self.relu1,
59 | self.dropout1,
60 | self.conv2,
61 | self.chomp2,
62 | self.relu2,
63 | self.dropout2,
64 | )
65 | self.downsample = (
66 | nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
67 | )
68 | self.relu = nn.ReLU()
69 | self.init_weights()
70 |
71 | def init_weights(self):
72 | self.conv1.weight.data.normal_(0, 0.01)
73 | self.conv2.weight.data.normal_(0, 0.01)
74 | if self.downsample is not None:
75 | self.downsample.weight.data.normal_(0, 0.01)
76 |
77 | def forward(self, x):
78 | out = self.net(x)
79 | res = x if self.downsample is None else self.downsample(x)
80 | return self.relu(out + res)
81 |
82 |
83 | class TemporalConvNet(nn.Module):
84 | def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
85 | super(TemporalConvNet, self).__init__()
86 | layers = []
87 | num_levels = len(num_channels)
88 | for i in range(num_levels):
89 | dilation_size = 2**i
90 | in_channels = num_inputs if i == 0 else num_channels[i - 1]
91 | out_channels = num_channels[i]
92 | layers += [
93 | TemporalBlock(
94 | in_channels,
95 | out_channels,
96 | kernel_size,
97 | stride=1,
98 | dilation=dilation_size,
99 | padding=(kernel_size - 1) * dilation_size,
100 | dropout=dropout,
101 | )
102 | ]
103 |
104 | self.network = nn.Sequential(*layers)
105 |
106 | def forward(self, x):
107 | hidden = {}
108 |
109 | x_ = x
110 | for layer_i, layer in enumerate(self.network.children()):
111 | x_ = layer(x_)
112 | hidden[layer_i] = x_
113 |
114 | return x_, hidden
115 |
116 |
117 | class TCN(nn.Module):
118 | def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
119 | super(TCN, self).__init__()
120 | self.tcn = TemporalConvNet(
121 | input_size, num_channels, kernel_size=kernel_size, dropout=dropout
122 | )
123 | self.linear = nn.Linear(num_channels[-1], output_size)
124 | self.init_weights()
125 |
126 | def init_weights(self):
127 | self.linear.weight.data.normal_(0, 0.01)
128 |
129 | def forward(self, x):
130 | y1, _ = self.tcn(x)
131 | return self.linear(y1.transpose(1, 2))
132 |
133 | def get_hidden(self, x):
134 | # self.tcn.network[layer_idx].register_forward_hook(get_activation("Dropout"))
135 | _, hidden = self.tcn(x)
136 | return hidden
137 |
--------------------------------------------------------------------------------
/code/learn/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models
4 | import torch.nn.functional as F
5 | from torchvision.models._utils import IntermediateLayerGetter
6 | from utils.misc import NestedTensor
7 |
8 |
9 | def convrelu(in_channels, out_channels, kernel, padding):
10 | return nn.Sequential(
11 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
12 | nn.ReLU(inplace=True),
13 | )
14 |
15 |
16 | class ResNetBackbone(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 | self.base_model = models.resnet50(pretrained=True)
20 | self.base_layers = list(self.base_model.children())
21 |
22 | self.conv_original_size0 = convrelu(3, 64, 3, 1)
23 | self.conv_original_size1 = convrelu(64, 64, 3, 1)
24 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
25 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
26 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
27 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
28 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
29 |
30 | self.strides = [8, 16, 32]
31 | self.num_channels = [512, 1024, 2048]
32 |
33 | def forward(self, inputs):
34 | x_original = self.conv_original_size0(inputs)
35 | x_original = self.conv_original_size1(x_original)
36 | layer0 = self.layer0(inputs)
37 | layer1 = self.layer1(layer0)
38 | layer2 = self.layer2(layer1)
39 | layer3 = self.layer3(layer2)
40 | layer4 = self.layer4(layer3)
41 |
42 | xs = {"0": layer2, "1": layer3, "2": layer4}
43 | all_feats = {'layer0': layer0, 'layer1': layer1, 'layer2': layer2,
44 | 'layer3': layer3, 'layer4': layer4, 'x_original': x_original}
45 |
46 | mask = torch.zeros(inputs.shape)[:, 0, :, :].to(layer4.device)
47 | return xs, mask, all_feats
48 |
49 | def train(self, mode=True):
50 | # Override train so that the training mode is set as we want
51 | nn.Module.train(self, mode)
52 | if mode:
53 | # fix all bn layers
54 | def set_bn_eval(m):
55 | classname = m.__class__.__name__
56 | if classname.find('BatchNorm') != -1:
57 | m.eval()
58 |
59 | self.apply(set_bn_eval)
60 |
61 |
62 | class ResNetUNet(nn.Module):
63 | def __init__(self, n_class, out_dim=None, ms_feat=False):
64 | super().__init__()
65 |
66 | self.return_ms_feat = ms_feat
67 | self.out_dim = out_dim
68 |
69 | self.base_model = models.resnet50(pretrained=True)
70 | self.base_layers = list(self.base_model.children())
71 |
72 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
73 | # self.layer0_1x1 = convrelu(64, 64, 1, 0)
74 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
75 | # self.layer1_1x1 = convrelu(256, 256, 1, 0)
76 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
77 | # self.layer2_1x1 = convrelu(512, 512, 1, 0)
78 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
79 | # self.layer3_1x1 = convrelu(1024, 1024, 1, 0)
80 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
81 | # self.layer4_1x1 = convrelu(2048, 2048, 1, 0)
82 |
83 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
84 |
85 | self.conv_up3 = convrelu(1024 + 2048, 1024, 3, 1)
86 | self.conv_up2 = convrelu(512 + 1024, 512, 3, 1)
87 | self.conv_up1 = convrelu(256 + 512, 256, 3, 1)
88 | self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
89 | # self.conv_up1 = convrelu(512, 256, 3, 1)
90 | # self.conv_up0 = convrelu(256, 128, 3, 1)
91 |
92 | self.conv_original_size0 = convrelu(3, 64, 3, 1)
93 | self.conv_original_size1 = convrelu(64, 64, 3, 1)
94 | self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
95 | # self.conv_last = nn.Conv2d(128, n_class, 1)
96 | self.conv_last = nn.Conv2d(64, n_class, 1)
97 | if out_dim:
98 | self.conv_out = nn.Conv2d(64, out_dim, 1)
99 | # self.conv_out = nn.Conv2d(128, out_dim, 1)
100 |
101 | # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
102 | self.strides = [8, 16, 32]
103 | self.num_channels = [512, 1024, 2048]
104 |
105 | def forward(self, inputs):
106 | x_original = self.conv_original_size0(inputs)
107 | x_original = self.conv_original_size1(x_original)
108 |
109 | layer0 = self.layer0(inputs)
110 | layer1 = self.layer1(layer0)
111 | layer2 = self.layer2(layer1)
112 | layer3 = self.layer3(layer2)
113 | layer4 = self.layer4(layer3)
114 |
115 | # layer4 = self.layer4_1x1(layer4)
116 | x = self.upsample(layer4)
117 | # layer3 = self.layer3_1x1(layer3)
118 | x = torch.cat([x, layer3], dim=1)
119 | x = self.conv_up3(x)
120 | layer3_up = x
121 |
122 | x = self.upsample(x)
123 | # layer2 = self.layer2_1x1(layer2)
124 | x = torch.cat([x, layer2], dim=1)
125 | x = self.conv_up2(x)
126 | layer2_up = x
127 |
128 | x = self.upsample(x)
129 | # layer1 = self.layer1_1x1(layer1)
130 | x = torch.cat([x, layer1], dim=1)
131 | x = self.conv_up1(x)
132 |
133 | x = self.upsample(x)
134 | # layer0 = self.layer0_1x1(layer0)
135 | x = torch.cat([x, layer0], dim=1)
136 | x = self.conv_up0(x)
137 |
138 | x = self.upsample(x)
139 | x = torch.cat([x, x_original], dim=1)
140 | x = self.conv_original_size2(x)
141 |
142 | out = self.conv_last(x)
143 | out = out.sigmoid().squeeze(1)
144 |
145 | # xs = {"0": layer2, "1": layer3, "2": layer4}
146 | xs = {"0": layer2_up, "1": layer3_up, "2": layer4}
147 | mask = torch.zeros(inputs.shape)[:, 0, :, :].to(layer4.device)
148 | # ms_feats = self.ms_feat(xs, mask)
149 |
150 | if self.return_ms_feat:
151 | if self.out_dim:
152 | out_feat = self.conv_out(x)
153 | out_feat = out_feat.permute(0, 2, 3, 1)
154 | return xs, mask, out, out_feat
155 | else:
156 | return xs, mask, out
157 | else:
158 | return out
159 |
160 | def train(self, mode=True):
161 | # Override train so that the training mode is set as we want
162 | nn.Module.train(self, mode)
163 | if mode:
164 | # fix all bn layers
165 | def set_bn_eval(m):
166 | classname = m.__class__.__name__
167 | if classname.find('BatchNorm') != -1:
168 | m.eval()
169 |
170 | self.apply(set_bn_eval)
171 |
--------------------------------------------------------------------------------
/code/learn/models/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 |
5 | def pos_encode_1d(x, d_model=128):
6 | pe = torch.zeros(len(x), d_model, device=x.device)
7 |
8 | div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model))
9 | div_term = div_term.to(x.device)
10 |
11 | pos_w = x.clone().float().unsqueeze(1)
12 |
13 | pe[:, 0:d_model:2] = torch.sin(pos_w * div_term)
14 | pe[:, 1:d_model:2] = torch.cos(pos_w * div_term)
15 |
16 | return pe
17 |
18 |
19 | def pos_encode_2d(x, y, d_model=128):
20 | assert len(x) == len(y)
21 | pe = torch.zeros(len(x), d_model, device=x.device)
22 |
23 | d_model = int(d_model / 2)
24 | div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model))
25 | div_term = div_term.to(x.device)
26 |
27 | pos_w = x.clone().float().unsqueeze(1)
28 | pos_h = y.clone().float().unsqueeze(1)
29 |
30 | pe[:, 0:d_model:2] = torch.sin(pos_w * div_term)
31 | pe[:, 1:d_model:2] = torch.cos(pos_w * div_term)
32 | pe[:, d_model::2] = torch.sin(pos_h * div_term)
33 | pe[:, d_model + 1 :: 2] = torch.cos(pos_h * div_term)
34 |
35 | return pe
36 |
37 |
38 | def get_geom_feats(coords):
39 | (bs, N, _) = coords.shape
40 | _coords = coords.reshape(bs * N, -1)
41 | _geom_enc_a = pos_encode_2d(x=_coords[:, 0], y=_coords[:, 1])
42 | _geom_enc_b = pos_encode_2d(x=_coords[:, 2], y=_coords[:, 3])
43 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1)
44 | geom_feats = _geom_feats.reshape(bs, N, -1)
45 |
46 | return geom_feats
47 |
--------------------------------------------------------------------------------
/code/learn/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | class Timer:
4 |
5 | def __init__(self, name):
6 | self.name = name
7 | self._start = None
8 | self._end = None
9 |
10 |
11 | def __enter__(self):
12 | self._start = time.time()
13 |
14 |
15 | def __exit__(self, *exc_info):
16 | self._end = time.time()
17 |
18 | print('%s: %f' % (self.name, self._end - self._start))
--------------------------------------------------------------------------------
/code/learn/train_order_class.py:
--------------------------------------------------------------------------------
1 | from turtle import pos
2 | import torch
3 | import torch.nn as nn
4 | import os
5 | import time
6 | import numpy as np
7 | import datetime
8 | import argparse
9 | from tqdm import tqdm
10 | from pathlib import Path
11 | from torch.utils.data import DataLoader
12 | from torch.utils.tensorboard import SummaryWriter
13 | from datasets.building_order_enum import (
14 | BuildingCornerDataset,
15 | collate_fn_seq,
16 | get_pixel_features,
17 | )
18 | from models.corner_models import CornerEnum
19 | from models.order_class_models import EdgeTransformer, EdgeTransformer2
20 | from models.loss import CornerCriterion, EdgeCriterion
21 | from models.corner_to_edge import prepare_edge_data
22 | import utils.misc as utils
23 | from utils.nn_utils import pos_encode_2d
24 | import torch.nn.functional as F
25 |
26 | # from infer_full import FloorHEAT
27 |
28 | # for debugging NaNs
29 | # torch.autograd.set_detect_anomaly(True)
30 |
31 |
32 | def get_args_parser():
33 | parser = argparse.ArgumentParser("Set transformer detector", add_help=False)
34 | parser.add_argument("--lr", default=2e-4, type=float)
35 | parser.add_argument("--batch_size", default=128, type=int)
36 | parser.add_argument("--weight_decay", default=1e-5, type=float)
37 | parser.add_argument("--epochs", default=80, type=int)
38 | parser.add_argument("--lr_drop", default=50, type=int)
39 | parser.add_argument(
40 | "--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm"
41 | )
42 |
43 | parser.add_argument("--resume", action="store_true")
44 | parser.add_argument(
45 | "--output_dir",
46 | default="../../ckpts/order_class",
47 | help="path where to save, empty for no saving",
48 | )
49 |
50 | # my own
51 | parser.add_argument("--test_idx", type=int, default=0)
52 | parser.add_argument("--grad_accum", type=int, default=1)
53 | parser.add_argument("--rand_aug", action="store_true", default=True)
54 | parser.add_argument("--last_first", type=bool, default=False)
55 | parser.add_argument("--lock_pos", type=str, default="none")
56 | parser.add_argument("--normalize_by_seq", action="store_true")
57 |
58 | return parser
59 |
60 |
61 | def train_one_epoch(
62 | image_size,
63 | edge_model,
64 | edge_criterion,
65 | data_loader,
66 | optimizer,
67 | writer,
68 | epoch,
69 | max_norm,
70 | args,
71 | ):
72 | # backbone.train()
73 | edge_model.train()
74 | edge_criterion.train()
75 | optimizer.zero_grad()
76 |
77 | acc_avg = 0
78 | loss_avg = 0
79 |
80 | pbar = tqdm(data_loader)
81 | for batch_i, data in enumerate(pbar):
82 | logits, loss, acc = run_model(data, edge_model, epoch, edge_criterion)
83 |
84 | loss = loss / args.grad_accum
85 | loss.backward()
86 |
87 | num_iter = epoch * len(data_loader) + batch_i
88 | writer.add_scalar("train/loss_mb", loss, num_iter)
89 | writer.add_scalar("train/acc_mb", acc, num_iter)
90 | pbar.set_description("Train Loss: %.3f Acc: %.3f" % (loss, acc))
91 |
92 | acc_avg += acc.item()
93 | loss_avg += loss.item()
94 |
95 | if ((batch_i + 1) % args.grad_accum == 0) or (
96 | (batch_i + 1) == len(data_loader)
97 | ):
98 | if max_norm > 0:
99 | # torch.nn.utils.clip_grad_norm_(backbone.parameters(), max_norm)
100 | torch.nn.utils.clip_grad_norm_(edge_model.parameters(), max_norm)
101 |
102 | optimizer.step()
103 | optimizer.zero_grad()
104 |
105 | acc_avg /= len(data_loader)
106 | loss_avg /= len(data_loader)
107 | writer.add_scalar("train/acc", acc_avg, epoch)
108 | writer.add_scalar("train/loss", loss_avg, epoch)
109 |
110 | print("Train loss: %.3f acc: %.3f" % (loss_avg, acc_avg))
111 |
112 | return -1
113 |
114 |
115 | @torch.no_grad()
116 | def evaluate(image_size, edge_model, edge_criterion, data_loader, writer, epoch, args):
117 | # backbone.train()
118 | edge_model.eval()
119 | edge_criterion.eval()
120 |
121 | loss_total = 0
122 | acc_total = 0
123 |
124 | pbar = tqdm(data_loader)
125 | for batch_i, data in enumerate(pbar):
126 | logits, loss, acc = run_model(data, edge_model, epoch, edge_criterion)
127 | pbar.set_description("Eval Loss: %.3f" % loss)
128 |
129 | loss_total += loss.item()
130 | acc_total += acc.item()
131 |
132 | loss_avg = loss_total / len(data_loader)
133 | acc_avg = acc_total / len(data_loader)
134 |
135 | print("Val loss: %.3f acc: %.3f\n" % (loss_avg, acc_avg))
136 | writer.add_scalar("eval/loss", loss_avg, epoch)
137 | writer.add_scalar("eval/acc", acc_avg, epoch)
138 |
139 | return loss_avg, acc_avg
140 |
141 |
142 | def run_model(data, edge_model, epoch, edge_criterion):
143 | for key in data.keys():
144 | if type(data[key]) is torch.Tensor:
145 | data[key] = data[key].cuda()
146 |
147 | # run the edge model
148 | assert (data["label"] >= 0).all() and (data["label"] <= 1).all()
149 |
150 | logits = edge_model(data)
151 | loss = edge_criterion(logits, data["label"])
152 | acc = (logits.argmax(-1) == data["label"]).float().mean()
153 |
154 | return logits, loss, acc
155 |
156 |
157 | def main(args):
158 | DATAPATH = "../../data"
159 | REVIT_ROOT = ""
160 | image_size = 512
161 |
162 | # prepare datasets
163 | train_dataset = BuildingCornerDataset(
164 | DATAPATH,
165 | REVIT_ROOT,
166 | phase="train",
167 | image_size=image_size,
168 | rand_aug=args.rand_aug,
169 | test_idx=args.test_idx,
170 | loss_type="class",
171 | )
172 | train_dataloader = DataLoader(
173 | train_dataset,
174 | batch_size=args.batch_size,
175 | shuffle=True,
176 | num_workers=8,
177 | collate_fn=collate_fn_seq,
178 | )
179 |
180 | test_dataset = BuildingCornerDataset(
181 | DATAPATH,
182 | REVIT_ROOT,
183 | phase="valid",
184 | image_size=image_size,
185 | rand_aug=False,
186 | test_idx=args.test_idx,
187 | loss_type="class",
188 | )
189 | test_dataloader = DataLoader(
190 | test_dataset,
191 | batch_size=args.batch_size,
192 | shuffle=False,
193 | num_workers=8,
194 | collate_fn=collate_fn_seq,
195 | )
196 |
197 | edge_model = EdgeTransformer(d_model=256)
198 | edge_model = edge_model.cuda()
199 |
200 | edge_criterion = nn.CrossEntropyLoss()
201 |
202 | edge_params = [p for p in edge_model.parameters()]
203 |
204 | all_params = edge_params # + backbone_params
205 | optimizer = torch.optim.AdamW(
206 | all_params, lr=args.lr, weight_decay=args.weight_decay
207 | )
208 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
209 | start_epoch = 0
210 |
211 | if args.resume:
212 | ckpt = torch.load(args.resume)
213 | edge_model.load_state_dict(ckpt["edge_model"])
214 | optimizer.load_state_dict(ckpt["optimizer"])
215 | lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
216 | lr_scheduler.step_size = args.lr_drop
217 |
218 | print(
219 | "Resume from ckpt file {}, starting from epoch {}".format(
220 | args.resume, ckpt["epoch"]
221 | )
222 | )
223 | start_epoch = ckpt["epoch"] + 1
224 |
225 | else:
226 | ckpt = torch.load(
227 | "../../ckpts/pretrained/%d/checkpoint_best.pth" % args.test_idx
228 | )
229 |
230 | replacements = [
231 | ["decoder_1", "module.transformer.relational_decoder"],
232 | ["decoder_2", "module.transformer.relational_decoder"],
233 | ]
234 |
235 | edge_model_dict = edge_model.state_dict()
236 | for key in edge_model_dict.keys():
237 | replaced = False
238 | for old, new in replacements:
239 | if old in key:
240 | assert not replaced
241 | new_key = key.replace(old, new)
242 | edge_model_dict[key] = ckpt["edge_model"][new_key]
243 | replaced = True
244 | print(key)
245 |
246 | edge_model.load_state_dict(edge_model_dict)
247 | print("Resume from pre-trained checkpoints")
248 |
249 | n_edge_parameters = sum(p.numel() for p in edge_params if p.requires_grad)
250 | n_all_parameters = sum(p.numel() for p in all_params if p.requires_grad)
251 | print("number of trainable edge params:", n_edge_parameters)
252 | print("number of all trainable params:", n_all_parameters)
253 |
254 | print("Start training")
255 | start_time = time.time()
256 |
257 | output_dir = Path("%s/%d" % (args.output_dir, args.test_idx))
258 | if not os.path.exists(output_dir):
259 | os.makedirs(output_dir)
260 |
261 | # prepare summary writer
262 | writer = SummaryWriter(log_dir=output_dir)
263 |
264 | best_acc = 0
265 | for epoch in range(start_epoch, args.epochs):
266 | print("Epoch: %d" % epoch)
267 | train_one_epoch(
268 | image_size,
269 | edge_model,
270 | edge_criterion,
271 | train_dataloader,
272 | optimizer,
273 | writer,
274 | epoch,
275 | args.clip_max_norm,
276 | args,
277 | )
278 | lr_scheduler.step()
279 |
280 | val_loss, val_acc = evaluate(
281 | image_size,
282 | edge_model,
283 | edge_criterion,
284 | test_dataloader,
285 | writer,
286 | epoch,
287 | args,
288 | )
289 |
290 | if val_acc > best_acc:
291 | is_best = True
292 | best_acc = val_acc
293 | else:
294 | is_best = False
295 |
296 | if args.output_dir:
297 | checkpoint_paths = [output_dir / ("checkpoint_latest.pth")]
298 | checkpoint_paths.append(output_dir / ("checkpoint_%03d.pth" % epoch))
299 | if is_best:
300 | checkpoint_paths.append(output_dir / "checkpoint_best.pth")
301 |
302 | for checkpoint_path in checkpoint_paths:
303 | torch.save(
304 | {
305 | "edge_model": edge_model.state_dict(),
306 | "optimizer": optimizer.state_dict(),
307 | "lr_scheduler": lr_scheduler.state_dict(),
308 | "epoch": epoch,
309 | "args": args,
310 | "val_acc": val_acc,
311 | },
312 | checkpoint_path,
313 | )
314 |
315 | total_time = time.time() - start_time
316 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
317 | print("Training time {}".format(total_time_str))
318 |
319 |
320 | if __name__ == "__main__":
321 | parser = argparse.ArgumentParser(
322 | "GeoVAE training and evaluation script", parents=[get_args_parser()]
323 | )
324 | args = parser.parse_args()
325 | main(args)
326 |
--------------------------------------------------------------------------------
/code/learn/train_order_metric.py:
--------------------------------------------------------------------------------
1 | from turtle import pos
2 | import torch
3 | import torch.nn as nn
4 | import os
5 | import time
6 | import numpy as np
7 | import datetime
8 | import argparse
9 | from tqdm import tqdm
10 | from pathlib import Path
11 | from torch.utils.data import DataLoader
12 | from torch.utils.tensorboard import SummaryWriter
13 | from datasets.building_order_enum import (
14 | BuildingCornerDataset,
15 | collate_fn_seq,
16 | get_pixel_features,
17 | )
18 | from models.corner_models import CornerEnum
19 | from models.order_metric_models import EdgeTransformer, EdgeTransformer2
20 | from models.corner_to_edge import prepare_edge_data
21 | import utils.misc as utils
22 | from utils.nn_utils import pos_encode_2d
23 | import torch.nn.functional as F
24 |
25 | # from infer_full import FloorHEAT
26 |
27 | # for debugging NaNs
28 | # torch.autograd.set_detect_anomaly(True)
29 |
30 |
31 | def get_args_parser():
32 | parser = argparse.ArgumentParser("Set transformer detector", add_help=False)
33 | parser.add_argument("--lr", default=2e-4, type=float)
34 | parser.add_argument("--batch_size", default=128, type=int)
35 | parser.add_argument("--weight_decay", default=1e-5, type=float)
36 | parser.add_argument("--epochs", default=100, type=int)
37 | parser.add_argument("--lr_drop", default=50, type=int)
38 | parser.add_argument(
39 | "--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm"
40 | )
41 |
42 | parser.add_argument("--resume", action="store_true")
43 | parser.add_argument(
44 | "--output_dir",
45 | default="../../ckpts/order_metric",
46 | help="path where to save, empty for no saving",
47 | )
48 |
49 | # my own
50 | parser.add_argument("--test_idx", type=int, default=0)
51 | parser.add_argument("--grad_accum", type=int, default=1)
52 | parser.add_argument("--rand_aug", action="store_true", default=True)
53 | parser.add_argument("--last_first", type=bool, default=False)
54 | parser.add_argument("--lock_pos", type=str, default="none")
55 | parser.add_argument("--normalize_by_seq", action="store_true")
56 |
57 | return parser
58 |
59 |
60 | def train_one_epoch(
61 | image_size,
62 | edge_model,
63 | loss_fn,
64 | miner,
65 | dist_fn,
66 | data_loader,
67 | optimizer,
68 | writer,
69 | epoch,
70 | max_norm,
71 | args,
72 | ):
73 | # backbone.train()
74 | edge_model.train()
75 | loss_fn.train()
76 | optimizer.zero_grad()
77 |
78 | acc_avg = 0
79 | loss_avg = 0
80 |
81 | pbar = tqdm(data_loader)
82 | for batch_i, data in enumerate(pbar):
83 | logits, loss, acc = run_model(data, edge_model, epoch, loss_fn, miner, dist_fn)
84 |
85 | loss = loss / args.grad_accum
86 | loss.backward()
87 |
88 | num_iter = epoch * len(data_loader) + batch_i
89 | writer.add_scalar("train/loss_mb", loss, num_iter)
90 | writer.add_scalar("train/acc_mb", acc, num_iter)
91 | pbar.set_description("Train Loss: %.3f Acc: %.3f" % (loss, acc))
92 |
93 | acc_avg += acc.item()
94 | loss_avg += loss.item()
95 |
96 | if ((batch_i + 1) % args.grad_accum == 0) or (
97 | (batch_i + 1) == len(data_loader)
98 | ):
99 | if max_norm > 0:
100 | # torch.nn.utils.clip_grad_norm_(backbone.parameters(), max_norm)
101 | torch.nn.utils.clip_grad_norm_(edge_model.parameters(), max_norm)
102 |
103 | optimizer.step()
104 | optimizer.zero_grad()
105 |
106 | acc_avg /= len(data_loader)
107 | loss_avg /= len(data_loader)
108 | writer.add_scalar("train/acc", acc_avg, epoch)
109 | writer.add_scalar("train/loss", acc_avg, epoch)
110 |
111 | print("Train loss: %.3f acc: %.3f" % (loss_avg, acc_avg))
112 |
113 | return -1
114 |
115 |
116 | @torch.no_grad()
117 | def evaluate(
118 | image_size,
119 | edge_model,
120 | loss_fn,
121 | miner,
122 | dist_fn,
123 | data_loader,
124 | writer,
125 | epoch,
126 | args,
127 | ):
128 | # backbone.train()
129 | edge_model.eval()
130 | loss_fn.eval()
131 |
132 | loss_total = 0
133 | acc_total = 0
134 |
135 | pbar = tqdm(data_loader)
136 | for batch_i, data in enumerate(pbar):
137 | logits, loss, acc = run_model(data, edge_model, epoch, loss_fn, miner, dist_fn)
138 | pbar.set_description("Eval Loss: %.3f" % loss)
139 |
140 | loss_total += loss.item()
141 | acc_total += acc.item()
142 |
143 | loss_avg = loss_total / len(data_loader)
144 | acc_avg = acc_total / len(data_loader)
145 |
146 | print("Val loss: %.3f acc: %.3f\n" % (loss_avg, acc_avg))
147 | writer.add_scalar("eval/loss", loss_avg, epoch)
148 | writer.add_scalar("eval/acc", acc_avg, epoch)
149 |
150 | return loss_avg, acc_avg
151 |
152 |
153 | def run_model(data, edge_model, epoch, loss_fn, miner, dist_fn):
154 | for key in data.keys():
155 | if type(data[key]) is torch.Tensor:
156 | data[key] = data[key].cuda()
157 |
158 | # run the edge model
159 | embeddings = edge_model(data)
160 |
161 | # ignore the provided edges
162 | all_neg = []
163 | all_anc = []
164 | all_pos = []
165 |
166 | total_loss = 0
167 | total_acc = 0
168 |
169 | for ex_i in range(len(embeddings)):
170 | mask = ~data["edge_coords_mask"][ex_i]
171 |
172 | _embeddings = embeddings[ex_i][mask]
173 | _label = data["edge_label"][ex_i][mask]
174 |
175 | neg = _embeddings[_label == 0]
176 | anc = _embeddings[_label == 1].repeat(len(neg), 1)
177 | pos = _embeddings[_label == 2].repeat(len(neg), 1)
178 |
179 | all_neg.append(neg)
180 | all_anc.append(anc)
181 | all_pos.append(pos)
182 |
183 | anc2pos = dist_fn(anc, pos)
184 | anc2neg = dist_fn(anc, neg)
185 | total_acc += (anc2pos < anc2neg).all().float()
186 |
187 | all_neg = torch.cat(all_neg, dim=0)
188 | all_anc = torch.cat(all_anc, dim=0)
189 | all_pos = torch.cat(all_pos, dim=0)
190 |
191 | total_loss = loss_fn(all_anc, all_pos, all_neg)
192 | total_loss /= len(embeddings)
193 |
194 | total_acc = total_acc / len(embeddings)
195 |
196 | return embeddings, total_loss, total_acc
197 |
198 |
199 | def main(args):
200 | DATAPATH = "../../data"
201 | REVIT_ROOT = ""
202 | image_size = 512
203 |
204 | # prepare datasets
205 | train_dataset = BuildingCornerDataset(
206 | DATAPATH,
207 | REVIT_ROOT,
208 | phase="train",
209 | image_size=image_size,
210 | rand_aug=args.rand_aug,
211 | test_idx=args.test_idx,
212 | loss_type="metric",
213 | )
214 | train_dataloader = DataLoader(
215 | train_dataset,
216 | batch_size=args.batch_size,
217 | shuffle=True,
218 | num_workers=8,
219 | collate_fn=collate_fn_seq,
220 | )
221 |
222 | # for blah in train_dataset:
223 | # pass
224 |
225 | test_dataset = BuildingCornerDataset(
226 | DATAPATH,
227 | REVIT_ROOT,
228 | phase="valid",
229 | image_size=image_size,
230 | rand_aug=False,
231 | test_idx=args.test_idx,
232 | loss_type="metric",
233 | )
234 | test_dataloader = DataLoader(
235 | test_dataset,
236 | batch_size=args.batch_size,
237 | shuffle=False,
238 | num_workers=8,
239 | collate_fn=collate_fn_seq,
240 | )
241 |
242 | edge_model = EdgeTransformer(d_model=256)
243 | edge_model = edge_model.cuda()
244 |
245 | # loss_fn = losses.TripletMarginLoss(distance=dist_fn)
246 | # miner = miners.MultiSimilarityMiner()
247 | dist_fn = lambda x, y: 1.0 - F.cosine_similarity(x, y)
248 | loss_fn = nn.TripletMarginWithDistanceLoss(
249 | distance_function=dist_fn, reduction="sum"
250 | )
251 | miner = None
252 |
253 | edge_params = [p for p in edge_model.parameters()]
254 |
255 | all_params = edge_params # + backbone_params
256 | optimizer = torch.optim.AdamW(
257 | all_params, lr=args.lr, weight_decay=args.weight_decay
258 | )
259 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
260 | start_epoch = 0
261 |
262 | if args.resume:
263 | ckpt = torch.load(args.resume)
264 | edge_model.load_state_dict(ckpt["edge_model"])
265 | optimizer.load_state_dict(ckpt["optimizer"])
266 | lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
267 | lr_scheduler.step_size = args.lr_drop
268 |
269 | print(
270 | "Resume from ckpt file {}, starting from epoch {}".format(
271 | args.resume, ckpt["epoch"]
272 | )
273 | )
274 | start_epoch = ckpt["epoch"] + 1
275 |
276 | else:
277 | ckpt = torch.load(
278 | "../../ckpts/pretrained/%d/checkpoint_best.pth" % args.test_idx
279 | )
280 |
281 | replacements = [
282 | ["decoder_1", "module.transformer.relational_decoder"],
283 | ["decoder_2", "module.transformer.relational_decoder"],
284 | ]
285 |
286 | edge_model_dict = edge_model.state_dict()
287 | for key in edge_model_dict.keys():
288 | replaced = False
289 | for old, new in replacements:
290 | if old in key:
291 | assert not replaced
292 | new_key = key.replace(old, new)
293 | edge_model_dict[key] = ckpt["edge_model"][new_key]
294 | replaced = True
295 | print(key)
296 |
297 | edge_model.load_state_dict(edge_model_dict)
298 | print("Resume from pre-trained checkpoints")
299 |
300 | n_edge_parameters = sum(p.numel() for p in edge_params if p.requires_grad)
301 | n_all_parameters = sum(p.numel() for p in all_params if p.requires_grad)
302 | print("number of trainable edge params:", n_edge_parameters)
303 | print("number of all trainable params:", n_all_parameters)
304 |
305 | print("Start training")
306 | start_time = time.time()
307 |
308 | output_dir = Path("%s/%d" % (args.output_dir, args.test_idx))
309 | if not os.path.exists(output_dir):
310 | os.makedirs(output_dir)
311 |
312 | # prepare summary writer
313 | writer = SummaryWriter(log_dir=output_dir)
314 |
315 | best_acc = 0
316 | for epoch in range(start_epoch, args.epochs):
317 | print("Epoch: %d" % epoch)
318 | train_one_epoch(
319 | image_size,
320 | edge_model,
321 | loss_fn,
322 | miner,
323 | dist_fn,
324 | train_dataloader,
325 | optimizer,
326 | writer,
327 | epoch,
328 | args.clip_max_norm,
329 | args,
330 | )
331 | lr_scheduler.step()
332 |
333 | # is_best = False
334 | # val_acc = 0
335 | val_loss, val_acc = evaluate(
336 | image_size,
337 | edge_model,
338 | loss_fn,
339 | miner,
340 | dist_fn,
341 | test_dataloader,
342 | writer,
343 | epoch,
344 | args,
345 | )
346 |
347 | if val_acc > best_acc:
348 | is_best = True
349 | best_acc = val_acc
350 | else:
351 | is_best = False
352 |
353 | if args.output_dir:
354 | checkpoint_paths = [output_dir / ("checkpoint_latest.pth")]
355 | # checkpoint_paths.append(output_dir / ("checkpoint_%03d.pth" % epoch))
356 | if is_best:
357 | checkpoint_paths.append(output_dir / "checkpoint_best.pth")
358 |
359 | for checkpoint_path in checkpoint_paths:
360 | torch.save(
361 | {
362 | "edge_model": edge_model.state_dict(),
363 | "optimizer": optimizer.state_dict(),
364 | "lr_scheduler": lr_scheduler.state_dict(),
365 | "epoch": epoch,
366 | "args": args,
367 | "val_acc": val_acc,
368 | },
369 | checkpoint_path,
370 | )
371 |
372 | total_time = time.time() - start_time
373 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
374 | print("Training time {}".format(total_time_str))
375 |
376 |
377 | if __name__ == "__main__":
378 | parser = argparse.ArgumentParser(
379 | "GeoVAE training and evaluation script", parents=[get_args_parser()]
380 | )
381 | args = parser.parse_args()
382 | main(args)
383 |
--------------------------------------------------------------------------------
/code/learn/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/code/learn/utils/__init__.py
--------------------------------------------------------------------------------
/code/learn/utils/_init_paths.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 |
4 |
5 | def add_path(path):
6 | if path not in sys.path: sys.path.insert(0, path)
7 |
8 |
9 | this_dir = osp.dirname(__file__)
10 |
11 | project_path = osp.abspath(osp.join(this_dir, '..'))
12 | add_path(project_path)
--------------------------------------------------------------------------------
/code/learn/utils/geometry_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | def building_metric(logits, label):
7 | preds = torch.argmax(logits, dim=-1)
8 | true_ids = torch.where(label==1)
9 | num_true = true_ids[0].shape[0]
10 | tp = (preds[true_ids] == 1).sum().double()
11 | recall = tp / num_true
12 | prec = tp / (preds == 1).sum()
13 | fscore = 2 * recall * prec / (prec + recall)
14 | return recall, prec, fscore
15 |
16 |
17 | def edge_acc(logits, label, lengths, gt_values):
18 | all_acc = list()
19 | for i in range(logits.shape[0]):
20 | length = lengths[i]
21 | gt_value = gt_values[i, :length]
22 | pred_idx = torch.where(gt_value == 2)
23 | if len(pred_idx[0]) == 0:
24 | continue
25 | else:
26 | preds = torch.argmax(logits[i, :, :length][:, pred_idx[0]], dim=0)
27 | gts = label[i, :length][pred_idx[0]]
28 | pos_ids = torch.where(gts == 1)
29 | correct = (preds[pos_ids] == gts[pos_ids]).sum().float()
30 | num_pos_gt = len(pos_ids[0])
31 | recall = correct / num_pos_gt if num_pos_gt > 0 else torch.tensor(0)
32 | num_pos_pred = (preds == 1).sum().float()
33 | prec = correct / num_pos_pred if num_pos_pred > 0 else torch.tensor(0)
34 | f_score = 2.0 * prec * recall / (recall + prec + 1e-8)
35 | f_score = f_score.cpu()
36 | all_acc.append(f_score)
37 | if len(all_acc) > 1:
38 | all_acc = torch.stack(all_acc, 0)
39 | avg_acc = all_acc.mean()
40 | else:
41 | avg_acc = all_acc[0]
42 | return avg_acc
43 |
44 |
45 | def corner_eval(targets, outputs):
46 | assert isinstance(targets, np.ndarray)
47 | assert isinstance(outputs, np.ndarray)
48 | output_to_gt = dict()
49 | gt_to_output = dict()
50 | for target_i, target in enumerate(targets):
51 | dist = (outputs - target) ** 2
52 | dist = np.sqrt(dist.sum(axis=-1))
53 | min_dist = dist.min()
54 | min_idx = dist.argmin()
55 | if min_dist < 5 and min_idx not in output_to_gt: # a positive match
56 | output_to_gt[min_idx] = target_i
57 | gt_to_output[target_i] = min_idx
58 | tp = len(output_to_gt)
59 | prec = tp / len(outputs)
60 | recall = tp / len(targets)
61 | return prec, recall
62 |
63 |
64 | def rectify_data(image, annot):
65 | rows, cols, ch = image.shape
66 | bins = [0 for _ in range(180)] # 5 degree per bin
67 | # edges vote for directions
68 |
69 | gauss_weights = [0.1, 0.2, 0.5, 1, 0.5, 0.2, 0.1]
70 |
71 | for src, connections in annot.items():
72 | for end in connections:
73 | edge = [(end[0] - src[0]), -(end[1] - src[1])]
74 | edge_len = np.sqrt(edge[0] ** 2 + edge[1] ** 2)
75 | if edge_len <= 10: # skip too short edges
76 | continue
77 | if edge[0] == 0:
78 | bin_id = 90
79 | else:
80 | theta = np.arctan(edge[1] / edge[0]) / np.pi * 180
81 | if edge[0] * edge[1] < 0:
82 | theta += 180
83 | bin_id = int(theta.round())
84 | if bin_id == 180:
85 | bin_id = 0
86 | for offset in range(-3, 4):
87 | bin_idx = bin_id + offset
88 | if bin_idx >= 180:
89 | bin_idx -= 180
90 | bins[bin_idx] += np.sqrt(edge[1] ** 2 + edge[0] ** 2) * gauss_weights[offset + 2]
91 |
92 | bins = np.array(bins)
93 | sorted_ids = np.argsort(bins)[::-1]
94 | bin_1 = sorted_ids[0]
95 | remained_ids = [idx for idx in sorted_ids if angle_dist(bin_1, idx) >= 30]
96 | bin_2 = remained_ids[0]
97 | if bin_1 < bin_2:
98 | bin_1, bin_2 = bin_2, bin_1
99 |
100 | dir_1, dir_2 = bin_1, bin_2
101 | # compute the affine parameters, and apply affine transform to the image
102 | origin = [127, 127]
103 | p1_old = [127 + 100 * np.cos(dir_1 / 180 * np.pi), 127 - 100 * np.sin(dir_1 / 180 * np.pi)]
104 | p2_old = [127 + 100 * np.cos(dir_2 / 180 * np.pi), 127 - 100 * np.sin(dir_2 / 180 * np.pi)]
105 | pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
106 | p1_new = [127, 27] # y_axis
107 | p2_new = [227, 127] # x_axis
108 | pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
109 |
110 | M1 = cv2.getAffineTransform(pts1, pts2)
111 |
112 | all_corners = list(annot.keys())
113 | all_corners_ = np.array(all_corners)
114 | ones = np.ones([all_corners_.shape[0], 1])
115 | all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
116 | new_corners = np.matmul(M1, all_corners_.T).T
117 |
118 | M = np.concatenate([M1, np.array([[0, 0, 1]])], axis=0)
119 |
120 | x_max = new_corners[:, 0].max()
121 | x_min = new_corners[:, 0].min()
122 | y_max = new_corners[:, 1].max()
123 | y_min = new_corners[:, 1].min()
124 |
125 | side_x = (x_max - x_min) * 0.1
126 | side_y = (y_max - y_min) * 0.1
127 | right_border = x_max + side_x
128 | left_border = x_min - side_x
129 | bot_border = y_max + side_y
130 | top_border = y_min - side_y
131 | pts1 = np.array([[left_border, top_border], [right_border, top_border], [right_border, bot_border]]).astype(
132 | np.float32)
133 | pts2 = np.array([[5, 5], [250, 5], [250, 250]]).astype(np.float32)
134 | M_scale = cv2.getAffineTransform(pts1, pts2)
135 |
136 | M = np.matmul(np.concatenate([M_scale, np.array([[0, 0, 1]])], axis=0), M)
137 |
138 | new_image = cv2.warpAffine(image, M[:2, :], (cols, rows), borderValue=(255, 255, 255))
139 | all_corners_ = np.concatenate([all_corners, ones], axis=-1)
140 | new_corners = np.matmul(M[:2, :], all_corners_.T).T
141 |
142 | corner_mapping = dict()
143 | for idx, corner in enumerate(all_corners):
144 | corner_mapping[corner] = new_corners[idx]
145 |
146 | new_annot = dict()
147 | for corner, connections in annot.items():
148 | new_corner = corner_mapping[corner]
149 | tuple_new_corner = tuple(new_corner)
150 | new_annot[tuple_new_corner] = list()
151 | for to_corner in connections:
152 | new_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
153 |
154 | # do the affine transform
155 | return new_image, new_annot, M
156 |
157 |
158 | def angle_dist(a1, a2):
159 | if a1 > a2:
160 | a1, a2 = a2, a1
161 | d1 = a2 - a1
162 | d2 = a1 + 180 - a2
163 | dist = min(d1, d2)
164 | return dist
165 |
--------------------------------------------------------------------------------
/code/learn/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | from collections import defaultdict, deque
4 | import datetime
5 | from typing import Optional, List
6 | from torch import Tensor
7 |
8 |
9 | @torch.no_grad()
10 | def accuracy(output, target, topk=(1,)):
11 | """Computes the precision@k for the specified values of k"""
12 | if target.numel() == 0:
13 | return [torch.zeros([], device=output.device)]
14 | maxk = max(topk)
15 | batch_size = target.size(0)
16 |
17 | _, pred = output.topk(maxk, 1, True, True)
18 | pred = pred.t()
19 | correct = pred.eq(target.view(1, -1).expand_as(pred))
20 |
21 | res = []
22 | for k in topk:
23 | correct_k = correct[:k].view(-1).float().sum(0)
24 | res.append(correct_k.mul_(100.0 / batch_size))
25 | return res
26 |
27 |
28 | class SmoothedValue(object):
29 | """Track a series of values and provide access to smoothed values over a
30 | window or the global series average.
31 | """
32 |
33 | def __init__(self, window_size=100, fmt=None):
34 | if fmt is None:
35 | fmt = "{median:.3f} ({global_avg:.3f})"
36 | self.deque = deque(maxlen=window_size)
37 | self.total = 0.0
38 | self.count = 0
39 | self.fmt = fmt
40 |
41 | def update(self, value, n=1):
42 | self.deque.append(value)
43 | self.count += n
44 | self.total += value * n
45 |
46 | @property
47 | def median(self):
48 | d = torch.tensor(list(self.deque))
49 | return d.median().item()
50 |
51 | @property
52 | def avg(self):
53 | d = torch.tensor(list(self.deque), dtype=torch.float32)
54 | return d.mean().item()
55 |
56 | @property
57 | def global_avg(self):
58 | return self.total / self.count
59 |
60 | @property
61 | def max(self):
62 | return max(self.deque)
63 |
64 | @property
65 | def value(self):
66 | #return self.deque[-1]
67 | return self.avg
68 |
69 | def __str__(self):
70 | return self.fmt.format(
71 | median=self.median,
72 | avg=self.avg,
73 | global_avg=self.global_avg,
74 | max=self.max,
75 | value=self.value)
76 |
77 |
78 | class MetricLogger(object):
79 | def __init__(self, delimiter="\t"):
80 | self.meters = defaultdict(SmoothedValue)
81 | self.delimiter = delimiter
82 |
83 | def update(self, **kwargs):
84 | for k, v in kwargs.items():
85 | if isinstance(v, torch.Tensor):
86 | v = v.item()
87 | assert isinstance(v, (float, int))
88 | self.meters[k].update(v)
89 |
90 | def __getattr__(self, attr):
91 | if attr in self.meters:
92 | return self.meters[attr]
93 | if attr in self.__dict__:
94 | return self.__dict__[attr]
95 | raise AttributeError("'{}' object has no attribute '{}'".format(
96 | type(self).__name__, attr))
97 |
98 | def __str__(self):
99 | loss_str = []
100 | for name, meter in self.meters.items():
101 | loss_str.append(
102 | "{}: {}".format(name, str(meter))
103 | )
104 | return self.delimiter.join(loss_str)
105 |
106 | def synchronize_between_processes(self):
107 | for meter in self.meters.values():
108 | meter.synchronize_between_processes()
109 |
110 | def add_meter(self, name, meter):
111 | self.meters[name] = meter
112 |
113 | def log_every(self, iterable, print_freq, header=None, length_total=None):
114 | i = 0
115 | if length_total is None:
116 | length_total = len(iterable)
117 | if not header:
118 | header = ''
119 | start_time = time.time()
120 | end = time.time()
121 | iter_time = SmoothedValue(fmt='{avg:.4f}')
122 | data_time = SmoothedValue(fmt='{avg:.4f}')
123 | space_fmt = ':' + str(len(str(length_total))) + 'd'
124 | if torch.cuda.is_available():
125 | log_msg = self.delimiter.join([
126 | header,
127 | '[{0' + space_fmt + '}/{1}]',
128 | 'eta: {eta}',
129 | '{meters}',
130 | 'time: {time}',
131 | 'data: {data}',
132 | 'max mem: {memory:.0f}'
133 | ])
134 | else:
135 | log_msg = self.delimiter.join([
136 | header,
137 | '[{0' + space_fmt + '}/{1}]',
138 | 'eta: {eta}',
139 | '{meters}',
140 | 'time: {time}',
141 | 'data: {data}'
142 | ])
143 | MB = 1024.0 * 1024.0
144 | for obj in iterable:
145 | data_time.update(time.time() - end)
146 | yield obj
147 | iter_time.update(time.time() - end)
148 | if i % print_freq == 0 or i == length_total - 1:
149 | eta_seconds = iter_time.global_avg * (length_total - i)
150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
151 | if torch.cuda.is_available():
152 | try:
153 | print(log_msg.format(
154 | i, length_total, eta=eta_string,
155 | meters=str(self),
156 | time=str(iter_time), data=str(data_time),
157 | memory=torch.cuda.max_memory_allocated() / MB))
158 | except Exception as e:
159 | import pdb; pdb.set_trace()
160 | else:
161 | print(log_msg.format(
162 | i, length_total, eta=eta_string,
163 | meters=str(self),
164 | time=str(iter_time), data=str(data_time)))
165 | i += 1
166 | end = time.time()
167 | total_time = time.time() - start_time
168 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
169 | print('{} Total time: {} ({:.4f} s / it)'.format(
170 | header, total_time_str, total_time / length_total))
171 |
172 |
173 | class NestedTensor(object):
174 | def __init__(self, tensors, mask: Optional[Tensor]):
175 | self.tensors = tensors
176 | self.mask = mask
177 |
178 | def to(self, device, non_blocking=False):
179 | # type: (Device) -> NestedTensor # noqa
180 | cast_tensor = self.tensors.to(device, non_blocking=non_blocking)
181 | mask = self.mask
182 | if mask is not None:
183 | assert mask is not None
184 | cast_mask = mask.to(device, non_blocking=non_blocking)
185 | else:
186 | cast_mask = None
187 | return NestedTensor(cast_tensor, cast_mask)
188 |
189 | def record_stream(self, *args, **kwargs):
190 | self.tensors.record_stream(*args, **kwargs)
191 | if self.mask is not None:
192 | self.mask.record_stream(*args, **kwargs)
193 |
194 | def decompose(self):
195 | return self.tensors, self.mask
196 |
197 | def __repr__(self):
198 | return str(self.tensors)
199 |
--------------------------------------------------------------------------------
/code/learn/utils/nn_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 |
5 | def pos_encode_2d(x, y, d_model=128):
6 | assert len(x) == len(y)
7 | pe = torch.zeros(len(x), d_model, device=x.device)
8 |
9 | d_model = int(d_model / 2)
10 | div_term = torch.exp(torch.arange(0., d_model, 2) *
11 | -(math.log(10000.0) / d_model))
12 | div_term = div_term.to(x.device)
13 |
14 | pos_w = torch.tensor(x).float().unsqueeze(1)
15 | pos_h = torch.tensor(y).float().unsqueeze(1)
16 |
17 | pe[:, 0:d_model:2] = torch.sin(pos_w * div_term)
18 | pe[:, 1:d_model:2] = torch.cos(pos_w * div_term)
19 | pe[:, d_model::2] = torch.sin(pos_h * div_term)
20 | pe[:, d_model+1::2] = torch.cos(pos_h * div_term)
21 |
22 | return pe
23 |
24 |
25 | def positional_encoding_2d(d_model, height, width):
26 | """
27 | :param d_model: dimension of the model
28 | :param height: height of the positions
29 | :param width: width of the positions
30 | :return: d_model*height*width position matrix
31 | """
32 | if d_model % 4 != 0:
33 | raise ValueError("Cannot use sin/cos positional encoding with "
34 | "odd dimension (got dim={:d})".format(d_model))
35 | pe = torch.zeros(d_model, height, width)
36 | # Each dimension use half of d_model
37 | d_model = int(d_model / 2)
38 | div_term = torch.exp(torch.arange(0., d_model, 2) *
39 | -(math.log(10000.0) / d_model))
40 | pos_w = torch.arange(0., width).unsqueeze(1)
41 | pos_h = torch.arange(0., height).unsqueeze(1)
42 | pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
43 | pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
44 | pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
45 | pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
46 |
47 | return pe
48 |
49 |
50 | def positional_encoding_1d(d_model, length):
51 | """
52 | :param d_model: dimension of the model
53 | :param length: length of positions
54 | :return: length*d_model position matrix
55 | """
56 | if d_model % 2 != 0:
57 | raise ValueError("Cannot use sin/cos positional encoding with "
58 | "odd dim (got dim={:d})".format(d_model))
59 | pe = torch.zeros(length, d_model)
60 | position = torch.arange(0, length).unsqueeze(1)
61 | div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
62 | -(math.log(10000.0) / d_model)))
63 | pe[:, 0::2] = torch.sin(position.float() * div_term)
64 | pe[:, 1::2] = torch.cos(position.float() * div_term)
65 |
66 | return pe
67 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Cython==0.29.22
2 | defusedxml==0.6.0
3 | einops==0.4.1
4 | future==0.18.2
5 | imageio==2.16.1
6 | numpy==1.20.1
7 | opencv-python==4.4.0.44
8 | matplotlib==3.3.4
9 | packaging==20.9
10 | Pillow==9.0.1
11 | prometheus-client==0.9.0
12 | prompt-toolkit==3.0.16
13 | ptyprocess==0.7.0
14 | pycparser==2.20
15 | Pygments==2.8.0
16 | python-dateutil==2.8.1
17 | scikit-image==0.19.2
18 | scikit-learn==1.0
19 | scipy==1.6.1
20 | six==1.15.0
21 | cairosvg==2.5.2
22 | svgwrite==1.4.2
23 | shapely==1.8.2
--------------------------------------------------------------------------------
/resources/addin_warning.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/resources/addin_warning.png
--------------------------------------------------------------------------------
/resources/crop.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/resources/crop.gif
--------------------------------------------------------------------------------
/resources/rotate.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/resources/rotate.gif
--------------------------------------------------------------------------------
/resources/translate.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/resources/translate.gif
--------------------------------------------------------------------------------
/setup_env.sh:
--------------------------------------------------------------------------------
1 | # for automatic conda environment setup
2 | . "$HOME/miniconda3/etc/profile.d/conda.sh"
3 | conda remove -y -n bim --all
4 | conda create -y -n bim python=3.8
5 | conda activate bim
6 |
7 | conda install -y pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
8 |
9 | pip install -r requirements.txt
10 |
11 | cd code/learn/models/ops
12 | python setup.py build install
13 |
14 | conda install -y tqdm shapely
15 | pip install tensorboard rtree shapely pytorch-metric-learning laspy[lazrs] open3d typer[all]
--------------------------------------------------------------------------------