├── .gitignore ├── QUICKSTART.md ├── README.md ├── code ├── learn │ ├── backend.py │ ├── datasets │ │ ├── __init__.py │ │ ├── _init_paths.py │ │ ├── building_ae.py │ │ ├── building_corners.py │ │ ├── building_corners_full.py │ │ ├── building_full.py │ │ ├── building_order_enum.py │ │ └── data_util.py │ ├── eval_order.py │ ├── metrics │ │ ├── get_metric.py │ │ └── new_utils.py │ ├── models │ │ ├── __init__.py │ │ ├── _init_paths.py │ │ ├── corner_models.py │ │ ├── corner_to_edge.py │ │ ├── dat_blocks.py │ │ ├── deformable_transformer.py │ │ ├── deformable_transformer_full.py │ │ ├── deformable_transformer_original.py │ │ ├── drn.py │ │ ├── edge_full_models.py │ │ ├── edge_models.py │ │ ├── full_models.py │ │ ├── loss.py │ │ ├── mlp.py │ │ ├── ops │ │ │ ├── .gitignore │ │ │ ├── functions │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn_func.py │ │ │ ├── make.sh │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn.py │ │ │ ├── setup.py │ │ │ ├── src │ │ │ │ ├── cpu │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ ├── cuda │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ ├── ms_deform_attn.h │ │ │ │ └── vision.cpp │ │ │ └── test.py │ │ ├── order_class_models.py │ │ ├── order_metric_models.py │ │ ├── stacked_hg.py │ │ ├── tcn.py │ │ ├── unet.py │ │ └── utils.py │ ├── my_utils.py │ ├── timer.py │ ├── train_corner.py │ ├── train_edge.py │ ├── train_metric.py │ ├── train_order_class.py │ ├── train_order_metric.py │ └── utils │ │ ├── __init__.py │ │ ├── _init_paths.py │ │ ├── geometry_utils.py │ │ ├── misc.py │ │ └── nn_utils.py └── preprocess │ └── data_gen.py ├── requirements.txt ├── resources ├── addin_warning.png ├── crop.gif ├── rotate.gif └── translate.gif └── setup_env.sh /.gitignore: -------------------------------------------------------------------------------- 1 | ./data/ 2 | ./ckpts/ 3 | data 4 | data_small/ 5 | ckpts 6 | *_full 7 | __pycache__/ 8 | .vscode/ 9 | *.swp 10 | *.zip 11 | fid_csv/ 12 | -------------------------------------------------------------------------------- /QUICKSTART.md: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | ## Plugin installation (Windows machine) 4 | 5 | Download the quickstart package [here](https://www.dropbox.com/scl/fi/jljkehuddx3df6hf6ptau/quickstart.zip?rlkey=bzxi1b13r00s6u29drkziazgv&dl=0), and unzip to somewhere on your computer. Yours should look like this: 6 | ``` 7 | . 8 | ├── ckpts/ 9 | ├── code/ 10 | ├── data/ 11 | ├── plugins/ 12 | ├── plugins_built/ 13 | └── ... 14 | ``` 15 | 16 | Install the Revit plugin by copying the content of `plugins_built/` to: 17 | ``` 18 | %APPDATA%\Autodesk\Revit\Addins\2022 19 | ``` 20 | You can navigate to that folder by copying the path above and paste into File Explorer's address bar. Change the year number at the end if you have a different Revit version. 21 | 22 | Start up Revit. 23 | You should see the warning below when you start up Revit for the first time after installing the plugin. 24 | Click "Always Load" if you don't want to see this warning again. 25 | If you do not see the warning, then the plugin has not been installed correctly. 26 | 27 | ![Addin warning](resources/addin_warning.png) 28 | 29 | ## Server setup (Windows or Linux) 30 | 31 | We use Miniconda 3 to manage the python environment, installed at `$HOME` directory. 32 | Installation instructions for Miniconda can be found [here](https://docs.conda.io/projects/miniconda/en/latest/miniconda-install.html). 33 | 34 | Once you have Miniconda installed, run the provided script: `sh ./setup_env.sh`. 35 | 36 | ## Using our provided example point cloud 37 | 38 | If you would like to use your own point cloud, skip to the section below. 39 | 40 | Open up `data/revit_projects/32_ShortOffice_05_F2.rvt` inside Revit. 41 | 42 | Now you need to link the provided ReCap project, which contains the point cloud. 43 | 44 | From the top ribbon menu, click on `Insert -> Manage Links -> Point Clouds -> 32_ShortOffice... -> Reload From...`. 45 | 46 | Open `data/recap_projects/32_ShortOffice_05_F2/32_ShortOffice_05_F2_s0p01m.rcp`. 47 | 48 | Finally, click "OK" and you should see the point cloud. 49 | 50 | Switch to the server machine, run the following commands to start the server: 51 | ``` 52 | cd code/learn 53 | conda activate bim 54 | python backend.py demo-floor 55 | ``` 56 | 57 | You are now ready to use the assistive system. 58 | 59 | ## Using your own point cloud 60 | 61 | Using your point cloud involves a few more steps, as the backend server needs to understand 62 | the relationship between its internal and the Revit coordinate system. 63 | 64 | As prerequisites, your point cloud should be in LAZ format and also in a ReCap project. 65 | Please prepare the ReCap project before proceeding to the next step. 66 | 67 | Open up `data/revit_projects/example.rvt` inside Revit. 68 | 69 | Import the point cloud from the top ribbon menu: `Insert -> Point Clouds`. 70 | 71 | Transform the point cloud inside Revit such that the ground plane is at Level 1 and that majority of the walls are axis-aligned. Please see the below GIFs for how to do the transforms. 72 | 73 |
74 | Translate ground plane to level 1 75 | 76 | ![Translate ground plane](resources/translate.gif) 77 | 78 |
79 | 80 |
81 | 82 | Rotate walls so they are axis-aligned 83 | 84 | ![Axis-aligned rotation](resources/rotate.gif) 85 | 86 |
87 | 88 |
89 | 90 | Use section box to define the rough bounding box of point cloud. Section box parameters 91 | are used by backend server for processing later on. 92 | 93 | ![Section box](resources/crop.gif) 94 | 95 | Finally, save out the transforms by clicking from the top ribbion menu: `Add-Ins -> Save Transform`. You should see a file named `transform.txt` inside the same folder as your Revit project. If your server machine is different from your client machine, please upload the transform text file to somewhere on your server. 96 | 97 | Switch to the server machine, run the following commands to start the server: 98 | ``` 99 | cd code/learn 100 | conda activate bim 101 | python backend.py demo-user \ 102 | --floor-name \ 103 | --laz-f \ 104 | --laz-transform-f \ 105 | --corner-ckpt-f ../../ckpts/corner/11/checkpoint.pth \ 106 | --edge-ckpt-f ../../ckpts/edge_sample_16/11/checkpoint.pth \ 107 | --metric-ckpt-f ../../ckpts/order_metric/11/checkpoint_latest.pth 108 | ``` 109 | 110 | Replace `<...>` with your own paramenters. Wait until all preprocessing is done, it may take 10 minutes depending on your computer hardware. Once it says "Server listening", you are now ready to use the assistive system. 111 | 112 | ## Using the assistive system 113 | 114 | To make the plugin easier to use, you should setup some keyboard shortcuts. 115 | 116 | From the top-left, click on `File -> Options -> User Interface -> Keyboard Shortcuts: Customize -> Filter: Add-Ins Tab`. 117 | Bind the following commands to the corresponding keys: 118 | 119 | | Command | Shortcuts | 120 | |----------------:|:---| 121 | |Obtain Prediction| F2 | 122 | |Add Corner | F3 | 123 | |Send Corners | F4 | 124 | |Next / Accept Green | 4 | 125 | |Reject / Accept Yellow | 5 | 126 | |Accept Red | 6 | 127 | 128 | To enable assistance, first go to the `Add-Ins` tab and click on `Autocomplete`. 129 | The icon should change into a pause logo. 130 | 131 | Now manually draw one wall, hit the Escape key twice to exit the "Modify Wall" mode, and you should see the next three suggested walls in solid (next) and dashed (subsequent) pink lines. 132 | 133 | Run `Next / Accept Green` command to accept the solid pink line as the next wall to add. 134 | You may interleave manual drawing or accept command however you like. 135 | 136 | You may also run `Reject / Accept Yellow` to choose one of three candidate walls to add next. 137 | Run the corresponding command to accept the colored suggestion. 138 | 139 | To simplify wall drawing, one may also provide wall junctions and query the backend to automatically infer relevant walls. 140 | This has the benefit of adding multiple walls at once, especially around higher-degree junctions. 141 | 142 | To do so: 143 | 1. Hover the mouse over the junction in the point cloud, and hit the `Add Corner` shortcut. 144 | 2. (Optional) Drag the ring to modify its location. 145 | 3. Once the desired junctions are added, hit `Send Corners`. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A-Scan2BIM 2 | 3 | Official implementation of the paper [A-Scan2BIM: Assistive Scan to Building Information Modeling](https://drive.google.com/file/d/1zvGfdlLYbd_oAp7Oc-1vF2Czhl1A7q75/view) (__BMVC 2023, oral__) 4 | 5 | Please also visit our [project website](https://a-scan2bim.github.io/) for a video demo of our assistive system. 6 | 7 | # Updates 8 | 9 | [06/14/2024] 10 | Updated our assistive demo so it accepts user-provided point cloud. Please visit [here](QUICKSTART.md) for instructions. 11 | 12 | [12/08/2023] 13 | We are unable to release data for two floors, as the point clouds are unfortunately not available for download. 14 | 15 | [11/30/2023] 16 | Due to a bug in our heuristic baseline method, we have updated our order evaluations. Please see the updated arXiv paper for the latest results. 17 | 18 | 19 | # Table of contents (also TODOs) 20 | 21 | - [x] [Prerequisites](#prerequisites) 22 | - [x] [Quickstart](#quickstart) 23 | - [x] [Training and evaluation](#training-and-evaluation) 24 | - [ ] [Collecting your own data](#collecting-your-own-data) 25 | - [ ] [Building and developing the plugin](#building-and-developing-the-plugin) 26 | - [x] [Contact](#contact) 27 | - [x] [Bibtex](#bibtex) 28 | - [x] [Acknowledgment](#acknowledgment) 29 | 30 | # Prerequisites 31 | 32 | For more flexibility, our system is designed as a client-server model: a Revit plugin written in C# serves as the front end, and a python server which performs all the neural network computation. 33 | 34 | To run our system, there are two options: 35 | 36 | 1. Run both the plugin and the server on the same machine 37 | 2. Run the server on a separate machine, and port-forward to client via local network or ssh 38 | 39 | The advantage of option 2 over 1 is that the server can run on either Linux or Window machines, and can potentially serve multiple clients at the same time. 40 | Of course, the client machine needs to be Windows as Revit is Windows-only. 41 | 42 | The code has been tested with Revit 2022. 43 | Other versions will likely work but are not verified. 44 | If you are a student or an educator, you can get it for free [here](https://www.autodesk.com/education/edu-software/overview?sorting=featured&filters=individual). 45 | 46 | # Quickstart 47 | 48 | To try out our assistive system, please visit [here](QUICKSTART.md) for instructions. 49 | 50 | # Training and evaluation 51 | 52 | At a high-level, our system consists of two components: 53 | candidate wall enumeration network, 54 | and next wall prediction network. 55 | They are trained stand-alone, which we will describe the process step-by-step. 56 | 57 | ## Data/environment preparation 58 | 59 | First, fill out the data request form [here](https://forms.gle/Apg86MauTep2KTxx8). 60 | We will be in contact with you to provide the data download links. 61 | Once you have the data, unzip it in the root directory of the repository. 62 | 63 | Since we do not own the point clouds, please download them from the [workshop website](https://cv4aec.github.io/). You would need the data from the [3D challenge](https://codalab.lisn.upsaclay.fr/competitions/12405), both the train and test data. Rename and move all the LAZ files to the data folder like below: 64 | 65 | ``` 66 | data/ 67 | ├── history/ 68 | ├── transforms/ 69 | ├── all_floors.txt 70 | └── laz/ 71 | ├── 05_MedOffice_01_F2.laz 72 | ├── 06_MedOffice_02_F1.laz 73 | ├── 07_MedOffice_03_F3.laz 74 | ├── 08_ShortOffice_01_F2.laz 75 | ├── 11_MedOffice_05_F2.laz 76 | ├── 11_MedOffice_05_F4.laz 77 | ├── 19_MedOffice_07_F4.laz 78 | ├── 25_Parking_01_F1.laz 79 | ├── 32_ShortOffice_05_F1.laz 80 | ├── 32_ShortOffice_05_F2.laz 81 | ├── 32_ShortOffice_05_F3.laz 82 | ├── 33_SmallBuilding_03_F1.laz 83 | ├── 35_Lab_02_F1.laz 84 | └── 35_Lab_02_F2.laz 85 | ``` 86 | 87 | To install all python dependencies, run the provided script: `sh ./setup_env.sh`. 88 | 89 | Then run the following command to preprocess the data: 90 | ``` 91 | cd code/preprocess 92 | python data_gen.py 93 | ``` 94 | 95 | Also please download pretrained models from [here](https://www.dropbox.com/scl/fi/cwhgu92a6ndl212nls59i/ckpts_full.zip?rlkey=pabethcn0w0rxqk0k5x1da0dv&dl=0) and extract to the root directory of the repository. 96 | Some weights are necessary for network initialization (`pretrained/`) or evaluation (`ae/`), but others you may delete if training from scratch. 97 | 98 | ## Training candidate wall enumerator 99 | 100 | We borrow the HEAT architecture, which consists of two components: corner detector, and edge classifier. 101 | However in our case, we do not train end-to-end due to the large input size. 102 | 103 | ``` 104 | cd code/learn 105 | 106 | # Step 1: train corner detector 107 | for i in {0..15} 108 | do 109 | python train_corner.py --test_idx i 110 | done 111 | 112 | # Step 2: cache detected corners to disk 113 | python backend.py export_corners 114 | 115 | # Step 3: train edge classifier 116 | for i in {0..15} 117 | do 118 | python train_edge.py --test_idx i 119 | done 120 | 121 | # Step 4: cache detected edges to disk 122 | python backend.py save_edge_preds 123 | ``` 124 | 125 | Note that we do leave-one-out (per-floor) cross-validation, hence the for loops for step 1 and 3. 126 | 127 | ## Training next wall predictor 128 | 129 | Our next wall predictor along with the classifier baseline method are both trained on GT walls and order. 130 | 131 | Training is initialized from pretrained weights of the previous candidate wall enumeration task (variant where only one reference point is used). 132 | 133 | We have provided the pretrained weights for your convenience (see above for link to pretrained checkpoints), but they may not be necessary if training for your own task. 134 | 135 | ``` 136 | cd code/learn 137 | 138 | # Training our method (metric-learning based) 139 | for i in {0..15} 140 | do 141 | python train_order_metric.py --test_idx i 142 | done 143 | 144 | # Training classifier baseline 145 | for i in {0..15} 146 | do 147 | python train_order_class.py --test_idx i 148 | done 149 | ``` 150 | 151 | ## Evaluation 152 | 153 | To evaluate reconstruction metrics: 154 | ``` 155 | python backend.py compute-metrics 156 | ``` 157 | 158 | To evaluate order metrics: 159 | ``` 160 | python eval_order.py eval-all-floors 161 | python eval_order.py plot-seq-FID 162 | ``` 163 | 164 | To evaluate entropy and accuracy of next wall prediction: 165 | ``` 166 | python eval_order.py eval-entropy 167 | python eval_order.py eval-all-acc-wrt-history 168 | ``` 169 | 170 | # Collecting your own data 171 | 172 | Coming soon... 173 | 174 | # Building and developing the plugin 175 | 176 | Coming soon... 177 | 178 | # Contact 179 | 180 | Weilian Song, weilians@sfu.ca 181 | 182 | # Bibtex 183 | ``` 184 | @article{weilian2023ascan2bim, 185 | author = {Song, Weilian and Luo, Jieliang and Zhao, Dale and Fu, Yan and Cheng, Chin-Yi and Furukawa, Yasutaka}, 186 | title = {A-Scan2BIM: Assistive Scan to Building Information Modeling}, 187 | journal = {British Machine Vision Conference (BMVC)}, 188 | year = {2023}, 189 | } 190 | ``` 191 | 192 | # Acknowledgment 193 | 194 | This research is partially supported by NSERC Discovery Grants with Accelerator Supplements and the DND/NSERC Discovery Grant Supplement, NSERC Alliance Grants, and the John R. Evans Leaders Fund (JELF). 195 | 196 | We are also grateful to the [CV4AEC CVPR workshop](https://cv4aec.github.io/) for providing the point clouds. 197 | 198 | And finally, much of the plugin code was borrowed from Jeremy Tammik's sample add-in [WinTooltip](https://github.com/jeremytammik/WinTooltip). 199 | Hats off :tophat: to Jeremy for all the wonderful tutorials he has written over the years. -------------------------------------------------------------------------------- /code/learn/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/code/learn/datasets/__init__.py -------------------------------------------------------------------------------- /code/learn/datasets/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | 5 | def add_path(path): 6 | if path not in sys.path: sys.path.insert(0, path) 7 | 8 | 9 | this_dir = osp.dirname(__file__) 10 | 11 | project_path = osp.abspath(osp.join(this_dir, '..')) 12 | add_path(project_path) -------------------------------------------------------------------------------- /code/learn/datasets/data_util.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFilter 2 | from torchvision import transforms 3 | 4 | 5 | def RandomBlur(radius=2.): 6 | blur = GaussianBlur(radius=radius) 7 | full_transform = transforms.RandomApply([blur], p=.3) 8 | return full_transform 9 | 10 | 11 | class ImageFilterTransform(object): 12 | 13 | def __init__(self): 14 | raise NotImplementedError 15 | 16 | def __call__(self, img): 17 | return img.filter(self.filter) 18 | 19 | 20 | class GaussianBlur(ImageFilterTransform): 21 | 22 | def __init__(self, radius=2.): 23 | self.filter = ImageFilter.GaussianBlur(radius=radius) 24 | -------------------------------------------------------------------------------- /code/learn/metrics/get_metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | import cv2 5 | from metrics.new_utils import * 6 | 7 | 8 | class Metric: 9 | def calc(self, gt_data, conv_data, thresh=8.0, iou_thresh=0.7): 10 | ### compute corners precision/recall 11 | gts = gt_data["corners"] 12 | dets = conv_data["corners"] 13 | 14 | per_sample_corner_tp = 0.0 15 | per_sample_corner_fp = 0.0 16 | per_sample_corner_length = gts.shape[0] 17 | found = [False] * gts.shape[0] 18 | c_det_annot = {} 19 | 20 | # for each corner detection 21 | for i, det in enumerate(dets): 22 | # get closest gt 23 | near_gt = [0, 999999.0, (0.0, 0.0)] 24 | for k, gt in enumerate(gts): 25 | dist = np.linalg.norm(gt - det) 26 | if dist < near_gt[1]: 27 | near_gt = [k, dist, gt] 28 | if near_gt[1] <= thresh and not found[near_gt[0]]: 29 | per_sample_corner_tp += 1.0 30 | found[near_gt[0]] = True 31 | c_det_annot[i] = near_gt[0] 32 | else: 33 | per_sample_corner_fp += 1.0 34 | 35 | per_corner_score = { 36 | "recall": per_sample_corner_tp / gts.shape[0], 37 | "precision": per_sample_corner_tp 38 | / (per_sample_corner_tp + per_sample_corner_fp + 1e-8), 39 | } 40 | 41 | ### compute edges precision/recall 42 | per_sample_edge_tp = 0.0 43 | per_sample_edge_fp = 0.0 44 | edge_corner_annots = gt_data["edges"] 45 | per_sample_edge_length = edge_corner_annots.shape[0] 46 | 47 | false_edge_ids = [] 48 | match_pred_ids = set() 49 | 50 | for l, e_det in enumerate(conv_data["edges"]): 51 | c1, c2 = e_det 52 | 53 | # check if corners are mapped 54 | if (c1 not in c_det_annot.keys()) or (c2 not in c_det_annot.keys()): 55 | per_sample_edge_fp += 1.0 56 | false_edge_ids.append(l) 57 | continue 58 | # check hit 59 | c1_prime = c_det_annot[c1] 60 | c2_prime = c_det_annot[c2] 61 | is_hit = False 62 | 63 | for k, e_annot in enumerate(edge_corner_annots): 64 | c3, c4 = e_annot 65 | if ((c1_prime == c3) and (c2_prime == c4)) or ( 66 | (c1_prime == c4) and (c2_prime == c3) 67 | ): 68 | is_hit = True 69 | match_pred_ids.add(l) 70 | break 71 | 72 | # hit 73 | if is_hit: 74 | per_sample_edge_tp += 1.0 75 | else: 76 | per_sample_edge_fp += 1.0 77 | false_edge_ids.append(l) 78 | 79 | per_edge_score = { 80 | "recall": per_sample_edge_tp / edge_corner_annots.shape[0], 81 | "precision": per_sample_edge_tp 82 | / (per_sample_edge_tp + per_sample_edge_fp + 1e-8), 83 | } 84 | 85 | # computer regions precision/recall 86 | conv_mask = render( 87 | corners=conv_data["corners"], 88 | edges=conv_data["edges"], 89 | render_pad=0, 90 | edge_linewidth=1, 91 | )[0] 92 | conv_mask = 1 - conv_mask 93 | conv_mask = conv_mask.astype(np.uint8) 94 | labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4) 95 | 96 | # cv2.imwrite('mask-pred.png', region_mask.astype(np.uint8) * 20) 97 | 98 | background_label = region_mask[0, 0] 99 | all_conv_masks = [] 100 | for region_i in range(1, labels): 101 | if region_i == background_label: 102 | continue 103 | the_region = region_mask == region_i 104 | if the_region.sum() < 20: 105 | continue 106 | all_conv_masks.append(the_region) 107 | 108 | gt_mask = render( 109 | corners=gt_data["corners"], 110 | edges=gt_data["edges"], 111 | render_pad=0, 112 | edge_linewidth=1, 113 | )[0] 114 | gt_mask = 1 - gt_mask 115 | gt_mask = gt_mask.astype(np.uint8) 116 | labels, region_mask = cv2.connectedComponents(gt_mask, connectivity=4) 117 | 118 | # cv2.imwrite('mask-gt.png', region_mask.astype(np.uint8) * 20) 119 | 120 | background_label = region_mask[0, 0] 121 | all_gt_masks = [] 122 | for region_i in range(1, labels): 123 | if region_i == background_label: 124 | continue 125 | the_region = region_mask == region_i 126 | if the_region.sum() < 20: 127 | continue 128 | all_gt_masks.append(the_region) 129 | 130 | per_sample_region_tp = 0.0 131 | per_sample_region_fp = 0.0 132 | per_sample_region_length = len(all_gt_masks) 133 | found = [False] * len(all_gt_masks) 134 | for i, r_det in enumerate(all_conv_masks): 135 | # gt closest gt 136 | near_gt = [0, 0, None] 137 | for k, r_gt in enumerate(all_gt_masks): 138 | iou = np.logical_and(r_gt, r_det).sum() / float( 139 | np.logical_or(r_gt, r_det).sum() 140 | ) 141 | if iou > near_gt[1]: 142 | near_gt = [k, iou, r_gt] 143 | if near_gt[1] >= iou_thresh and not found[near_gt[0]]: 144 | per_sample_region_tp += 1.0 145 | found[near_gt[0]] = True 146 | else: 147 | per_sample_region_fp += 1.0 148 | 149 | # per_region_score = { 150 | # 'recall': per_sample_region_tp / len(all_gt_masks), 151 | # 'precision': per_sample_region_tp / (per_sample_region_tp + per_sample_region_fp + 1e-8) 152 | # } 153 | per_region_score = {"recall": 1.0, "precision": 1.0} 154 | 155 | return { 156 | "corner_tp": per_sample_corner_tp, 157 | "corner_fp": per_sample_corner_fp, 158 | "corner_length": per_sample_corner_length, 159 | "edge_tp": per_sample_edge_tp, 160 | "edge_fp": per_sample_edge_fp, 161 | "edge_length": per_sample_edge_length, 162 | "region_tp": per_sample_region_tp, 163 | "region_fp": per_sample_region_fp, 164 | "region_length": per_sample_region_length, 165 | "corner": per_corner_score, 166 | "edge": per_edge_score, 167 | "region": per_region_score, 168 | # for visualizing edges 169 | "matched_pred_edge_ids": list(match_pred_ids), 170 | "false_pred_edge_ids": false_edge_ids, 171 | } 172 | 173 | def calc_corner(self, gt_data, conv_data, thresh=8.0, iou_thresh=0.7): 174 | ### compute corners precision/recall 175 | gts = gt_data["corners"] 176 | dets = conv_data["corners"] 177 | 178 | per_sample_corner_tp = 0.0 179 | per_sample_corner_fp = 0.0 180 | per_sample_corner_length = gts.shape[0] 181 | found = [False] * gts.shape[0] 182 | c_det_annot = {} 183 | 184 | # for each corner detection 185 | for i, det in enumerate(dets): 186 | # get closest gt 187 | near_gt = [0, 999999.0, (0.0, 0.0)] 188 | for k, gt in enumerate(gts): 189 | dist = np.linalg.norm(gt - det) 190 | if dist < near_gt[1]: 191 | near_gt = [k, dist, gt] 192 | if near_gt[1] <= thresh and not found[near_gt[0]]: 193 | per_sample_corner_tp += 1.0 194 | found[near_gt[0]] = True 195 | c_det_annot[i] = near_gt[0] 196 | else: 197 | per_sample_corner_fp += 1.0 198 | 199 | per_corner_score = { 200 | "recall": per_sample_corner_tp / gts.shape[0], 201 | "precision": per_sample_corner_tp 202 | / (per_sample_corner_tp + per_sample_corner_fp + 1e-8), 203 | } 204 | 205 | return { 206 | "corner_tp": per_sample_corner_tp, 207 | "corner_fp": per_sample_corner_fp, 208 | "corner_length": per_sample_corner_length, 209 | "corner": per_corner_score, 210 | } 211 | 212 | 213 | def compute_metrics(gt_data, pred_data, thresh=8.0): 214 | metric = Metric() 215 | score = metric.calc(gt_data, pred_data, thresh=thresh) 216 | return score 217 | 218 | 219 | def get_recall_and_precision(tp, fp, length): 220 | recall = tp / (length + 1e-8) 221 | precision = tp / (tp + fp + 1e-8) 222 | return recall, precision 223 | 224 | 225 | if __name__ == "__main__": 226 | base_path = "./" 227 | gt_datapath = "../data/cities_dataset/annot" 228 | metric = Metric() 229 | corner_tp = 0.0 230 | corner_fp = 0.0 231 | corner_length = 0.0 232 | edge_tp = 0.0 233 | edge_fp = 0.0 234 | edge_length = 0.0 235 | region_tp = 0.0 236 | region_fp = 0.0 237 | region_length = 0.0 238 | for file_name in os.listdir(base_path): 239 | if len(file_name) < 10: 240 | continue 241 | f = open(os.path.join(base_path, file_name), "rb") 242 | gt_data = np.load( 243 | os.path.join(gt_datapath, file_name + ".npy"), allow_pickle=True 244 | ).tolist() 245 | candidate = pickle.load(f) 246 | conv_corners = candidate.graph.getCornersArray() 247 | conv_edges = candidate.graph.getEdgesArray() 248 | conv_data = {"corners": conv_corners, "edges": conv_edges} 249 | score = metric.calc(gt_data, conv_data) 250 | corner_tp += score["corner_tp"] 251 | corner_fp += score["corner_fp"] 252 | corner_length += score["corner_length"] 253 | edge_tp += score["edge_tp"] 254 | edge_fp += score["edge_fp"] 255 | edge_length += score["edge_length"] 256 | region_tp += score["region_tp"] 257 | region_fp += score["region_fp"] 258 | region_length += score["region_length"] 259 | 260 | f = open(os.path.join(base_path, "score.txt"), "w") 261 | # corner 262 | recall, precision = get_recall_and_precision(corner_tp, corner_fp, corner_length) 263 | f_score = 2.0 * precision * recall / (recall + precision + 1e-8) 264 | print( 265 | "corners - precision: %.3f recall: %.3f f_score: %.3f" 266 | % (precision, recall, f_score) 267 | ) 268 | f.write( 269 | "corners - precision: %.3f recall: %.3f f_score: %.3f\n" 270 | % (precision, recall, f_score) 271 | ) 272 | 273 | # edge 274 | recall, precision = get_recall_and_precision(edge_tp, edge_fp, edge_length) 275 | f_score = 2.0 * precision * recall / (recall + precision + 1e-8) 276 | print( 277 | "edges - precision: %.3f recall: %.3f f_score: %.3f" 278 | % (precision, recall, f_score) 279 | ) 280 | f.write( 281 | "edges - precision: %.3f recall: %.3f f_score: %.3f\n" 282 | % (precision, recall, f_score) 283 | ) 284 | 285 | # region 286 | recall, precision = get_recall_and_precision(region_tp, region_fp, region_length) 287 | f_score = 2.0 * precision * recall / (recall + precision + 1e-8) 288 | print( 289 | "regions - precision: %.3f recall: %.3f f_score: %.3f" 290 | % (precision, recall, f_score) 291 | ) 292 | f.write( 293 | "regions - precision: %.3f recall: %.3f f_score: %.3f\n" 294 | % (precision, recall, f_score) 295 | ) 296 | 297 | f.close() 298 | -------------------------------------------------------------------------------- /code/learn/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/code/learn/models/__init__.py -------------------------------------------------------------------------------- /code/learn/models/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | 5 | def add_path(path): 6 | if path not in sys.path: sys.path.insert(0, path) 7 | 8 | 9 | this_dir = osp.dirname(__file__) 10 | 11 | project_path = osp.abspath(osp.join(this_dir, '..')) 12 | add_path(project_path) -------------------------------------------------------------------------------- /code/learn/models/corner_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | from models.deformable_transformer_original import ( 7 | DeformableTransformerEncoderLayer, 8 | DeformableTransformerEncoder, 9 | DeformableTransformerDecoder, 10 | DeformableTransformerDecoderLayer, 11 | DeformableAttnDecoderLayer, 12 | ) 13 | from models.ops.modules import MSDeformAttn 14 | from models.mlp import MLP 15 | from models.unet import convrelu 16 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ 17 | from einops import rearrange, repeat 18 | from einops.layers.torch import Rearrange 19 | from utils.misc import NestedTensor 20 | import scipy.ndimage.filters as filters 21 | 22 | 23 | class CornerEnum(nn.Module): 24 | def __init__( 25 | self, 26 | input_dim, 27 | hidden_dim, 28 | num_feature_levels, 29 | backbone_strides, 30 | backbone_num_channels, 31 | ): 32 | super(CornerEnum, self).__init__() 33 | self.input_dim = input_dim 34 | self.hidden_dim = hidden_dim 35 | self.num_feature_levels = num_feature_levels 36 | 37 | if num_feature_levels > 1: 38 | num_backbone_outs = len(backbone_strides) 39 | input_proj_list = [] 40 | for _ in range(num_backbone_outs): 41 | in_channels = backbone_num_channels[_] 42 | input_proj_list.append( 43 | nn.Sequential( 44 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1), 45 | nn.GroupNorm(32, hidden_dim), 46 | ) 47 | ) 48 | for _ in range(num_feature_levels - num_backbone_outs): 49 | input_proj_list.append( 50 | nn.Sequential( 51 | nn.Conv2d( 52 | in_channels, hidden_dim, kernel_size=3, stride=2, padding=1 53 | ), 54 | nn.GroupNorm(32, hidden_dim), 55 | ) 56 | ) 57 | in_channels = hidden_dim 58 | self.input_proj = nn.ModuleList(input_proj_list) 59 | else: 60 | self.input_proj = nn.ModuleList( 61 | [ 62 | nn.Sequential( 63 | nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1), 64 | nn.GroupNorm(32, hidden_dim), 65 | ) 66 | ] 67 | ) 68 | 69 | self.patch_size = 4 70 | patch_dim = (self.patch_size**2) * input_dim 71 | self.to_patch_embedding = nn.Sequential( 72 | Rearrange( 73 | "b (h p1) (w p2) c -> b (h w) (p1 p2 c)", 74 | p1=self.patch_size, 75 | p2=self.patch_size, 76 | ), 77 | nn.Linear(patch_dim, input_dim), 78 | nn.Linear(input_dim, hidden_dim), 79 | ) 80 | 81 | self.pixel_pe_fc = nn.Linear(input_dim, hidden_dim) 82 | self.transformer = CornerTransformer( 83 | d_model=hidden_dim, 84 | nhead=8, 85 | num_encoder_layers=1, 86 | dim_feedforward=1024, 87 | dropout=0.1, 88 | ) 89 | 90 | self.img_pos = PositionEmbeddingSine(hidden_dim // 2) 91 | 92 | @staticmethod 93 | def get_ms_feat(xs, img_mask): 94 | out: Dict[str, NestedTensor] = {} 95 | # out = list() 96 | for name, x in sorted(xs.items()): 97 | m = img_mask 98 | assert m is not None 99 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 100 | out[name] = NestedTensor(x, mask) 101 | # out.append(NestedTensor(x, mask)) 102 | return out 103 | 104 | @staticmethod 105 | def get_decoder_reference_points(height, width, device): 106 | ref_y, ref_x = torch.meshgrid( 107 | torch.linspace( 108 | 0.5, height - 0.5, height, dtype=torch.float32, device=device 109 | ), 110 | torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device), 111 | ) 112 | ref_y = ref_y.reshape(-1)[None] / height 113 | ref_x = ref_x.reshape(-1)[None] / width 114 | ref = torch.stack((ref_x, ref_y), -1) 115 | return ref 116 | 117 | def forward(self, image_feats, feat_mask, pixels_feat, pixels, all_image_feats): 118 | # process image features 119 | features = self.get_ms_feat(image_feats, feat_mask) 120 | 121 | srcs = [] 122 | masks = [] 123 | all_pos = [] 124 | 125 | new_features = list() 126 | for name, x in sorted(features.items()): 127 | new_features.append(x) 128 | features = new_features 129 | 130 | for l, feat in enumerate(features): 131 | src, mask = feat.decompose() 132 | mask = mask.to(src.device) 133 | srcs.append(self.input_proj[l](src)) 134 | pos = self.img_pos(src).to(src.dtype) 135 | all_pos.append(pos) 136 | masks.append(mask) 137 | assert mask is not None 138 | 139 | if self.num_feature_levels > len(srcs): 140 | _len_srcs = len(srcs) 141 | for l in range(_len_srcs, self.num_feature_levels): 142 | if l == _len_srcs: 143 | src = self.input_proj[l](features[-1].tensors) 144 | else: 145 | src = self.input_proj[l](srcs[-1]) 146 | m = feat_mask 147 | mask = ( 148 | F.interpolate(m[None].float(), size=src.shape[-2:]) 149 | .to(torch.bool)[0] 150 | .to(src.device) 151 | ) 152 | pos_l = self.img_pos(src).to(src.dtype) 153 | srcs.append(src) 154 | masks.append(mask) 155 | all_pos.append(pos_l) 156 | 157 | sp_inputs = self.to_patch_embedding(pixels_feat) 158 | 159 | # compute the reference points 160 | H_tgt = W_tgt = int(np.sqrt(sp_inputs.shape[1])) 161 | reference_points_s1 = self.get_decoder_reference_points( 162 | H_tgt, W_tgt, sp_inputs.device 163 | ) 164 | rp_all_pixels = self.get_decoder_reference_points(256, 256, sp_inputs.device) 165 | 166 | corner_logits = self.transformer( 167 | srcs, masks, all_pos, sp_inputs, reference_points_s1, all_image_feats 168 | ) 169 | return corner_logits 170 | 171 | 172 | class PositionEmbeddingSine(nn.Module): 173 | """ 174 | This is a more standard version of the position embedding, very similar to the one 175 | used by the Attention is all you need paper, generalized to work on images. 176 | """ 177 | 178 | def __init__( 179 | self, num_pos_feats=64, temperature=10000, normalize=False, scale=None 180 | ): 181 | super().__init__() 182 | self.num_pos_feats = num_pos_feats 183 | self.temperature = temperature 184 | self.normalize = normalize 185 | if scale is not None and normalize is False: 186 | raise ValueError("normalize should be True if scale is passed") 187 | if scale is None: 188 | scale = 2 * math.pi 189 | self.scale = scale 190 | 191 | def forward(self, x): 192 | mask = torch.zeros([x.shape[0], x.shape[2], x.shape[3]]).bool().to(x.device) 193 | not_mask = ~mask 194 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 195 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 196 | if self.normalize: 197 | eps = 1e-6 198 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 199 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 200 | 201 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 202 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 203 | 204 | pos_x = x_embed[:, :, :, None] / dim_t 205 | pos_y = y_embed[:, :, :, None] / dim_t 206 | pos_x = torch.stack( 207 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 208 | ).flatten(3) 209 | pos_y = torch.stack( 210 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 211 | ).flatten(3) 212 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 213 | return pos 214 | 215 | 216 | class CornerTransformer(nn.Module): 217 | def __init__( 218 | self, 219 | d_model=512, 220 | nhead=8, 221 | num_encoder_layers=6, 222 | dim_feedforward=1024, 223 | dropout=0.1, 224 | activation="relu", 225 | return_intermediate_dec=False, 226 | num_feature_levels=4, 227 | dec_n_points=4, 228 | enc_n_points=4, 229 | ): 230 | super(CornerTransformer, self).__init__() 231 | 232 | encoder_layer = DeformableTransformerEncoderLayer( 233 | d_model, 234 | dim_feedforward, 235 | dropout, 236 | activation, 237 | num_feature_levels, 238 | nhead, 239 | enc_n_points, 240 | ) 241 | self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) 242 | 243 | decoder_attn_layer = DeformableAttnDecoderLayer( 244 | d_model, 245 | dim_feedforward, 246 | dropout, 247 | activation, 248 | num_feature_levels, 249 | nhead, 250 | dec_n_points, 251 | ) 252 | self.per_edge_decoder = DeformableTransformerDecoder( 253 | decoder_attn_layer, 1, False, with_sa=False 254 | ) 255 | 256 | self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) 257 | 258 | # upconv layers 259 | self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 260 | self.conv_up1 = convrelu(256 + 256, 256, 3, 1) 261 | self.conv_up0 = convrelu(64 + 256, 128, 3, 1) 262 | self.conv_original_size2 = convrelu(64 + 128, d_model, 3, 1) 263 | self.output_fc_1 = nn.Linear(d_model, 1) 264 | self.output_fc_2 = nn.Linear(d_model, 1) 265 | 266 | self._reset_parameters() 267 | 268 | def _reset_parameters(self): 269 | for p in self.parameters(): 270 | if p.dim() > 1: 271 | nn.init.xavier_uniform_(p) 272 | for m in self.modules(): 273 | if isinstance(m, MSDeformAttn): 274 | m._reset_parameters() 275 | # xavier_uniform_(self.reference_points.weight.data, gain=1.0) 276 | # constant_(self.reference_points.bias.data, 0.) 277 | normal_(self.level_embed) 278 | 279 | def get_valid_ratio(self, mask): 280 | _, H, W = mask.shape 281 | valid_H = torch.sum(~mask[:, :, 0], 1) 282 | valid_W = torch.sum(~mask[:, 0, :], 1) 283 | valid_ratio_h = valid_H.float() / H 284 | valid_ratio_w = valid_W.float() / W 285 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) 286 | return valid_ratio 287 | 288 | def forward( 289 | self, srcs, masks, pos_embeds, query_embed, reference_points, all_image_feats 290 | ): 291 | # prepare input for encoder 292 | src_flatten = [] 293 | mask_flatten = [] 294 | lvl_pos_embed_flatten = [] 295 | spatial_shapes = [] 296 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): 297 | bs, c, h, w = src.shape 298 | spatial_shape = (h, w) 299 | spatial_shapes.append(spatial_shape) 300 | src = src.flatten(2).transpose(1, 2) 301 | mask = mask.flatten(1) 302 | pos_embed = pos_embed.flatten(2).transpose(1, 2) 303 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 304 | lvl_pos_embed_flatten.append(lvl_pos_embed) 305 | src_flatten.append(src) 306 | mask_flatten.append(mask) 307 | src_flatten = torch.cat(src_flatten, 1) 308 | mask_flatten = torch.cat(mask_flatten, 1) 309 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 310 | spatial_shapes = torch.as_tensor( 311 | spatial_shapes, dtype=torch.long, device=src_flatten.device 312 | ) 313 | level_start_index = torch.cat( 314 | (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) 315 | ) 316 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) 317 | 318 | # encoder 319 | memory = self.encoder( 320 | src_flatten, 321 | spatial_shapes, 322 | level_start_index, 323 | valid_ratios, 324 | lvl_pos_embed_flatten, 325 | mask_flatten, 326 | ) 327 | 328 | # prepare input for decoder 329 | bs, _, c = memory.shape 330 | 331 | tgt = query_embed 332 | 333 | # relational decoder 334 | hs_pixels_s1, _ = self.per_edge_decoder( 335 | tgt, 336 | reference_points, 337 | memory, 338 | spatial_shapes, 339 | level_start_index, 340 | valid_ratios, 341 | query_embed, 342 | mask_flatten, 343 | ) 344 | 345 | feats_s1, preds_s1 = self.generate_corner_preds(hs_pixels_s1, all_image_feats) 346 | 347 | return preds_s1 348 | 349 | def generate_corner_preds(self, outputs, conv_outputs): 350 | B, L, C = outputs.shape 351 | side = int(np.sqrt(L)) 352 | outputs = outputs.view(B, side, side, C) 353 | outputs = outputs.permute(0, 3, 1, 2) 354 | outputs = torch.cat([outputs, conv_outputs["layer1"]], dim=1) 355 | x = self.conv_up1(outputs) 356 | 357 | x = self.upsample(x) 358 | x = torch.cat([x, conv_outputs["layer0"]], dim=1) 359 | x = self.conv_up0(x) 360 | 361 | x = self.upsample(x) 362 | x = torch.cat([x, conv_outputs["x_original"]], dim=1) 363 | x = self.conv_original_size2(x) 364 | 365 | logits = x.permute(0, 2, 3, 1) 366 | preds = self.output_fc_1(logits) 367 | preds = preds.squeeze(-1).sigmoid() 368 | return logits, preds 369 | -------------------------------------------------------------------------------- /code/learn/models/corner_to_edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import scipy.ndimage.filters as filters 5 | import cv2 6 | import itertools 7 | 8 | NEIGHBOUR_SIZE = 5 9 | MATCH_THRESH = 5 10 | LOCAL_MAX_THRESH = 0.01 11 | viz_count = 0 12 | 13 | all_combibations = dict() 14 | for length in range(2, 351): 15 | ids = np.arange(length) 16 | combs = np.array(list(itertools.combinations(ids, 2))) 17 | all_combibations[length] = combs 18 | 19 | def prepare_edge_data(c_outputs, annots, images): 20 | bs = c_outputs.shape[0] 21 | # prepares parameters for each sample of the batch 22 | all_results = list() 23 | 24 | for b_i in range(bs): 25 | annot = annots[b_i] 26 | output = c_outputs[b_i] 27 | results = process_each_sample({'annot': annot, 'output': output, 'viz_img': images[b_i]}) 28 | all_results.append(results) 29 | 30 | processed_corners = [item['corners'] for item in all_results] 31 | edge_coords = [item['edges'] for item in all_results] 32 | edge_labels = [item['labels'] for item in all_results] 33 | 34 | edge_info = { 35 | 'edge_coords': edge_coords, 36 | 'edge_labels': edge_labels, 37 | 'processed_corners': processed_corners 38 | } 39 | 40 | edge_data = collate_edge_info(edge_info) 41 | return edge_data 42 | 43 | 44 | def process_annot(annot, do_round=True): 45 | corners = np.array(list(annot.keys())) 46 | ind = np.lexsort(corners.T) # sort the g.t. corners to fix the order for the matching later 47 | corners = corners[ind] # sorted by y, then x 48 | corner_mapping = {tuple(k): v for v, k in enumerate(corners)} 49 | 50 | edges = list() 51 | for c, connections in annot.items(): 52 | for other_c in connections: 53 | edge_pair = (corner_mapping[c], corner_mapping[tuple(other_c)]) 54 | edges.append(edge_pair) 55 | corner_degrees = [len(annot[tuple(c)]) for c in corners] 56 | if do_round: 57 | corners = corners.round() 58 | return corners, edges, corner_degrees 59 | 60 | 61 | def process_each_sample(data): 62 | annot = data['annot'] 63 | output = data['output'] 64 | viz_img = data['viz_img'] 65 | 66 | preds = output.detach().cpu().numpy() 67 | 68 | data_max = filters.maximum_filter(preds, NEIGHBOUR_SIZE) 69 | maxima = (preds == data_max) 70 | data_min = filters.minimum_filter(preds, NEIGHBOUR_SIZE) 71 | diff = ((data_max - data_min) > 0) 72 | maxima[diff == 0] = 0 73 | local_maximas = np.where((maxima > 0) & (preds > LOCAL_MAX_THRESH)) 74 | pred_corners = np.stack(local_maximas, axis=-1)[:, [1, 0]] # to (x, y format) 75 | 76 | # produce edge labels labels from pred corners here 77 | #global viz_count 78 | #rand_num = np.random.rand() 79 | #if rand_num > 0.7 and len(pred_corners) > 0: 80 | #processed_corners, edges, labels = get_edge_label_det_only(pred_corners, annot) # only using det corners 81 | #output_path = './viz_edge_data/{}_example_det.png'.format(viz_count) 82 | #_visualize_edge_training_data(processed_corners, edges, labels, viz_img, output_path) 83 | #else: 84 | # # use g.t. corners, but mix with neg pred corners 85 | processed_corners, edges, labels = get_edge_label_mix_gt(pred_corners, annot) 86 | #output_path = './viz_training/{}_example_gt.png'.format(viz_count) 87 | #_visualize_edge_training_data(processed_corners, edges, labels, viz_img, output_path) 88 | #viz_count += 1 89 | 90 | results = { 91 | 'corners': processed_corners, 92 | 'edges': edges, 93 | 'labels': labels, 94 | } 95 | return results 96 | 97 | 98 | def get_edge_label_mix_gt(pred_corners, annot): 99 | ind = np.lexsort(pred_corners.T) # sort the pred corners to fix the order for matching 100 | pred_corners = pred_corners[ind] # sorted by y, then x 101 | gt_corners, edge_pairs, corner_degrees = process_annot(annot) 102 | 103 | output_to_gt = dict() 104 | gt_to_output = dict() 105 | diff = np.sqrt(((pred_corners[:, None] - gt_corners) ** 2).sum(-1)) 106 | diff = diff.T 107 | 108 | if len(pred_corners) > 0: 109 | for target_i, target in enumerate(gt_corners): 110 | dist = diff[target_i] 111 | if len(output_to_gt) > 0: 112 | dist[list(output_to_gt.keys())] = 1000 # ignore already matched pred corners 113 | min_dist = dist.min() 114 | min_idx = dist.argmin() 115 | if min_dist < MATCH_THRESH and min_idx not in output_to_gt: # a positive match 116 | output_to_gt[min_idx] = (target_i, min_dist) 117 | gt_to_output[target_i] = min_idx 118 | 119 | all_corners = gt_corners.copy() 120 | 121 | # replace matched g.t. corners with pred corners 122 | for gt_i in range(len(gt_corners)): 123 | if gt_i in gt_to_output: 124 | all_corners[gt_i] = pred_corners[gt_to_output[gt_i]] 125 | 126 | nm_pred_ids = [i for i in range(len(pred_corners)) if i not in output_to_gt] 127 | nm_pred_ids = np.random.permutation(nm_pred_ids) 128 | if len(nm_pred_ids) > 0: 129 | nm_pred_corners = pred_corners[nm_pred_ids] 130 | if len(nm_pred_ids) + len(all_corners) <= 150: 131 | all_corners = np.concatenate([all_corners, nm_pred_corners], axis=0) 132 | else: 133 | all_corners = np.concatenate([all_corners, nm_pred_corners[:(150 - len(gt_corners)), :]], axis=0) 134 | 135 | processed_corners, edges, edge_ids, labels = _get_edges(all_corners, edge_pairs) 136 | 137 | return processed_corners, edges, labels 138 | 139 | 140 | def _get_edges(corners, edge_pairs): 141 | ind = np.lexsort(corners.T) 142 | corners = corners[ind] # sorted by y, then x 143 | corners = corners.round() 144 | id_mapping = {old: new for new, old in enumerate(ind)} 145 | 146 | all_ids = all_combibations[len(corners)] 147 | edges = corners[all_ids] 148 | labels = np.zeros(edges.shape[0]) 149 | 150 | N = len(corners) 151 | edge_pairs = [(id_mapping[p[0]], id_mapping[p[1]]) for p in edge_pairs] 152 | edge_pairs = [p for p in edge_pairs if p[0] < p[1]] 153 | pos_ids = [int((2 * N - 1 - p[0]) * p[0] / 2 + p[1] - p[0] - 1) for p in edge_pairs] 154 | labels[pos_ids] = 1 155 | 156 | edge_ids = np.array(all_ids) 157 | return corners, edges, edge_ids, labels 158 | 159 | 160 | def collate_edge_info(data): 161 | batched_data = {} 162 | lengths_info = {} 163 | for field in data.keys(): 164 | batch_values = data[field] 165 | all_lens = [len(value) for value in batch_values] 166 | max_len = max(all_lens) 167 | pad_value = 0 168 | batch_values = [pad_sequence(value, max_len, pad_value) for value in batch_values] 169 | batch_values = np.stack(batch_values, axis=0) 170 | 171 | if field in ['edge_coords', 'edge_labels', 'gt_values']: 172 | batch_values = torch.Tensor(batch_values).long() 173 | if field in ['processed_corners', 'edge_coords']: 174 | lengths_info[field] = all_lens 175 | batched_data[field] = batch_values 176 | 177 | # Add length and mask into the data, the mask if for Transformers' input format, True means padding 178 | for field, lengths in lengths_info.items(): 179 | lengths_str = field + '_lengths' 180 | batched_data[lengths_str] = torch.Tensor(lengths).long() 181 | mask = torch.arange(max(lengths)) 182 | mask = mask.unsqueeze(0).repeat(batched_data[field].shape[0], 1) 183 | mask = mask >= batched_data[lengths_str].unsqueeze(-1) 184 | mask_str = field + '_mask' 185 | batched_data[mask_str] = mask 186 | 187 | return batched_data 188 | 189 | 190 | def pad_sequence(seq, length, pad_value=0): 191 | if len(seq) == length: 192 | return seq 193 | else: 194 | pad_len = length - len(seq) 195 | if len(seq.shape) == 1: 196 | if pad_value == 0: 197 | paddings = np.zeros([pad_len, ]) 198 | else: 199 | paddings = np.ones([pad_len, ]) * pad_value 200 | else: 201 | if pad_value == 0: 202 | paddings = np.zeros([pad_len, ] + list(seq.shape[1:])) 203 | else: 204 | paddings = np.ones([pad_len, ] + list(seq.shape[1:])) * pad_value 205 | padded_seq = np.concatenate([seq, paddings], axis=0) 206 | return padded_seq 207 | 208 | 209 | def get_infer_edge_pairs(corners, confs): 210 | ind = np.lexsort(corners.T) 211 | corners = corners[ind] # sorted by y, then x 212 | confs = confs[ind] 213 | 214 | edge_ids = all_combibations[len(corners)] 215 | edge_coords = corners[edge_ids] 216 | 217 | edge_coords = torch.tensor(np.array(edge_coords)).unsqueeze(0).long() 218 | mask = torch.zeros([edge_coords.shape[0], edge_coords.shape[1]]).bool() 219 | edge_ids = torch.tensor(np.array(edge_ids)) 220 | return corners, confs, edge_coords, mask, edge_ids 221 | 222 | 223 | 224 | def get_mlm_info(labels, mlm=True): 225 | """ 226 | :param labels: original edge labels 227 | :param mlm: whether enable mlm 228 | :return: g.t. values: 0-known false, 1-known true, 2-unknown 229 | """ 230 | if mlm: # For training / evaluting with MLM 231 | rand_ratio = np.random.rand() * 0.5 + 0.5 232 | labels = torch.Tensor(labels) 233 | gt_rand = torch.rand(labels.size()) 234 | gt_flag = torch.zeros_like(labels) 235 | gt_value = torch.zeros_like(labels) 236 | gt_flag[torch.where(gt_rand >= rand_ratio)] = 1 237 | gt_idx = torch.where(gt_flag == 1) 238 | pred_idx = torch.where(gt_flag == 0) 239 | gt_value[gt_idx] = labels[gt_idx] 240 | gt_value[pred_idx] = 2 # use 2 to represent unknown value, need to predict 241 | else: 242 | labels = torch.Tensor(labels) 243 | gt_flag = torch.zeros_like(labels) 244 | gt_value = torch.zeros_like(labels) 245 | gt_flag[:] = 0 246 | gt_value[:] = 2 247 | return gt_value 248 | 249 | 250 | def _get_rand_midpoint(edge): 251 | c1, c2 = edge 252 | _rand = np.random.rand() * 0.4 + 0.3 253 | mid_x = int(np.round(c1[0] + (c2[0] - c1[0]) * _rand)) 254 | mid_y = int(np.round(c1[1] + (c2[1] - c1[1]) * _rand)) 255 | mid_point = (mid_x, mid_y) 256 | return mid_point 257 | 258 | 259 | def _visualize_edge_training_data(corners, edges, edge_labels, image, save_path): 260 | image = image.transpose([1, 2, 0]) 261 | image = (image * 255).astype(np.uint8) 262 | image = np.ascontiguousarray(image) 263 | 264 | for edge, label in zip(edges, edge_labels): 265 | if label == 1: 266 | cv2.line(image, tuple(edge[0].astype(np.int)), tuple(edge[1].astype(np.int)), (255, 255, 0), 2) 267 | 268 | for c in corners: 269 | cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1) 270 | 271 | cv2.imwrite(save_path, image) 272 | -------------------------------------------------------------------------------- /code/learn/models/deformable_transformer_original.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch import nn, Tensor 4 | from models.ops.modules import MSDeformAttn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DeformableTransformerEncoderLayer(nn.Module): 9 | def __init__( 10 | self, 11 | d_model=256, 12 | d_ffn=1024, 13 | dropout=0.1, 14 | activation="relu", 15 | n_levels=4, 16 | n_heads=8, 17 | n_points=4, 18 | ): 19 | super().__init__() 20 | 21 | # self attention 22 | self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) 23 | self.dropout1 = nn.Dropout(dropout) 24 | self.norm1 = nn.LayerNorm(d_model) 25 | 26 | # ffn 27 | self.linear1 = nn.Linear(d_model, d_ffn) 28 | self.activation = _get_activation_fn(activation) 29 | self.dropout2 = nn.Dropout(dropout) 30 | self.linear2 = nn.Linear(d_ffn, d_model) 31 | self.dropout3 = nn.Dropout(dropout) 32 | self.norm2 = nn.LayerNorm(d_model) 33 | 34 | @staticmethod 35 | def with_pos_embed(tensor, pos): 36 | return tensor if pos is None else tensor + pos 37 | 38 | def forward_ffn(self, src): 39 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 40 | src = src + self.dropout3(src2) 41 | src = self.norm2(src) 42 | return src 43 | 44 | def forward( 45 | self, 46 | src, 47 | pos, 48 | reference_points, 49 | spatial_shapes, 50 | level_start_index, 51 | padding_mask=None, 52 | ): 53 | # self attention 54 | src2 = self.self_attn( 55 | self.with_pos_embed(src, pos), 56 | reference_points, 57 | src, 58 | spatial_shapes, 59 | level_start_index, 60 | padding_mask, 61 | ) 62 | src = src + self.dropout1(src2) 63 | src = self.norm1(src) 64 | 65 | # ffn 66 | src = self.forward_ffn(src) 67 | 68 | return src 69 | 70 | 71 | class DeformableTransformerEncoder(nn.Module): 72 | def __init__(self, encoder_layer, num_layers): 73 | super().__init__() 74 | self.layers = _get_clones(encoder_layer, num_layers) 75 | self.num_layers = num_layers 76 | 77 | @staticmethod 78 | def get_reference_points(spatial_shapes, valid_ratios, device): 79 | reference_points_list = [] 80 | for lvl, (H_, W_) in enumerate(spatial_shapes): 81 | ref_y, ref_x = torch.meshgrid( 82 | torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 83 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), 84 | ) 85 | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) 86 | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) 87 | ref = torch.stack((ref_x, ref_y), -1) 88 | reference_points_list.append(ref) 89 | reference_points = torch.cat(reference_points_list, 1) 90 | reference_points = reference_points[:, :, None] * valid_ratios[:, None] 91 | return reference_points 92 | 93 | def forward( 94 | self, 95 | src, 96 | spatial_shapes, 97 | level_start_index, 98 | valid_ratios, 99 | pos=None, 100 | padding_mask=None, 101 | ): 102 | output = src 103 | reference_points = self.get_reference_points( 104 | spatial_shapes, valid_ratios, device=src.device 105 | ) 106 | for _, layer in enumerate(self.layers): 107 | output = layer( 108 | output, 109 | pos, 110 | reference_points, 111 | spatial_shapes, 112 | level_start_index, 113 | padding_mask, 114 | ) 115 | 116 | return output 117 | 118 | 119 | class DeformableAttnDecoderLayer(nn.Module): 120 | def __init__( 121 | self, 122 | d_model=256, 123 | d_ffn=1024, 124 | dropout=0.1, 125 | activation="relu", 126 | n_levels=4, 127 | n_heads=8, 128 | n_points=4, 129 | ): 130 | super().__init__() 131 | # cross attention 132 | self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) 133 | self.dropout1 = nn.Dropout(dropout) 134 | self.norm1 = nn.LayerNorm(d_model) 135 | 136 | # ffn 137 | self.linear1 = nn.Linear(d_model, d_ffn) 138 | self.activation = _get_activation_fn(activation) 139 | self.dropout3 = nn.Dropout(dropout) 140 | self.linear2 = nn.Linear(d_ffn, d_model) 141 | self.dropout4 = nn.Dropout(dropout) 142 | self.norm3 = nn.LayerNorm(d_model) 143 | 144 | @staticmethod 145 | def with_pos_embed(tensor, pos): 146 | return tensor if pos is None else tensor + pos 147 | 148 | def forward_ffn(self, tgt): 149 | tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) 150 | tgt = tgt + self.dropout4(tgt2) 151 | tgt = self.norm3(tgt) 152 | return tgt 153 | 154 | def forward( 155 | self, 156 | tgt, 157 | query_pos, 158 | reference_points, 159 | src, 160 | src_spatial_shapes, 161 | level_start_index, 162 | src_padding_mask=None, 163 | key_padding_mask=None, 164 | ): 165 | # cross attention 166 | tgt2 = self.cross_attn( 167 | self.with_pos_embed(tgt, query_pos), 168 | reference_points, 169 | src, 170 | src_spatial_shapes, 171 | level_start_index, 172 | src_padding_mask, 173 | ) 174 | tgt = tgt + self.dropout1(tgt2) 175 | tgt = self.norm1(tgt) 176 | 177 | # ffn 178 | tgt = self.forward_ffn(tgt) 179 | 180 | return tgt 181 | 182 | 183 | class DeformableTransformerDecoderLayer(nn.Module): 184 | def __init__( 185 | self, 186 | d_model=256, 187 | d_ffn=1024, 188 | dropout=0.1, 189 | activation="relu", 190 | n_levels=4, 191 | n_heads=8, 192 | n_points=4, 193 | ): 194 | super().__init__() 195 | # cross attention 196 | self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) 197 | self.dropout1 = nn.Dropout(dropout) 198 | self.norm1 = nn.LayerNorm(d_model) 199 | 200 | # self attention 201 | self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) 202 | self.dropout2 = nn.Dropout(dropout) 203 | self.norm2 = nn.LayerNorm(d_model) 204 | 205 | # ffn 206 | self.linear1 = nn.Linear(d_model, d_ffn) 207 | self.activation = _get_activation_fn(activation) 208 | self.dropout3 = nn.Dropout(dropout) 209 | self.linear2 = nn.Linear(d_ffn, d_model) 210 | self.dropout4 = nn.Dropout(dropout) 211 | self.norm3 = nn.LayerNorm(d_model) 212 | 213 | @staticmethod 214 | def with_pos_embed(tensor, pos): 215 | return tensor if pos is None else tensor + pos 216 | 217 | def forward_ffn(self, tgt): 218 | tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) 219 | tgt = tgt + self.dropout4(tgt2) 220 | tgt = self.norm3(tgt) 221 | return tgt 222 | 223 | def forward( 224 | self, 225 | tgt, 226 | query_pos, 227 | reference_points, 228 | src, 229 | src_spatial_shapes, 230 | level_start_index, 231 | src_padding_mask=None, 232 | key_padding_mask=None, 233 | get_image_feat=True, 234 | ): 235 | # self attention 236 | q = k = self.with_pos_embed(tgt, query_pos) 237 | tgt2 = self.self_attn( 238 | q.transpose(0, 1), 239 | k.transpose(0, 1), 240 | tgt.transpose(0, 1), 241 | key_padding_mask=key_padding_mask, 242 | )[0].transpose(0, 1) 243 | tgt = tgt + self.dropout2(tgt2) 244 | tgt = self.norm2(tgt) 245 | 246 | if get_image_feat: 247 | # cross attention 248 | tgt2 = self.cross_attn( 249 | self.with_pos_embed(tgt, query_pos), 250 | reference_points, 251 | src, 252 | src_spatial_shapes, 253 | level_start_index, 254 | src_padding_mask, 255 | ) 256 | tgt = tgt + self.dropout1(tgt2) 257 | tgt = self.norm1(tgt) 258 | 259 | # ffn 260 | tgt = self.forward_ffn(tgt) 261 | 262 | return tgt 263 | 264 | 265 | class DeformableTransformerDecoder(nn.Module): 266 | def __init__( 267 | self, decoder_layer, num_layers, return_intermediate=False, with_sa=True 268 | ): 269 | super().__init__() 270 | self.layers = _get_clones(decoder_layer, num_layers) 271 | self.num_layers = num_layers 272 | self.return_intermediate = return_intermediate 273 | # hack implementation for iterative bounding box refinement and two-stage Deformable DETR 274 | self.with_sa = with_sa 275 | 276 | def forward( 277 | self, 278 | tgt, 279 | reference_points, 280 | src, 281 | src_spatial_shapes, 282 | src_level_start_index, 283 | src_valid_ratios, 284 | query_pos=None, 285 | src_padding_mask=None, 286 | key_padding_mask=None, 287 | get_image_feat=True, 288 | ): 289 | output = tgt 290 | 291 | intermediate = [] 292 | intermediate_reference_points = [] 293 | for lid, layer in enumerate(self.layers): 294 | if reference_points.shape[-1] == 4: 295 | reference_points_input = ( 296 | reference_points[:, :, None] 297 | * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] 298 | ) 299 | else: 300 | assert reference_points.shape[-1] == 2 301 | reference_points_input = ( 302 | reference_points[:, :, None] * src_valid_ratios[:, None] 303 | ) 304 | if self.with_sa: 305 | output = layer( 306 | output, 307 | query_pos, 308 | reference_points_input, 309 | src, 310 | src_spatial_shapes, 311 | src_level_start_index, 312 | src_padding_mask, 313 | key_padding_mask, 314 | get_image_feat, 315 | ) 316 | else: 317 | output = layer( 318 | output, 319 | query_pos, 320 | reference_points_input, 321 | src, 322 | src_spatial_shapes, 323 | src_level_start_index, 324 | src_padding_mask, 325 | key_padding_mask, 326 | ) 327 | 328 | if self.return_intermediate: 329 | intermediate.append(output) 330 | intermediate_reference_points.append(reference_points) 331 | 332 | if self.return_intermediate: 333 | return torch.stack(intermediate), torch.stack(intermediate_reference_points) 334 | 335 | return output, reference_points 336 | 337 | 338 | def _get_clones(module, N): 339 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 340 | 341 | 342 | def _get_activation_fn(activation): 343 | """Return an activation function given a string""" 344 | if activation == "relu": 345 | return F.relu 346 | if activation == "gelu": 347 | return F.gelu 348 | if activation == "glu": 349 | return F.glu 350 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 351 | -------------------------------------------------------------------------------- /code/learn/models/edge_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from models.mlp import MLP 6 | from models.deformable_transformer import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \ 7 | DeformableTransformerDecoder, DeformableTransformerDecoderLayer, DeformableAttnDecoderLayer 8 | from models.ops.modules import MSDeformAttn 9 | from models.corner_models import PositionEmbeddingSine 10 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ 11 | import torch.nn.functional as F 12 | from utils.misc import NestedTensor 13 | 14 | 15 | class EdgeEnum(nn.Module): 16 | def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels, ): 17 | super(EdgeEnum, self).__init__() 18 | self.input_dim = input_dim 19 | self.hidden_dim = hidden_dim 20 | self.num_feature_levels = num_feature_levels 21 | 22 | if num_feature_levels > 1: 23 | num_backbone_outs = len(backbone_strides) 24 | input_proj_list = [] 25 | for _ in range(num_backbone_outs): 26 | in_channels = backbone_num_channels[_] 27 | input_proj_list.append(nn.Sequential( 28 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1), 29 | nn.GroupNorm(32, hidden_dim), 30 | )) 31 | for _ in range(num_feature_levels - num_backbone_outs): 32 | input_proj_list.append(nn.Sequential( 33 | nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), 34 | nn.GroupNorm(32, hidden_dim), 35 | )) 36 | in_channels = hidden_dim 37 | self.input_proj = nn.ModuleList(input_proj_list) 38 | else: 39 | self.input_proj = nn.ModuleList([ 40 | nn.Sequential( 41 | nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1), 42 | nn.GroupNorm(32, hidden_dim), 43 | )]) 44 | 45 | self.img_pos = PositionEmbeddingSine(hidden_dim // 2) 46 | 47 | self.edge_input_fc = nn.Linear(input_dim * 2, hidden_dim) 48 | self.output_fc = MLP(input_dim=hidden_dim, hidden_dim=hidden_dim // 2, output_dim=2, num_layers=2) 49 | 50 | # self.mlm_embedding = nn.Embedding(3, input_dim) 51 | self.transformer = EdgeTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1, 52 | num_decoder_layers=6, dim_feedforward=1024, dropout=0.1) 53 | 54 | @staticmethod 55 | def get_ms_feat(xs, img_mask): 56 | out: Dict[str, NestedTensor] = {} 57 | # out = list() 58 | for name, x in sorted(xs.items()): 59 | m = img_mask 60 | assert m is not None 61 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 62 | out[name] = NestedTensor(x, mask) 63 | # out.append(NestedTensor(x, mask)) 64 | return out 65 | 66 | def forward(self, image_feats, feat_mask, corner_outputs, edge_coords, edge_masks, gt_values, corner_nums, 67 | max_candidates, do_inference=False, get_hs=False): 68 | # Prepare ConvNet features 69 | features = self.get_ms_feat(image_feats, feat_mask) 70 | 71 | srcs = [] 72 | masks = [] 73 | all_pos = [] 74 | 75 | new_features = list() 76 | for name, x in sorted(features.items()): 77 | new_features.append(x) 78 | features = new_features 79 | 80 | for l, feat in enumerate(features): 81 | src, mask = feat.decompose() 82 | mask = mask.to(src.device) 83 | srcs.append(self.input_proj[l](src)) 84 | pos = self.img_pos(src).to(src.dtype) 85 | all_pos.append(pos) 86 | masks.append(mask) 87 | assert mask is not None 88 | 89 | if self.num_feature_levels > len(srcs): 90 | _len_srcs = len(srcs) 91 | for l in range(_len_srcs, self.num_feature_levels): 92 | if l == _len_srcs: 93 | src = self.input_proj[l](features[-1].tensors) 94 | else: 95 | src = self.input_proj[l](srcs[-1]) 96 | m = feat_mask 97 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device) 98 | pos_l = self.img_pos(src).to(src.dtype) 99 | srcs.append(src) 100 | masks.append(mask) 101 | all_pos.append(pos_l) 102 | 103 | bs = edge_masks.size(0) 104 | num_edges = edge_masks.size(1) 105 | 106 | corner_feats = corner_outputs 107 | edge_feats = list() 108 | for b_i in range(bs): 109 | feats = corner_feats[b_i, edge_coords[b_i, :, :, 1], edge_coords[b_i, :, :, 0], :] 110 | edge_feats.append(feats) 111 | edge_feats = torch.stack(edge_feats, dim=0) 112 | edge_feats = edge_feats.view(bs, num_edges, -1) 113 | 114 | edge_inputs = self.edge_input_fc(edge_feats.view(bs * num_edges, -1)) 115 | edge_inputs = edge_inputs.view(bs, num_edges, -1) 116 | 117 | edge_center = (edge_coords[:, :, 0, :].float() + edge_coords[:, :, 1, :].float()) / 2 118 | edge_center = edge_center / feat_mask.shape[1] 119 | 120 | return self.transformer(srcs, masks, all_pos, edge_inputs, edge_center, gt_values, 121 | edge_masks, corner_nums, max_candidates, do_inference, get_hs) 122 | 123 | 124 | class EdgeTransformer(nn.Module): 125 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 126 | num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, 127 | activation="relu", return_intermediate_dec=False, 128 | num_feature_levels=4, dec_n_points=4, enc_n_points=4, 129 | ): 130 | super(EdgeTransformer, self).__init__() 131 | 132 | encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, 133 | dropout, activation, 134 | num_feature_levels, nhead, enc_n_points) 135 | self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) 136 | 137 | decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward, 138 | dropout, activation, 139 | num_feature_levels, nhead, dec_n_points) 140 | self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False) 141 | 142 | decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward, 143 | dropout, activation, 144 | num_feature_levels, nhead, dec_n_points) 145 | self.relational_decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, 146 | return_intermediate_dec, with_sa=True) 147 | 148 | self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) 149 | # self.reference_points = nn.Linear(d_model, 2) 150 | 151 | self.gt_label_embed = nn.Embedding(3, d_model) 152 | 153 | self.input_fc_hb = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2) 154 | self.input_fc_rel = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2) 155 | 156 | self.output_fc_1 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2) 157 | self.output_fc_2 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2) 158 | self.output_fc_3 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2) 159 | self._reset_parameters() 160 | 161 | def _reset_parameters(self): 162 | for p in self.parameters(): 163 | if p.dim() > 1: 164 | nn.init.xavier_uniform_(p) 165 | for m in self.modules(): 166 | if isinstance(m, MSDeformAttn): 167 | m._reset_parameters() 168 | # xavier_uniform_(self.reference_points.weight.data, gain=1.0) 169 | # constant_(self.reference_points.bias.data, 0.) 170 | normal_(self.level_embed) 171 | 172 | def get_valid_ratio(self, mask): 173 | _, H, W = mask.shape 174 | valid_H = torch.sum(~mask[:, :, 0], 1) 175 | valid_W = torch.sum(~mask[:, 0, :], 1) 176 | valid_ratio_h = valid_H.float() / H 177 | valid_ratio_w = valid_W.float() / W 178 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) 179 | return valid_ratio 180 | 181 | def forward(self, srcs, masks, pos_embeds, query_embed, reference_points, labels, key_padding_mask, corner_nums, 182 | max_candidates, do_inference=False, get_hs=False): 183 | # prepare input for encoder 184 | src_flatten = [] 185 | mask_flatten = [] 186 | lvl_pos_embed_flatten = [] 187 | spatial_shapes = [] 188 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): 189 | bs, c, h, w = src.shape 190 | spatial_shape = (h, w) 191 | spatial_shapes.append(spatial_shape) 192 | src = src.flatten(2).transpose(1, 2) 193 | mask = mask.flatten(1) 194 | pos_embed = pos_embed.flatten(2).transpose(1, 2) 195 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 196 | lvl_pos_embed_flatten.append(lvl_pos_embed) 197 | src_flatten.append(src) 198 | mask_flatten.append(mask) 199 | src_flatten = torch.cat(src_flatten, 1) 200 | mask_flatten = torch.cat(mask_flatten, 1) 201 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 202 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) 203 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) 204 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) 205 | 206 | # encoder 207 | memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, 208 | mask_flatten) 209 | 210 | # prepare input for decoder 211 | bs, _, c = memory.shape 212 | 213 | tgt = query_embed 214 | 215 | # reference_points = self.reference_points(query_embed).sigmoid() 216 | init_reference_out = reference_points 217 | 218 | # relational decoder 219 | hs_per_edge, _ = self.per_edge_decoder(tgt, reference_points, memory, 220 | spatial_shapes, level_start_index, valid_ratios, query_embed, 221 | mask_flatten) 222 | 223 | logits_per_edge = self.output_fc_1(hs_per_edge).permute(0, 2, 1) 224 | 225 | if get_hs: 226 | return hs_per_edge, logits_per_edge 227 | 228 | filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids = self.candidate_filtering( 229 | logits_per_edge, 230 | hs_per_edge, query_embed, reference_points, 231 | labels, 232 | key_padding_mask, corner_nums, max_candidates) 233 | 234 | if not do_inference: 235 | filtered_gt_values = self.generate_gt_masking(filtered_labels, filtered_mask) 236 | else: 237 | filtered_gt_values = filtered_labels 238 | 239 | # relational decoder with image feature 240 | gt_info = self.gt_label_embed(filtered_gt_values) 241 | hybrid_prim_hs = self.input_fc_hb(torch.cat([filtered_hs, gt_info], dim=-1)) 242 | 243 | hs, inter_references = self.relational_decoder(hybrid_prim_hs, filtered_rp, memory, 244 | spatial_shapes, level_start_index, valid_ratios, filtered_query, 245 | mask_flatten, 246 | key_padding_mask=filtered_mask, get_image_feat=True) 247 | 248 | logits_final_hb = self.output_fc_2(hs).permute(0, 2, 1) 249 | 250 | # relational decoder without image feature 251 | rel_prim_hs = self.input_fc_rel(torch.cat([filtered_query, gt_info], dim=-1)) 252 | 253 | hs_rel, _ = self.relational_decoder(rel_prim_hs, filtered_rp, memory, 254 | spatial_shapes, level_start_index, valid_ratios, filtered_query, 255 | mask_flatten, 256 | key_padding_mask=filtered_mask, get_image_feat=False) 257 | 258 | logits_final_rel = self.output_fc_3(hs_rel).permute(0, 2, 1) 259 | 260 | return logits_per_edge, logits_final_hb, logits_final_rel, selected_ids, filtered_mask, filtered_gt_values 261 | 262 | @staticmethod 263 | def candidate_filtering(logits, hs, query, rp, labels, key_padding_mask, corner_nums, max_candidates): 264 | B, L, _ = hs.shape 265 | preds = logits.detach().softmax(1)[:, 1, :] # BxL 266 | preds[key_padding_mask == True] = -1 # ignore the masking parts 267 | sorted_ids = torch.argsort(preds, dim=-1, descending=True) 268 | filtered_hs = list() 269 | filtered_mask = list() 270 | filtered_query = list() 271 | filtered_rp = list() 272 | filtered_labels = list() 273 | selected_ids = list() 274 | for b_i in range(B): 275 | num_candidates = corner_nums[b_i] * 3 276 | ids = sorted_ids[b_i, :max_candidates[b_i]] 277 | filtered_hs.append(hs[b_i][ids]) 278 | new_mask = key_padding_mask[b_i][ids] 279 | new_mask[num_candidates:] = True 280 | filtered_mask.append(new_mask) 281 | filtered_query.append(query[b_i][ids]) 282 | filtered_rp.append(rp[b_i][ids]) 283 | filtered_labels.append(labels[b_i][ids]) 284 | selected_ids.append(ids) 285 | filtered_hs = torch.stack(filtered_hs, dim=0) 286 | filtered_mask = torch.stack(filtered_mask, dim=0) 287 | filtered_query = torch.stack(filtered_query, dim=0) 288 | filtered_rp = torch.stack(filtered_rp, dim=0) 289 | filtered_labels = torch.stack(filtered_labels, dim=0) 290 | selected_ids = torch.stack(selected_ids, dim=0) 291 | 292 | return filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids 293 | 294 | 295 | @staticmethod 296 | def generate_gt_masking(labels, mask): 297 | bs = labels.shape[0] 298 | gt_values = torch.zeros_like(mask).long() 299 | for b_i in range(bs): 300 | edge_length = (mask[b_i] == 0).sum() 301 | rand_ratio = np.random.rand() * 0.5 + 0.5 302 | gt_rand = torch.rand(edge_length) 303 | gt_flag = torch.zeros(edge_length) 304 | gt_flag[torch.where(gt_rand >= rand_ratio)] = 1 305 | gt_idx = torch.where(gt_flag == 1) 306 | pred_idx = torch.where(gt_flag == 0) 307 | gt_values[b_i, gt_idx[0]] = labels[b_i, gt_idx[0]] 308 | gt_values[b_i, pred_idx[0]] = 2 # use 2 to represent unknown value, need to predict 309 | return gt_values 310 | -------------------------------------------------------------------------------- /code/learn/models/full_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from models.mlp import MLP 6 | from models.deformable_transformer_full import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \ 7 | DeformableTransformerDecoder, DeformableTransformerDecoderLayer, DeformableAttnDecoderLayer 8 | from models.ops.modules import MSDeformAttn 9 | from models.corner_models import PositionEmbeddingSine 10 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ 11 | import torch.nn.functional as F 12 | from utils.misc import NestedTensor 13 | from utils.nn_utils import pos_encode_2d 14 | 15 | 16 | class EdgeEnum(nn.Module): 17 | 18 | def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels): 19 | super(EdgeEnum, self).__init__() 20 | self.input_dim = input_dim 21 | self.hidden_dim = hidden_dim 22 | self.num_feature_levels = num_feature_levels 23 | 24 | if num_feature_levels > 1: 25 | num_backbone_outs = len(backbone_strides) 26 | input_proj_list = [] 27 | for _ in range(num_backbone_outs): 28 | in_channels = backbone_num_channels[_] 29 | input_proj_list.append(nn.Sequential( 30 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1), 31 | nn.GroupNorm(32, hidden_dim), 32 | )) 33 | for _ in range(num_feature_levels - num_backbone_outs): 34 | input_proj_list.append(nn.Sequential( 35 | nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), 36 | nn.GroupNorm(32, hidden_dim), 37 | )) 38 | in_channels = hidden_dim 39 | self.input_proj = nn.ModuleList(input_proj_list) 40 | else: 41 | self.input_proj = nn.ModuleList([ 42 | nn.Sequential( 43 | nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1), 44 | nn.GroupNorm(32, hidden_dim), 45 | )]) 46 | 47 | self.img_pos = PositionEmbeddingSine(hidden_dim // 2) 48 | self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, hidden_dim)) 49 | self.edge_input_fc = nn.Linear(input_dim * 2, hidden_dim) 50 | 51 | self.output_fc = MLP(input_dim=hidden_dim, hidden_dim=hidden_dim // 2, output_dim=2, num_layers=2) 52 | 53 | # self.mlm_embedding = nn.Embedding(3, input_dim) 54 | self.transformer = EdgeTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1, 55 | num_decoder_layers=6, dim_feedforward=1024, dropout=0.1) 56 | 57 | self._reset_parameters() 58 | 59 | 60 | def _reset_parameters(self): 61 | # for p in self.parameters(): 62 | # if p.dim() > 1: 63 | # nn.init.xavier_uniform_(p) 64 | # for m in self.modules(): 65 | # if isinstance(m, MSDeformAttn): 66 | # m._reset_parameters() 67 | 68 | # xavier_uniform_(self.reference_points.weight.data, gain=1.0) 69 | # constant_(self.reference_points.bias.data, 0.) 70 | normal_(self.level_embed) 71 | 72 | 73 | @staticmethod 74 | def get_ms_feat(xs, img_mask): 75 | out: Dict[str, NestedTensor] = {} 76 | # out = list() 77 | for name, x in sorted(xs.items()): 78 | m = img_mask 79 | assert m is not None 80 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 81 | out[name] = NestedTensor(x, mask) 82 | # out.append(NestedTensor(x, mask)) 83 | return out 84 | 85 | 86 | def get_valid_ratio(self, mask): 87 | _, H, W = mask.shape 88 | valid_H = torch.sum(~mask[:, :, 0], 1) 89 | valid_W = torch.sum(~mask[:, 0, :], 1) 90 | valid_ratio_h = valid_H.float() / H 91 | valid_ratio_w = valid_W.float() / W 92 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) 93 | return valid_ratio 94 | 95 | 96 | def prepare_image_feat(self, crop, backbone): 97 | image = torch.tensor(crop['image']).cuda().unsqueeze(0) 98 | with torch.no_grad(): 99 | image_feats, feat_mask, _ = backbone(image) 100 | features = self.get_ms_feat(image_feats, feat_mask) 101 | 102 | srcs = [] 103 | masks = [] 104 | all_pos = [] 105 | 106 | new_features = list() 107 | for name, x in sorted(features.items()): 108 | new_features.append(x) 109 | features = new_features 110 | 111 | for l, feat in enumerate(features): 112 | src, mask = feat.decompose() 113 | mask = mask.to(src.device) 114 | srcs.append(self.input_proj[l](src)) 115 | pos = self.img_pos(src).to(src.dtype) 116 | all_pos.append(pos) 117 | masks.append(mask) 118 | assert mask is not None 119 | 120 | if self.num_feature_levels > len(srcs): 121 | _len_srcs = len(srcs) 122 | for l in range(_len_srcs, self.num_feature_levels): 123 | if l == _len_srcs: 124 | src = self.input_proj[l](features[-1].tensors) 125 | else: 126 | src = self.input_proj[l](srcs[-1]) 127 | m = feat_mask 128 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device) 129 | pos_l = self.img_pos(src).to(src.dtype) 130 | srcs.append(src) 131 | masks.append(mask) 132 | all_pos.append(pos_l) 133 | 134 | # prepare input for encoder 135 | pos_embeds = all_pos 136 | 137 | src_flatten = [] 138 | mask_flatten = [] 139 | lvl_pos_embed_flatten = [] 140 | spatial_shapes = [] 141 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): 142 | bs, c, h, w = src.shape 143 | spatial_shape = (h, w) 144 | spatial_shapes.append(spatial_shape) 145 | src = src.flatten(2).transpose(1, 2) 146 | mask = mask.flatten(1) 147 | pos_embed = pos_embed.flatten(2).transpose(1, 2) 148 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 149 | lvl_pos_embed_flatten.append(lvl_pos_embed) 150 | src_flatten.append(src) 151 | mask_flatten.append(mask) 152 | src_flatten = torch.cat(src_flatten, 1) 153 | mask_flatten = torch.cat(mask_flatten, 1) 154 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 155 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) 156 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) 157 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) 158 | 159 | image_data = { 160 | 'src_flatten': src_flatten, 161 | 'spatial_shapes': spatial_shapes, 162 | 'level_start_index': level_start_index, 163 | 'valid_ratios': valid_ratios, 164 | 'lvl_pos_embed_flatten': lvl_pos_embed_flatten, 165 | 'mask_flatten': mask_flatten 166 | } 167 | return image_data 168 | 169 | 170 | def forward(self, data, backbone): 171 | # encode edges 172 | edge_coords = data['edge'] 173 | edge_feats_a = pos_encode_2d(x=edge_coords[:,0], y=edge_coords[:,1]) 174 | edge_feats_b = pos_encode_2d(x=edge_coords[:,2], y=edge_coords[:,3]) 175 | edge_feats = torch.cat([edge_feats_a, edge_feats_b], dim=-1) 176 | edge_feats = self.edge_input_fc(edge_feats) 177 | 178 | # prepare image features 179 | for batch_crops in data['crops']: 180 | for crop in batch_crops: 181 | crop['image_data'] = self.prepare_image_feat(crop, backbone) 182 | 183 | # edge_center = (edge_coords[:, :, 0, :].float() + edge_coords[:, :, 1, :].float()) / 2 184 | # edge_center = edge_center / feat_mask.shape[1] 185 | 186 | # logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values = self.transformer(srcs, 187 | # masks, 188 | # all_pos, 189 | # edge_inputs, 190 | # edge_center, 191 | # gt_values, 192 | # edge_masks, 193 | # corner_nums, 194 | # max_candidates, 195 | # do_inference) 196 | 197 | self.transformer(edge_feats, data) 198 | 199 | return logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values 200 | 201 | 202 | class EdgeTransformer(nn.Module): 203 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 204 | num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, 205 | activation="relu", return_intermediate_dec=False, 206 | num_feature_levels=4, dec_n_points=4, enc_n_points=4, 207 | ): 208 | super(EdgeTransformer, self).__init__() 209 | 210 | encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, 211 | dropout, activation, 212 | num_feature_levels, nhead, enc_n_points) 213 | self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) 214 | 215 | decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward, 216 | dropout, activation, 217 | num_feature_levels, nhead, dec_n_points) 218 | self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False) 219 | 220 | decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward, 221 | dropout, activation, 222 | num_feature_levels, nhead, dec_n_points) 223 | self.relational_decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, 224 | return_intermediate_dec, with_sa=True) 225 | 226 | # self.reference_points = nn.Linear(d_model, 2) 227 | 228 | self.gt_label_embed = nn.Embedding(3, d_model) 229 | 230 | self.input_fc_hb = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2) 231 | self.input_fc_rel = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2) 232 | 233 | self.output_fc_1 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2) 234 | self.output_fc_2 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2) 235 | self.output_fc_3 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2) 236 | 237 | self._reset_parameters() 238 | 239 | def _reset_parameters(self): 240 | for p in self.parameters(): 241 | if p.dim() > 1: 242 | nn.init.xavier_uniform_(p) 243 | for m in self.modules(): 244 | if isinstance(m, MSDeformAttn): 245 | m._reset_parameters() 246 | 247 | def forward(self, edge_feats, data): 248 | src_flatten = [] 249 | mask_flatten = [] 250 | lvl_pos_embed_flatten = [] 251 | 252 | for batch_crops in data['crops']: 253 | for crop in batch_crops: 254 | image_data = crop['image_data'] 255 | with torch.no_grad(): 256 | memory = self.encoder(image_data['src_flatten'], 257 | image_data['spatial_shapes'], 258 | image_data['level_start_index'], 259 | image_data['valid_ratios'], 260 | image_data['lvl_pos_embed_flatten'], 261 | image_data['mask_flatten']) 262 | image_data['memory'] = memory 263 | 264 | hs_per_edge, _ = self.per_edge_decoder(edge_feats, data) 265 | 266 | logits_per_edge = self.output_fc_1(hs_per_edge).permute(0, 2, 1) 267 | 268 | # relational decoder with image feature 269 | gt_info = self.gt_label_embed(filtered_gt_values) 270 | hybrid_prim_hs = self.input_fc_hb(torch.cat([filtered_hs, gt_info], dim=-1)) 271 | 272 | hs, inter_references = self.relational_decoder(hybrid_prim_hs, filtered_rp, memory, 273 | spatial_shapes, level_start_index, valid_ratios, filtered_query, 274 | mask_flatten, 275 | key_padding_mask=filtered_mask, get_image_feat=True) 276 | 277 | logits_final_hb = self.output_fc_2(hs).permute(0, 2, 1) 278 | 279 | # relational decoder without image feature 280 | rel_prim_hs = self.input_fc_rel(torch.cat([filtered_query, gt_info], dim=-1)) 281 | 282 | hs_rel, _ = self.relational_decoder(rel_prim_hs, filtered_rp, memory, 283 | spatial_shapes, level_start_index, valid_ratios, filtered_query, 284 | mask_flatten, 285 | key_padding_mask=filtered_mask, get_image_feat=False) 286 | 287 | logits_final_rel = self.output_fc_3(hs_rel).permute(0, 2, 1) 288 | 289 | return logits_per_edge, logits_final_hb, logits_final_rel, selected_ids, filtered_mask, filtered_gt_values -------------------------------------------------------------------------------- /code/learn/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.optimize import linear_sum_assignment 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import scipy.ndimage.filters as filters 6 | import numpy as np 7 | from utils.geometry_utils import edge_acc 8 | 9 | 10 | class CornerCriterion(nn.Module): 11 | def __init__(self, image_size): 12 | super().__init__() 13 | self.gamma = 1 14 | self.loss_rate = 9 15 | 16 | # def forward(self, outputs_s1, outputs_s2, s2_mask, s2_candidates, targets, gauss_targets, epoch=0): 17 | def forward(self, outputs_s1, targets, gauss_targets, epoch=0): 18 | # Compute the acc first, use the acc to guide the setup of loss weight 19 | preds_s1 = (outputs_s1 >= 0.5).float() 20 | pos_target_ids = torch.where(targets == 1) 21 | correct = (preds_s1[pos_target_ids] == targets[pos_target_ids]).float().sum() 22 | # acc = correct / (preds.shape[0] * preds.shape[1] * preds.shape[2]) 23 | # num_pos_preds = (preds == 1).sum() 24 | recall_s1 = correct / len(pos_target_ids[0]) 25 | # prec = correct / num_pos_preds if num_pos_preds > 0 else torch.tensor(0).to(correct.device) 26 | # f_score = 2.0 * prec * recall / (recall + prec + 1e-8) 27 | 28 | rate = self.loss_rate 29 | 30 | loss_weight = (gauss_targets > 0.5).float() * rate + 1 31 | loss_s1 = F.binary_cross_entropy( 32 | outputs_s1, gauss_targets, weight=loss_weight, reduction="none" 33 | ) 34 | loss_s1 = loss_s1.sum(-1).sum(-1).mean() 35 | 36 | # loss for stage-2 37 | # B, H, W = gauss_targets.shape 38 | # gauss_targets_1d = gauss_targets.view(B, H*W) 39 | # s2_ids = s2_candidates[:, :, 1] * H + s2_candidates[:, :, 0] 40 | # s2_labels = torch.gather(gauss_targets_1d, 1, s2_ids) 41 | # # try an aggressive labeling for s2 42 | # s2_th = 0.1 43 | # s2_labels = (s2_labels > s2_th).float() 44 | # loss_weight = (s2_labels > 0.5).float() * rate + 1 45 | # loss_s2 = F.binary_cross_entropy(outputs_s2, s2_labels, weight=loss_weight, reduction='none') 46 | # loss_s2[torch.where(s2_mask == True)] = 0 47 | # loss_s2 = loss_s2.sum(-1).sum(-1).mean() 48 | 49 | return loss_s1, recall_s1 50 | # , loss_s2, recall_s1 51 | 52 | 53 | class CornerCriterion4D(nn.Module): 54 | def __init__(self): 55 | super().__init__() 56 | 57 | def forward(self, outputs, corner_labels, offset_labels): 58 | loss_dict = {"loss_jloc": 0.0, "loss_joff": 0.0} 59 | for _, output in enumerate(outputs): 60 | loss_dict["loss_jloc"] += cross_entropy_loss_for_junction( 61 | output[:, :2], corner_labels 62 | ) 63 | loss_dict["loss_joff"] += sigmoid_l1_loss( 64 | output[:, 2:4], offset_labels, -0.5, corner_labels.float() 65 | ) 66 | return loss_dict 67 | 68 | 69 | def cross_entropy_loss_for_junction(logits, positive): 70 | nlogp = -F.log_softmax(logits, dim=1) 71 | loss = positive * nlogp[:, None, 1] + (1 - positive) * nlogp[:, None, 0] 72 | pos_rate = 4 73 | weights = (positive == 1) * pos_rate + 1 74 | loss = loss * weights 75 | 76 | loss = loss.mean() 77 | return loss 78 | 79 | 80 | def sigmoid_l1_loss(logits, targets, offset=0.0, mask=None): 81 | logp = torch.sigmoid(logits) + offset 82 | loss = torch.abs(logp - targets) 83 | 84 | if mask is not None: 85 | w = mask.mean(3, True).mean(2, True) 86 | w[w == 0] = 1 87 | loss = loss * (mask / w) 88 | 89 | loss = loss.mean() # avg over batch dim 90 | return loss 91 | 92 | 93 | class EdgeCriterion(nn.Module): 94 | def __init__(self): 95 | super().__init__() 96 | self.edge_loss = nn.CrossEntropyLoss( 97 | weight=torch.tensor([0.33, 1.0]).cuda(), reduction="none" 98 | ) 99 | self.width_loss = nn.CrossEntropyLoss(reduction="none", ignore_index=0) 100 | self.gamma = 1.0 # used by focal loss (if enabled) 101 | 102 | def forward( 103 | self, 104 | logits_s1, 105 | logits_edge_hybrid, 106 | logits_edge_rel, 107 | logits_width_hybrid, 108 | logits_width_rel, 109 | s2_ids, 110 | s2_edge_mask, 111 | edge_labels, 112 | width_labels, 113 | edge_lengths, 114 | edge_mask, 115 | s2_gt_values, 116 | ): 117 | # ensure batch size of 1 and no padding 118 | assert len(logits_s1) == 1 119 | assert not s2_edge_mask.any() 120 | 121 | # loss for stage-1: edge filtering 122 | s1_losses = self.edge_loss(logits_s1, edge_labels) 123 | s1_losses[torch.where(edge_mask == True)] = 0 124 | s1_losses = s1_losses[torch.where(s1_losses > 0)].sum() / edge_mask.shape[0] 125 | gt_values = torch.ones_like(edge_mask).long() * 2 126 | s1_acc = edge_acc(logits_s1, edge_labels, edge_lengths, gt_values) 127 | 128 | # loss for stage-2 129 | s2_edge_labels = torch.gather(edge_labels, 1, s2_ids) 130 | s2_width_labels = torch.gather(width_labels, 1, s2_ids) 131 | 132 | ### Edge losses ### 133 | 134 | s2_losses_e_hybrid = self.edge_loss(logits_edge_hybrid, s2_edge_labels) 135 | s2_losses_e_hybrid[ 136 | torch.where((s2_edge_mask == True) | (s2_gt_values != 2)) 137 | ] = 0 138 | # aggregate the loss into the final scalar 139 | s2_losses_e_hybrid = ( 140 | s2_losses_e_hybrid[torch.where(s2_losses_e_hybrid > 0)].sum() 141 | / s2_edge_mask.shape[0] 142 | ) 143 | s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1) 144 | # compute edge-level f1-score 145 | s2_acc_e_hybrid = edge_acc( 146 | logits_edge_hybrid, s2_edge_labels, s2_edge_lengths, s2_gt_values 147 | ) 148 | 149 | s2_losses_e_rel = self.edge_loss(logits_edge_rel, s2_edge_labels) 150 | s2_losses_e_rel[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0 151 | # aggregate the loss into the final scalar 152 | s2_losses_e_rel = ( 153 | s2_losses_e_rel[torch.where(s2_losses_e_rel > 0)].sum() 154 | / s2_edge_mask.shape[0] 155 | ) 156 | s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1) 157 | # compute edge-level f1-score 158 | s2_acc_e_rel = edge_acc( 159 | logits_edge_rel, s2_edge_labels, s2_edge_lengths, s2_gt_values 160 | ) 161 | 162 | ### Width losses ### 163 | 164 | s2_losses_w_hybrid = self.width_loss(logits_width_hybrid, s2_width_labels) 165 | s2_losses_w_hybrid[torch.where(s2_edge_mask == True)] = 0 166 | # aggregate the loss into the final scalar 167 | s2_losses_w_hybrid = ( 168 | s2_losses_w_hybrid[torch.where(s2_losses_w_hybrid > 0)].sum() 169 | / s2_edge_mask.shape[0] 170 | ) 171 | s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1) 172 | # compute edge-level f1-score 173 | preds_width_hybrid = logits_width_hybrid.argmax(1) 174 | s2_acc_w_hybrid = ( 175 | ( 176 | preds_width_hybrid[s2_width_labels != 0] 177 | == s2_width_labels[s2_width_labels != 0] 178 | ) 179 | .float() 180 | .mean() 181 | ) 182 | 183 | s2_losses_w_rel = self.width_loss(logits_width_rel, s2_width_labels) 184 | s2_losses_w_rel[torch.where(s2_edge_mask == True)] = 0 185 | # aggregate the loss into the final scalar 186 | s2_losses_w_rel = ( 187 | s2_losses_w_rel[torch.where(s2_losses_w_rel > 0)].sum() 188 | / s2_edge_mask.shape[0] 189 | ) 190 | s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1) 191 | # compute edge-level f1-score 192 | preds_width_rel = logits_width_rel.argmax(1) 193 | s2_acc_w_rel = ( 194 | ( 195 | preds_width_rel[s2_width_labels != 0] 196 | == s2_width_labels[s2_width_labels != 0] 197 | ) 198 | .float() 199 | .mean() 200 | ) 201 | 202 | return ( 203 | s1_losses, 204 | s1_acc, 205 | s2_losses_e_hybrid, 206 | s2_acc_e_hybrid, 207 | s2_losses_e_rel, 208 | s2_acc_e_rel, 209 | s2_losses_w_hybrid, 210 | s2_acc_w_hybrid, 211 | s2_losses_w_rel, 212 | s2_acc_w_rel, 213 | ) 214 | -------------------------------------------------------------------------------- /code/learn/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class MLP(nn.Module): 6 | """ Very simple multi-layer perceptron (also called FFN)""" 7 | 8 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 9 | super(MLP, self).__init__() 10 | self.output_dim = output_dim 11 | self.num_layers = num_layers 12 | h = [hidden_dim] * (num_layers - 1) 13 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 14 | 15 | def forward(self, x): 16 | B, N, D = x.size() 17 | x = x.reshape(B*N, D) 18 | for i, layer in enumerate(self.layers): 19 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 20 | x = x.view(B, N, self.output_dim) 21 | return x 22 | -------------------------------------------------------------------------------- /code/learn/models/ops/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | MultiScale*/ -------------------------------------------------------------------------------- /code/learn/models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /code/learn/models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | -------------------------------------------------------------------------------- /code/learn/models/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python3 setup.py build install 11 | -------------------------------------------------------------------------------- /code/learn/models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /code/learn/models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError( 27 | "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)) 28 | ) 29 | return (n & (n - 1) == 0) and n != 0 30 | 31 | 32 | class MSDeformAttn(nn.Module): 33 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 34 | """ 35 | Multi-Scale Deformable Attention Module 36 | :param d_model hidden dimension 37 | :param n_levels number of feature levels 38 | :param n_heads number of attention heads 39 | :param n_points number of sampling points per attention head per feature level 40 | """ 41 | super().__init__() 42 | if d_model % n_heads != 0: 43 | raise ValueError( 44 | "d_model must be divisible by n_heads, but got {} and {}".format( 45 | d_model, n_heads 46 | ) 47 | ) 48 | _d_per_head = d_model // n_heads 49 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 50 | if not _is_power_of_2(_d_per_head): 51 | warnings.warn( 52 | "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 53 | "which is more efficient in our CUDA implementation." 54 | ) 55 | 56 | self.im2col_step = 64 57 | 58 | self.d_model = d_model 59 | self.n_levels = n_levels 60 | self.n_heads = n_heads 61 | self.n_points = n_points 62 | 63 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 64 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 65 | self.value_proj = nn.Linear(d_model, d_model) 66 | self.output_proj = nn.Linear(d_model, d_model) 67 | 68 | self._reset_parameters() 69 | 70 | def _reset_parameters(self): 71 | constant_(self.sampling_offsets.weight.data, 0.0) 72 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * ( 73 | 2.0 * math.pi / self.n_heads 74 | ) 75 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 76 | grid_init = ( 77 | (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) 78 | .view(self.n_heads, 1, 1, 2) 79 | .repeat(1, self.n_levels, self.n_points, 1) 80 | ) 81 | for i in range(self.n_points): 82 | grid_init[:, :, i, :] *= i + 1 83 | with torch.no_grad(): 84 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 85 | constant_(self.attention_weights.weight.data, 0.0) 86 | constant_(self.attention_weights.bias.data, 0.0) 87 | xavier_uniform_(self.value_proj.weight.data) 88 | constant_(self.value_proj.bias.data, 0.0) 89 | xavier_uniform_(self.output_proj.weight.data) 90 | constant_(self.output_proj.bias.data, 0.0) 91 | 92 | def forward( 93 | self, 94 | query, 95 | reference_points, 96 | input_flatten, 97 | input_spatial_shapes, 98 | input_level_start_index, 99 | input_padding_mask=None, 100 | return_sample_locs=False, 101 | ): 102 | """ 103 | :param query (N, Length_{query}, C) 104 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 105 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 106 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 107 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 108 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 109 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 110 | 111 | :return output (N, Length_{query}, C) 112 | """ 113 | N, Len_q, _ = query.shape 114 | N, Len_in, _ = input_flatten.shape 115 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 116 | 117 | value = self.value_proj(input_flatten) 118 | if input_padding_mask is not None: 119 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 120 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 121 | sampling_offsets = self.sampling_offsets(query).view( 122 | N, Len_q, self.n_heads, self.n_levels, self.n_points, 2 123 | ) 124 | attention_weights = self.attention_weights(query).view( 125 | N, Len_q, self.n_heads, self.n_levels * self.n_points 126 | ) 127 | attention_weights = F.softmax(attention_weights, -1).view( 128 | N, Len_q, self.n_heads, self.n_levels, self.n_points 129 | ) 130 | # N, Len_q, n_heads, n_levels, n_points, 2 131 | if reference_points.shape[-1] == 2: 132 | offset_normalizer = torch.stack( 133 | [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1 134 | ) 135 | sampling_locations = ( 136 | reference_points[:, :, None, :, None, :] 137 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 138 | ) 139 | elif reference_points.shape[-1] == 4: 140 | sampling_locations = ( 141 | reference_points[:, :, None, :, None, :2] 142 | + sampling_offsets 143 | / self.n_points 144 | * reference_points[:, :, None, :, None, 2:] 145 | * 0.5 146 | ) 147 | else: 148 | raise ValueError( 149 | "Last dim of reference_points must be 2 or 4, but get {} instead.".format( 150 | reference_points.shape[-1] 151 | ) 152 | ) 153 | output = MSDeformAttnFunction.apply( 154 | value, 155 | input_spatial_shapes, 156 | input_level_start_index, 157 | sampling_locations, 158 | attention_weights, 159 | self.im2col_step, 160 | ) 161 | output = self.output_proj(output) 162 | 163 | if return_sample_locs: 164 | return output, sampling_locations.detach().cpu() 165 | else: 166 | return output 167 | -------------------------------------------------------------------------------- /code/learn/models/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=1", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /code/learn/models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /code/learn/models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /code/learn/models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /code/learn/models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /code/learn/models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /code/learn/models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /code/learn/models/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /code/learn/models/order_class_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from json import encoder 3 | from sys import flags 4 | import cv2 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | from tqdm import tqdm 9 | from models.mlp import MLP 10 | 11 | from models.deformable_transformer_full import ( 12 | DeformableTransformerEncoderLayer, 13 | DeformableTransformerEncoder, 14 | DeformableTransformerDecoder, 15 | DeformableTransformerDecoderLayer, 16 | DeformableAttnDecoderLayer, 17 | ) 18 | from models.ops.modules import MSDeformAttn 19 | from models.corner_models import PositionEmbeddingSine 20 | from models.unet import ResNetBackbone 21 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ 22 | import torch.nn.functional as F 23 | 24 | # from utils.misc import NestedTensor 25 | from models.utils import pos_encode_1d, pos_encode_2d 26 | 27 | from shapely.geometry import Point, LineString, box, MultiLineString 28 | from shapely import affinity 29 | 30 | 31 | def unnormalize_edges(edges, normalize_param): 32 | lines = [((x0, y0), (x1, y1)) for (x0, y0, x1, y1) in edges] 33 | lines = MultiLineString(lines) 34 | 35 | # normalize so longest edge is 1000, and to not change aspect ratio 36 | (xfact, yfact) = normalize_param["scale"] 37 | xfact = 1 / xfact 38 | yfact = 1 / yfact 39 | lines = affinity.scale(lines, xfact=xfact, yfact=yfact, origin=(0, 0)) 40 | 41 | # center edges around 0 42 | (xoff, yoff) = normalize_param["translation"] 43 | lines = affinity.translate(lines, xoff=-xoff, yoff=-yoff) 44 | 45 | # rotation 46 | if "rotation" in normalize_param.keys(): 47 | angle = normalize_param["rotation"] 48 | lines = affinity.rotate(lines, -angle, origin=(0, 0)) 49 | 50 | new_coords = [list(line.coords) for line in lines.geoms] 51 | new_coords = np.array(new_coords).reshape(-1, 4) 52 | 53 | return new_coords 54 | 55 | 56 | def vis_data(data): 57 | import matplotlib.pyplot as plt 58 | 59 | bs = len(data["floor_name"]) 60 | 61 | for b_i in range(bs): 62 | edge_coords = data["edge_coords"][b_i].cpu().numpy() 63 | edge_mask = data["edge_coords_mask"][b_i].cpu().numpy() 64 | edge_order = data["edge_order"][b_i].cpu().numpy() 65 | label = data["label"][b_i].cpu().numpy() 66 | 67 | max_order = edge_order.max() 68 | 69 | for edge_i, (x0, y0, x1, y1) in enumerate(edge_coords): 70 | if not edge_mask[edge_i]: 71 | if edge_order[edge_i] == max_order: 72 | plt.plot([x0, x1], [y0, y1], "-or") 73 | else: 74 | plt.plot([x0, x1], [y0, y1], "-oy") 75 | plt.text( 76 | (x0 + x1) / 2, (y0 + y1) / 2, str(edge_order[edge_i]), color="c" 77 | ) 78 | 79 | plt.title("%d" % label) 80 | plt.tight_layout() 81 | plt.show() 82 | plt.close() 83 | 84 | 85 | class EdgeTransformer(nn.Module): 86 | def __init__( 87 | self, 88 | d_model=512, 89 | nhead=8, 90 | num_encoder_layers=6, 91 | num_decoder_layers=6, 92 | dim_feedforward=1024, 93 | dropout=0.1, 94 | activation="relu", 95 | return_intermediate_dec=False, 96 | num_feature_levels=4, 97 | dec_n_points=4, 98 | enc_n_points=4, 99 | ): 100 | super(EdgeTransformer, self).__init__() 101 | self.d_model = d_model 102 | 103 | decoder_1_layer = DeformableTransformerDecoderLayer( 104 | d_model, 105 | dim_feedforward, 106 | dropout, 107 | activation, 108 | num_feature_levels, 109 | nhead, 110 | dec_n_points, 111 | ) 112 | self.decoder_1 = DeformableTransformerDecoder( 113 | decoder_1_layer, num_decoder_layers, return_intermediate_dec, with_sa=True 114 | ) 115 | 116 | self.type_embed = nn.Embedding(3, 128) 117 | self.input_head = nn.Linear(128 * 4, d_model) 118 | 119 | # self.final_head = MLP( 120 | # input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2 121 | # ) 122 | self.final_head = nn.Linear(d_model, 2) 123 | 124 | self._reset_parameters() 125 | 126 | def _reset_parameters(self): 127 | for p in self.parameters(): 128 | if p.dim() > 1: 129 | nn.init.xavier_uniform_(p) 130 | for m in self.modules(): 131 | if isinstance(m, MSDeformAttn): 132 | m._reset_parameters() 133 | 134 | def get_geom_feats(self, coords): 135 | (bs, N, _) = coords.shape 136 | _coords = coords.reshape(bs * N, -1) 137 | _geom_enc_a = pos_encode_2d(x=_coords[:, 0], y=_coords[:, 1]) 138 | _geom_enc_b = pos_encode_2d(x=_coords[:, 2], y=_coords[:, 3]) 139 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1) 140 | geom_feats = _geom_feats.reshape(bs, N, -1) 141 | 142 | return geom_feats 143 | 144 | def forward(self, data): 145 | # vis_data(data) 146 | 147 | # obtain edge positional features 148 | edge_coords = data["edge_coords"] 149 | edge_mask = data["edge_coords_mask"] 150 | edge_order = data["edge_order"] 151 | 152 | dtype = edge_coords.dtype 153 | device = edge_coords.device 154 | 155 | # three types of nodes 156 | # 1. dummy node 157 | # 2. modelled node 158 | # 3. sequence node 159 | 160 | bs = len(edge_coords) 161 | 162 | # geom 163 | geom_feats = self.get_geom_feats(edge_coords) 164 | dummy_geom = torch.zeros( 165 | [bs, 1, geom_feats.shape[-1]], dtype=dtype, device=device 166 | ) 167 | geom_feats = torch.cat([dummy_geom, geom_feats], dim=1) 168 | 169 | # node type (NEED TO BE BEFORE ORDER) 170 | max_order = 10 171 | node_type = torch.zeros_like(edge_order) 172 | node_type[edge_order < max_order] = 1 173 | node_type[edge_order == max_order] = 2 174 | assert not (node_type == 0).any() 175 | 176 | dummy_type = torch.zeros([bs, 1], dtype=dtype, device=device) 177 | node_type = torch.cat([dummy_type, node_type], dim=1) 178 | type_feats = self.type_embed(node_type.long()) 179 | 180 | # order 181 | dummy_order = torch.zeros([bs, 1], dtype=dtype, device=device) 182 | edge_order = torch.cat([dummy_order, edge_order], dim=1) 183 | order_feats = pos_encode_1d(edge_order.flatten()) 184 | order_feats = order_feats.reshape(edge_order.shape[0], edge_order.shape[1], -1) 185 | 186 | # need to pad the mask 187 | dummy_mask = torch.zeros([bs, 1], dtype=edge_mask.dtype, device=device) 188 | edge_mask = torch.cat([dummy_mask, edge_mask], dim=-1) 189 | 190 | # combine features 191 | edge_feats = torch.cat([geom_feats, type_feats, order_feats], dim=-1) 192 | edge_feats = self.input_head(edge_feats) 193 | 194 | # first do self-attention without any flags 195 | hs = self.decoder_1( 196 | edge_feats=edge_feats, 197 | geom_feats=geom_feats, 198 | image_feats=None, 199 | key_padding_mask=edge_mask, 200 | get_image_feat=False, 201 | ) 202 | 203 | hs = self.final_head(hs) 204 | 205 | # return only the dummy node 206 | return hs[:, 0] 207 | 208 | 209 | class EdgeTransformer2(nn.Module): 210 | def __init__( 211 | self, 212 | d_model=256, 213 | nhead=8, 214 | num_encoder_layers=3, 215 | num_decoder_layers=3, 216 | dim_feedforward=1024, 217 | dropout=0.1, 218 | activation="relu", 219 | return_intermediate_dec=False, 220 | num_feature_levels=4, 221 | dec_n_points=4, 222 | enc_n_points=4, 223 | ): 224 | super(EdgeTransformer2, self).__init__() 225 | self.d_model = d_model 226 | 227 | encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead) 228 | self.encoder = nn.TransformerEncoder( 229 | encoder_layer, num_layers=num_encoder_layers 230 | ) 231 | 232 | decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead) 233 | self.decoder = nn.TransformerDecoder( 234 | decoder_layer, num_layers=num_decoder_layers 235 | ) 236 | 237 | self.type_embed = nn.Embedding(3, 128) 238 | self.input_head = nn.Linear(128 * 3 + 2, d_model) 239 | 240 | # self.final_head = MLP( 241 | # input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2 242 | # ) 243 | self.final_head = nn.Linear(d_model, 2) 244 | 245 | self._reset_parameters() 246 | 247 | def _reset_parameters(self): 248 | for p in self.parameters(): 249 | if p.dim() > 1: 250 | nn.init.xavier_uniform_(p) 251 | # for m in self.modules(): 252 | # if isinstance(m, MSDeformAttn): 253 | # m._reset_parameters() 254 | 255 | def forward(self, data): 256 | # vis_data(data) 257 | 258 | # obtain edge positional features 259 | edge_coords = data["edge_coords"] 260 | edge_mask = data["edge_coords_mask"] 261 | modelled_coords = data["modelled_coords"] 262 | modelled_mask = data["modelled_coords_mask"] 263 | 264 | dtype = edge_coords.dtype 265 | device = edge_coords.device 266 | 267 | # first self-attention among the modelled edges 268 | (bs, N, _) = modelled_coords.shape 269 | _modelled_coords = modelled_coords.reshape(bs * N, -1) 270 | _geom_enc_a = pos_encode_2d(x=_modelled_coords[:, 0], y=_modelled_coords[:, 1]) 271 | _geom_enc_b = pos_encode_2d(x=_modelled_coords[:, 2], y=_modelled_coords[:, 3]) 272 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1) 273 | modelled_geom = _geom_feats.reshape(bs, N, -1) 274 | modelled_geom = modelled_geom.permute([1, 0, 2]) # bs,S,E -> S,bs,E 275 | 276 | memory = self.encoder(modelled_geom, src_key_padding_mask=modelled_mask) 277 | 278 | # then SA and CA among the sequence edges 279 | (bs, N, _) = edge_coords.shape 280 | _edge_coords = edge_coords.reshape(bs * N, -1) 281 | _geom_enc_a = pos_encode_2d(x=_edge_coords[:, 0], y=_edge_coords[:, 1]) 282 | _geom_enc_b = pos_encode_2d(x=_edge_coords[:, 2], y=_edge_coords[:, 3]) 283 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1) 284 | seq_geom = _geom_feats.reshape(bs, N, -1) 285 | seq_geom = seq_geom.permute([1, 0, 2]) # bs,S,E -> S,bs,E 286 | 287 | dummy_geom = torch.zeros( 288 | [1, bs, seq_geom.shape[-1]], dtype=dtype, device=device 289 | ) 290 | seq_geom = torch.cat([dummy_geom, seq_geom], dim=0) 291 | 292 | # for edge positions, dummy node is 0, modelled edges are 0, sequence start at 1 293 | pos_inds = torch.arange(N + 1, dtype=dtype, device=device) 294 | pos_feats = pos_encode_1d(pos_inds) 295 | pos_feats = pos_feats.unsqueeze(1).repeat(1, bs, 1) 296 | 297 | # for edge type, dummy node is 0, modelled edges are 1, sequence start at 2 298 | flag_feats = torch.zeros([N + 1, bs], dtype=torch.int64, device=device) 299 | flag_feats[1:, :] = 1 300 | flag_feats = nn.functional.one_hot(flag_feats, num_classes=2) 301 | 302 | edge_feats = torch.cat([seq_geom, pos_feats, flag_feats], dim=-1) 303 | edge_feats = self.input_head(edge_feats) 304 | 305 | # need to pad the mask 306 | dummy_mask = torch.zeros([bs, 1], dtype=edge_mask.dtype, device=device) 307 | edge_mask = torch.cat([dummy_mask, edge_mask], dim=-1) 308 | 309 | # first do self-attention without any flags 310 | hs = self.decoder( 311 | tgt=edge_feats, 312 | memory=memory, 313 | tgt_key_padding_mask=edge_mask, 314 | memory_key_padding_mask=modelled_mask, 315 | ) 316 | return self.final_head(hs[0]) 317 | 318 | # hs = self.final_head(hs) 319 | 320 | # # return only the dummy node 321 | # return hs[:, 0] 322 | -------------------------------------------------------------------------------- /code/learn/models/order_metric_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from json import encoder 3 | from sys import flags 4 | import cv2 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | from tqdm import tqdm 9 | from models.mlp import MLP 10 | 11 | from models.deformable_transformer_full import ( 12 | DeformableTransformerEncoderLayer, 13 | DeformableTransformerEncoder, 14 | DeformableTransformerDecoder, 15 | DeformableTransformerDecoderLayer, 16 | DeformableAttnDecoderLayer, 17 | ) 18 | from models.ops.modules import MSDeformAttn 19 | from models.corner_models import PositionEmbeddingSine 20 | from models.unet import ResNetBackbone 21 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ 22 | import torch.nn.functional as F 23 | 24 | # from utils.misc import NestedTensor 25 | from models.utils import pos_encode_1d, pos_encode_2d 26 | 27 | from shapely.geometry import Point, LineString, box, MultiLineString 28 | from shapely import affinity 29 | 30 | 31 | def unnormalize_edges(edges, normalize_param): 32 | lines = [((x0, y0), (x1, y1)) for (x0, y0, x1, y1) in edges] 33 | lines = MultiLineString(lines) 34 | 35 | # normalize so longest edge is 1000, and to not change aspect ratio 36 | (xfact, yfact) = normalize_param["scale"] 37 | xfact = 1 / xfact 38 | yfact = 1 / yfact 39 | lines = affinity.scale(lines, xfact=xfact, yfact=yfact, origin=(0, 0)) 40 | 41 | # center edges around 0 42 | (xoff, yoff) = normalize_param["translation"] 43 | lines = affinity.translate(lines, xoff=-xoff, yoff=-yoff) 44 | 45 | # rotation 46 | if "rotation" in normalize_param.keys(): 47 | angle = normalize_param["rotation"] 48 | lines = affinity.rotate(lines, -angle, origin=(0, 0)) 49 | 50 | new_coords = [list(line.coords) for line in lines.geoms] 51 | new_coords = np.array(new_coords).reshape(-1, 4) 52 | 53 | return new_coords 54 | 55 | 56 | def vis_data(data): 57 | import matplotlib.pyplot as plt 58 | 59 | bs = len(data["floor_name"]) 60 | 61 | for b_i in range(bs): 62 | edge_coords = data["edge_coords"][b_i].cpu().numpy() 63 | edge_mask = data["edge_coords_mask"][b_i].cpu().numpy() 64 | edge_order = data["edge_order"][b_i].cpu().numpy() 65 | label = data["label"][b_i].cpu().numpy() 66 | 67 | for edge_i, (x0, y0, x1, y1) in enumerate(edge_coords): 68 | if not edge_mask[edge_i]: 69 | if label[edge_i] == 0: 70 | assert edge_order[edge_i] == 0 71 | plt.plot([x0, x1], [y0, y1], "-or") 72 | elif label[edge_i] == 3: 73 | assert edge_order[edge_i] > 0 74 | plt.plot([x0, x1], [y0, y1], "--oc") 75 | plt.text( 76 | (x0 + x1) / 2, (y0 + y1) / 2, str(edge_order[edge_i]), color="c" 77 | ) 78 | elif label[edge_i] == 1: 79 | assert edge_order[edge_i] == 1 80 | plt.plot([x0, x1], [y0, y1], "-oy") 81 | plt.text( 82 | (x0 + x1) / 2, (y0 + y1) / 2, str(edge_order[edge_i]), color="c" 83 | ) 84 | elif label[edge_i] == 2: 85 | assert edge_order[edge_i] == 0 86 | plt.plot([x0, x1], [y0, y1], "-og") 87 | plt.text( 88 | (x0 + x1) / 2, (y0 + y1) / 2, str(edge_order[edge_i]), color="c" 89 | ) 90 | else: 91 | raise Exception 92 | 93 | # plt.title("%d" % label) 94 | plt.tight_layout() 95 | plt.show() 96 | plt.close() 97 | 98 | 99 | class EdgeTransformer(nn.Module): 100 | def __init__( 101 | self, 102 | d_model=256, 103 | nhead=8, 104 | num_encoder_layers=6, 105 | num_decoder_layers=6, 106 | dim_feedforward=1024, 107 | dropout=0.1, 108 | activation="relu", 109 | return_intermediate_dec=False, 110 | num_feature_levels=4, 111 | dec_n_points=4, 112 | enc_n_points=4, 113 | ): 114 | super(EdgeTransformer, self).__init__() 115 | self.d_model = d_model 116 | 117 | decoder_1_layer = DeformableTransformerDecoderLayer( 118 | d_model, 119 | dim_feedforward, 120 | dropout, 121 | activation, 122 | num_feature_levels, 123 | nhead, 124 | dec_n_points, 125 | ) 126 | self.decoder_1 = DeformableTransformerDecoder( 127 | decoder_1_layer, num_decoder_layers, return_intermediate_dec, with_sa=True 128 | ) 129 | 130 | self.type_embed = nn.Embedding(3, 128) 131 | self.input_head = nn.Linear(128 * 4, d_model) 132 | 133 | # self.final_head = MLP( 134 | # input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2 135 | # ) 136 | self.final_head = nn.Linear(d_model, d_model) 137 | 138 | self._reset_parameters() 139 | 140 | def _reset_parameters(self): 141 | for p in self.parameters(): 142 | if p.dim() > 1: 143 | nn.init.xavier_uniform_(p) 144 | for m in self.modules(): 145 | if isinstance(m, MSDeformAttn): 146 | m._reset_parameters() 147 | 148 | def get_geom_feats(self, coords): 149 | (bs, N, _) = coords.shape 150 | _coords = coords.reshape(bs * N, -1) 151 | _geom_enc_a = pos_encode_2d(x=_coords[:, 0], y=_coords[:, 1]) 152 | _geom_enc_b = pos_encode_2d(x=_coords[:, 2], y=_coords[:, 3]) 153 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1) 154 | geom_feats = _geom_feats.reshape(bs, N, -1) 155 | 156 | return geom_feats 157 | 158 | def forward(self, data): 159 | # vis_data(data) 160 | 161 | # obtain edge positional features 162 | edge_coords = data["edge_coords"] 163 | edge_mask = data["edge_coords_mask"] 164 | edge_order = data["edge_order"] 165 | 166 | dtype = edge_coords.dtype 167 | device = edge_coords.device 168 | 169 | # three types of nodes 170 | # 1. dummy node 171 | # 2. modelled node 172 | # 3. sequence node 173 | 174 | bs = len(edge_coords) 175 | 176 | # geom 177 | geom_feats = self.get_geom_feats(edge_coords) 178 | 179 | # node type (NEED TO BE BEFORE ORDER) 180 | max_order = 10 181 | node_type = torch.ones_like(edge_order) 182 | node_type[edge_order == 0] = 0 183 | node_type[edge_order == max_order] = 2 184 | type_feats = self.type_embed(node_type.long()) 185 | 186 | # order 187 | order_feats = pos_encode_1d(edge_order.flatten()) 188 | order_feats = order_feats.reshape(edge_order.shape[0], edge_order.shape[1], -1) 189 | 190 | # combine features and also add a dummy node 191 | edge_feats = torch.cat([geom_feats, type_feats, order_feats], dim=-1) 192 | edge_feats = self.input_head(edge_feats) 193 | 194 | # first do self-attention without any flags 195 | hs = self.decoder_1( 196 | edge_feats=edge_feats, 197 | geom_feats=geom_feats, 198 | image_feats=None, 199 | key_padding_mask=edge_mask, 200 | get_image_feat=False, 201 | ) 202 | 203 | hs = self.final_head(hs) 204 | 205 | # return only the dummy node 206 | return hs 207 | 208 | 209 | class EdgeTransformer2(nn.Module): 210 | def __init__( 211 | self, 212 | d_model=256, 213 | nhead=8, 214 | num_encoder_layers=3, 215 | num_decoder_layers=3, 216 | dim_feedforward=1024, 217 | dropout=0.1, 218 | activation="relu", 219 | return_intermediate_dec=False, 220 | num_feature_levels=4, 221 | dec_n_points=4, 222 | enc_n_points=4, 223 | ): 224 | super(EdgeTransformer2, self).__init__() 225 | self.d_model = d_model 226 | 227 | encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead) 228 | self.encoder = nn.TransformerEncoder( 229 | encoder_layer, num_layers=num_encoder_layers 230 | ) 231 | 232 | decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead) 233 | self.decoder = nn.TransformerDecoder( 234 | decoder_layer, num_layers=num_decoder_layers 235 | ) 236 | 237 | self.type_embed = nn.Embedding(3, 128) 238 | self.input_head = nn.Linear(128 * 3 + 2, d_model) 239 | 240 | # self.final_head = MLP( 241 | # input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2 242 | # ) 243 | self.final_head = nn.Linear(d_model, 2) 244 | 245 | self._reset_parameters() 246 | 247 | def _reset_parameters(self): 248 | for p in self.parameters(): 249 | if p.dim() > 1: 250 | nn.init.xavier_uniform_(p) 251 | # for m in self.modules(): 252 | # if isinstance(m, MSDeformAttn): 253 | # m._reset_parameters() 254 | 255 | def forward(self, data): 256 | # vis_data(data) 257 | 258 | # obtain edge positional features 259 | edge_coords = data["edge_coords"] 260 | edge_mask = data["edge_coords_mask"] 261 | modelled_coords = data["modelled_coords"] 262 | modelled_mask = data["modelled_coords_mask"] 263 | 264 | dtype = edge_coords.dtype 265 | device = edge_coords.device 266 | 267 | # first self-attention among the modelled edges 268 | (bs, N, _) = modelled_coords.shape 269 | _modelled_coords = modelled_coords.reshape(bs * N, -1) 270 | _geom_enc_a = pos_encode_2d(x=_modelled_coords[:, 0], y=_modelled_coords[:, 1]) 271 | _geom_enc_b = pos_encode_2d(x=_modelled_coords[:, 2], y=_modelled_coords[:, 3]) 272 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1) 273 | modelled_geom = _geom_feats.reshape(bs, N, -1) 274 | modelled_geom = modelled_geom.permute([1, 0, 2]) # bs,S,E -> S,bs,E 275 | 276 | memory = self.encoder(modelled_geom, src_key_padding_mask=modelled_mask) 277 | 278 | # then SA and CA among the sequence edges 279 | (bs, N, _) = edge_coords.shape 280 | _edge_coords = edge_coords.reshape(bs * N, -1) 281 | _geom_enc_a = pos_encode_2d(x=_edge_coords[:, 0], y=_edge_coords[:, 1]) 282 | _geom_enc_b = pos_encode_2d(x=_edge_coords[:, 2], y=_edge_coords[:, 3]) 283 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1) 284 | seq_geom = _geom_feats.reshape(bs, N, -1) 285 | seq_geom = seq_geom.permute([1, 0, 2]) # bs,S,E -> S,bs,E 286 | 287 | dummy_geom = torch.zeros( 288 | [1, bs, seq_geom.shape[-1]], dtype=dtype, device=device 289 | ) 290 | seq_geom = torch.cat([dummy_geom, seq_geom], dim=0) 291 | 292 | # for edge positions, dummy node is 0, modelled edges are 0, sequence start at 1 293 | pos_inds = torch.arange(N + 1, dtype=dtype, device=device) 294 | pos_feats = pos_encode_1d(pos_inds) 295 | pos_feats = pos_feats.unsqueeze(1).repeat(1, bs, 1) 296 | 297 | # for edge type, dummy node is 0, modelled edges are 1, sequence start at 2 298 | flag_feats = torch.zeros([N + 1, bs], dtype=torch.int64, device=device) 299 | flag_feats[1:, :] = 1 300 | flag_feats = nn.functional.one_hot(flag_feats, num_classes=2) 301 | 302 | edge_feats = torch.cat([seq_geom, pos_feats, flag_feats], dim=-1) 303 | edge_feats = self.input_head(edge_feats) 304 | 305 | # need to pad the mask 306 | dummy_mask = torch.zeros([bs, 1], dtype=edge_mask.dtype, device=device) 307 | edge_mask = torch.cat([dummy_mask, edge_mask], dim=-1) 308 | 309 | # first do self-attention without any flags 310 | hs = self.decoder( 311 | tgt=edge_feats, 312 | memory=memory, 313 | tgt_key_padding_mask=edge_mask, 314 | memory_key_padding_mask=modelled_mask, 315 | ) 316 | return self.final_head(hs[0]) 317 | 318 | # hs = self.final_head(hs) 319 | 320 | # # return only the dummy node 321 | # return hs[:, 0] 322 | -------------------------------------------------------------------------------- /code/learn/models/stacked_hg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hourglass network inserted in the pre-activated Resnet 3 | Use lr=0.01 for current version 4 | (c) Nan Xue (HAWP) 5 | (c) Yichao Zhou (LCNN) 6 | (c) YANG, Wei 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | __all__ = ["HourglassNet", "hg"] 13 | 14 | 15 | class Bottleneck2D(nn.Module): 16 | expansion = 2 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(Bottleneck2D, self).__init__() 20 | 21 | self.bn1 = nn.BatchNorm2d(inplanes) 22 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1) 25 | self.bn3 = nn.BatchNorm2d(planes) 26 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.bn1(x) 35 | out = self.relu(out) 36 | out = self.conv1(out) 37 | 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | out = self.conv2(out) 41 | 42 | out = self.bn3(out) 43 | out = self.relu(out) 44 | out = self.conv3(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | 51 | return out 52 | 53 | 54 | class Hourglass(nn.Module): 55 | def __init__(self, block, num_blocks, planes, depth): 56 | super(Hourglass, self).__init__() 57 | self.depth = depth 58 | self.block = block 59 | self.hg = self._make_hour_glass(block, num_blocks, planes, depth) 60 | 61 | def _make_residual(self, block, num_blocks, planes): 62 | layers = [] 63 | for i in range(0, num_blocks): 64 | layers.append(block(planes * block.expansion, planes)) 65 | return nn.Sequential(*layers) 66 | 67 | def _make_hour_glass(self, block, num_blocks, planes, depth): 68 | hg = [] 69 | for i in range(depth): 70 | res = [] 71 | for j in range(3): 72 | res.append(self._make_residual(block, num_blocks, planes)) 73 | if i == 0: 74 | res.append(self._make_residual(block, num_blocks, planes)) 75 | hg.append(nn.ModuleList(res)) 76 | return nn.ModuleList(hg) 77 | 78 | def _hour_glass_forward(self, n, x): 79 | up1 = self.hg[n - 1][0](x) 80 | low1 = F.max_pool2d(x, 2, stride=2) 81 | low1 = self.hg[n - 1][1](low1) 82 | 83 | if n > 1: 84 | low2 = self._hour_glass_forward(n - 1, low1) 85 | else: 86 | low2 = self.hg[n - 1][3](low1) 87 | low3 = self.hg[n - 1][2](low2) 88 | up2 = F.interpolate(low3, scale_factor=2) 89 | out = up1 + up2 90 | return out 91 | 92 | def forward(self, x): 93 | return self._hour_glass_forward(self.depth, x) 94 | 95 | 96 | class HourglassNet(nn.Module): 97 | """Hourglass model from Newell et al ECCV 2016""" 98 | 99 | def __init__(self, inplanes, num_feats, block, head, depth, num_stacks, num_blocks, num_classes): 100 | super(HourglassNet, self).__init__() 101 | 102 | self.inplanes = inplanes 103 | self.num_feats = num_feats 104 | self.num_stacks = num_stacks 105 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3) 106 | self.bn1 = nn.BatchNorm2d(self.inplanes) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.layer1 = self._make_residual(block, self.inplanes, 1) 109 | self.layer2 = self._make_residual(block, self.inplanes, 1) 110 | self.layer3 = self._make_residual(block, self.num_feats, 1) 111 | self.maxpool = nn.MaxPool2d(2, stride=2) 112 | 113 | # build hourglass modules 114 | ch = self.num_feats * block.expansion 115 | # vpts = [] 116 | hg, res, fc, score, fc_, score_ = [], [], [], [], [], [] 117 | for i in range(num_stacks): 118 | hg.append(Hourglass(block, num_blocks, self.num_feats, depth)) 119 | res.append(self._make_residual(block, self.num_feats, num_blocks)) 120 | fc.append(self._make_fc(ch, ch)) 121 | score.append(head(ch, num_classes)) 122 | # vpts.append(VptsHead(ch)) 123 | # vpts.append(nn.Linear(ch, 9)) 124 | # score.append(nn.Conv2d(ch, num_classes, kernel_size=1)) 125 | # score[i].bias.data[0] += 4.6 126 | # score[i].bias.data[2] += 4.6 127 | if i < num_stacks - 1: 128 | fc_.append(nn.Conv2d(ch, ch, kernel_size=1)) 129 | score_.append(nn.Conv2d(num_classes, ch, kernel_size=1)) 130 | self.hg = nn.ModuleList(hg) 131 | self.res = nn.ModuleList(res) 132 | self.fc = nn.ModuleList(fc) 133 | self.score = nn.ModuleList(score) 134 | # self.vpts = nn.ModuleList(vpts) 135 | self.fc_ = nn.ModuleList(fc_) 136 | self.score_ = nn.ModuleList(score_) 137 | 138 | def _make_residual(self, block, planes, blocks, stride=1): 139 | downsample = None 140 | if stride != 1 or self.inplanes != planes * block.expansion: 141 | downsample = nn.Sequential( 142 | nn.Conv2d( 143 | self.inplanes, 144 | planes * block.expansion, 145 | kernel_size=1, 146 | stride=stride, 147 | ) 148 | ) 149 | 150 | layers = [] 151 | layers.append(block(self.inplanes, planes, stride, downsample)) 152 | self.inplanes = planes * block.expansion 153 | for i in range(1, blocks): 154 | layers.append(block(self.inplanes, planes)) 155 | 156 | return nn.Sequential(*layers) 157 | 158 | def _make_fc(self, inplanes, outplanes): 159 | bn = nn.BatchNorm2d(inplanes) 160 | conv = nn.Conv2d(inplanes, outplanes, kernel_size=1) 161 | return nn.Sequential(conv, bn, self.relu) 162 | 163 | def forward(self, x): 164 | out = [] 165 | x = self.conv1(x) 166 | x = self.bn1(x) 167 | x = self.relu(x) 168 | 169 | x = self.layer1(x) 170 | x = self.maxpool(x) 171 | x = self.layer2(x) 172 | x = self.layer3(x) 173 | 174 | for i in range(self.num_stacks): 175 | y = self.hg[i](x) 176 | y = self.res[i](y) 177 | y = self.fc[i](y) 178 | score = self.score[i](y) 179 | out.append(score) 180 | 181 | if i < self.num_stacks - 1: 182 | fc_ = self.fc_[i](y) 183 | score_ = self.score_[i](score) 184 | x = x + fc_ + score_ 185 | 186 | return out[::-1], y 187 | 188 | def train(self, mode=True): 189 | # Override train so that the training mode is set as we want 190 | nn.Module.train(self, mode) 191 | if mode: 192 | # fix all bn layers 193 | def set_bn_eval(m): 194 | classname = m.__class__.__name__ 195 | if classname.find('BatchNorm') != -1: 196 | m.eval() 197 | 198 | self.apply(set_bn_eval) 199 | 200 | 201 | class MultitaskHead(nn.Module): 202 | def __init__(self, input_channels, num_class, head_size): 203 | super(MultitaskHead, self).__init__() 204 | 205 | m = int(input_channels / 4) 206 | heads = [] 207 | for output_channels in sum(head_size, []): 208 | heads.append( 209 | nn.Sequential( 210 | nn.Conv2d(input_channels, m, kernel_size=3, padding=1), 211 | nn.ReLU(inplace=True), 212 | nn.Conv2d(m, output_channels, kernel_size=1), 213 | ) 214 | ) 215 | self.heads = nn.ModuleList(heads) 216 | assert num_class == sum(sum(head_size, [])) 217 | 218 | def forward(self, x): 219 | return torch.cat([head(x) for head in self.heads], dim=1) 220 | 221 | 222 | def build_hg(): 223 | inplanes = 64 224 | num_feats = 256 //2 225 | depth = 4 226 | num_stacks = 2 227 | num_blocks = 1 228 | head_size = [[2], [2]] 229 | 230 | out_feature_channels = 256 231 | 232 | num_class = sum(sum(head_size, [])) 233 | model = HourglassNet( 234 | block=Bottleneck2D, 235 | inplanes = inplanes, 236 | num_feats= num_feats, 237 | depth=depth, 238 | head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size), 239 | num_stacks = num_stacks, 240 | num_blocks = num_blocks, 241 | num_classes = num_class) 242 | 243 | model.out_feature_channels = out_feature_channels 244 | 245 | return model 246 | 247 | -------------------------------------------------------------------------------- /code/learn/models/tcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils import weight_norm 4 | 5 | 6 | def get_activation(name): 7 | def hook(model, input, output): 8 | activation[name] = output.detach() 9 | 10 | return hook 11 | 12 | 13 | class Chomp1d(nn.Module): 14 | def __init__(self, chomp_size): 15 | super(Chomp1d, self).__init__() 16 | self.chomp_size = chomp_size 17 | 18 | def forward(self, x): 19 | return x[:, :, : -self.chomp_size].contiguous() 20 | 21 | 22 | class TemporalBlock(nn.Module): 23 | def __init__( 24 | self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2 25 | ): 26 | super(TemporalBlock, self).__init__() 27 | self.conv1 = weight_norm( 28 | nn.Conv1d( 29 | n_inputs, 30 | n_outputs, 31 | kernel_size, 32 | stride=stride, 33 | padding=padding, 34 | dilation=dilation, 35 | ) 36 | ) 37 | self.chomp1 = Chomp1d(padding) 38 | self.relu1 = nn.ReLU() 39 | self.dropout1 = nn.Dropout(dropout) 40 | 41 | self.conv2 = weight_norm( 42 | nn.Conv1d( 43 | n_outputs, 44 | n_outputs, 45 | kernel_size, 46 | stride=stride, 47 | padding=padding, 48 | dilation=dilation, 49 | ) 50 | ) 51 | self.chomp2 = Chomp1d(padding) 52 | self.relu2 = nn.ReLU() 53 | self.dropout2 = nn.Dropout(dropout) 54 | 55 | self.net = nn.Sequential( 56 | self.conv1, 57 | self.chomp1, 58 | self.relu1, 59 | self.dropout1, 60 | self.conv2, 61 | self.chomp2, 62 | self.relu2, 63 | self.dropout2, 64 | ) 65 | self.downsample = ( 66 | nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None 67 | ) 68 | self.relu = nn.ReLU() 69 | self.init_weights() 70 | 71 | def init_weights(self): 72 | self.conv1.weight.data.normal_(0, 0.01) 73 | self.conv2.weight.data.normal_(0, 0.01) 74 | if self.downsample is not None: 75 | self.downsample.weight.data.normal_(0, 0.01) 76 | 77 | def forward(self, x): 78 | out = self.net(x) 79 | res = x if self.downsample is None else self.downsample(x) 80 | return self.relu(out + res) 81 | 82 | 83 | class TemporalConvNet(nn.Module): 84 | def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): 85 | super(TemporalConvNet, self).__init__() 86 | layers = [] 87 | num_levels = len(num_channels) 88 | for i in range(num_levels): 89 | dilation_size = 2**i 90 | in_channels = num_inputs if i == 0 else num_channels[i - 1] 91 | out_channels = num_channels[i] 92 | layers += [ 93 | TemporalBlock( 94 | in_channels, 95 | out_channels, 96 | kernel_size, 97 | stride=1, 98 | dilation=dilation_size, 99 | padding=(kernel_size - 1) * dilation_size, 100 | dropout=dropout, 101 | ) 102 | ] 103 | 104 | self.network = nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | hidden = {} 108 | 109 | x_ = x 110 | for layer_i, layer in enumerate(self.network.children()): 111 | x_ = layer(x_) 112 | hidden[layer_i] = x_ 113 | 114 | return x_, hidden 115 | 116 | 117 | class TCN(nn.Module): 118 | def __init__(self, input_size, output_size, num_channels, kernel_size, dropout): 119 | super(TCN, self).__init__() 120 | self.tcn = TemporalConvNet( 121 | input_size, num_channels, kernel_size=kernel_size, dropout=dropout 122 | ) 123 | self.linear = nn.Linear(num_channels[-1], output_size) 124 | self.init_weights() 125 | 126 | def init_weights(self): 127 | self.linear.weight.data.normal_(0, 0.01) 128 | 129 | def forward(self, x): 130 | y1, _ = self.tcn(x) 131 | return self.linear(y1.transpose(1, 2)) 132 | 133 | def get_hidden(self, x): 134 | # self.tcn.network[layer_idx].register_forward_hook(get_activation("Dropout")) 135 | _, hidden = self.tcn(x) 136 | return hidden 137 | -------------------------------------------------------------------------------- /code/learn/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | import torch.nn.functional as F 5 | from torchvision.models._utils import IntermediateLayerGetter 6 | from utils.misc import NestedTensor 7 | 8 | 9 | def convrelu(in_channels, out_channels, kernel, padding): 10 | return nn.Sequential( 11 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding), 12 | nn.ReLU(inplace=True), 13 | ) 14 | 15 | 16 | class ResNetBackbone(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.base_model = models.resnet50(pretrained=True) 20 | self.base_layers = list(self.base_model.children()) 21 | 22 | self.conv_original_size0 = convrelu(3, 64, 3, 1) 23 | self.conv_original_size1 = convrelu(64, 64, 3, 1) 24 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) 25 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) 26 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) 27 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) 28 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) 29 | 30 | self.strides = [8, 16, 32] 31 | self.num_channels = [512, 1024, 2048] 32 | 33 | def forward(self, inputs): 34 | x_original = self.conv_original_size0(inputs) 35 | x_original = self.conv_original_size1(x_original) 36 | layer0 = self.layer0(inputs) 37 | layer1 = self.layer1(layer0) 38 | layer2 = self.layer2(layer1) 39 | layer3 = self.layer3(layer2) 40 | layer4 = self.layer4(layer3) 41 | 42 | xs = {"0": layer2, "1": layer3, "2": layer4} 43 | all_feats = {'layer0': layer0, 'layer1': layer1, 'layer2': layer2, 44 | 'layer3': layer3, 'layer4': layer4, 'x_original': x_original} 45 | 46 | mask = torch.zeros(inputs.shape)[:, 0, :, :].to(layer4.device) 47 | return xs, mask, all_feats 48 | 49 | def train(self, mode=True): 50 | # Override train so that the training mode is set as we want 51 | nn.Module.train(self, mode) 52 | if mode: 53 | # fix all bn layers 54 | def set_bn_eval(m): 55 | classname = m.__class__.__name__ 56 | if classname.find('BatchNorm') != -1: 57 | m.eval() 58 | 59 | self.apply(set_bn_eval) 60 | 61 | 62 | class ResNetUNet(nn.Module): 63 | def __init__(self, n_class, out_dim=None, ms_feat=False): 64 | super().__init__() 65 | 66 | self.return_ms_feat = ms_feat 67 | self.out_dim = out_dim 68 | 69 | self.base_model = models.resnet50(pretrained=True) 70 | self.base_layers = list(self.base_model.children()) 71 | 72 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) 73 | # self.layer0_1x1 = convrelu(64, 64, 1, 0) 74 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) 75 | # self.layer1_1x1 = convrelu(256, 256, 1, 0) 76 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) 77 | # self.layer2_1x1 = convrelu(512, 512, 1, 0) 78 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) 79 | # self.layer3_1x1 = convrelu(1024, 1024, 1, 0) 80 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) 81 | # self.layer4_1x1 = convrelu(2048, 2048, 1, 0) 82 | 83 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 84 | 85 | self.conv_up3 = convrelu(1024 + 2048, 1024, 3, 1) 86 | self.conv_up2 = convrelu(512 + 1024, 512, 3, 1) 87 | self.conv_up1 = convrelu(256 + 512, 256, 3, 1) 88 | self.conv_up0 = convrelu(64 + 256, 128, 3, 1) 89 | # self.conv_up1 = convrelu(512, 256, 3, 1) 90 | # self.conv_up0 = convrelu(256, 128, 3, 1) 91 | 92 | self.conv_original_size0 = convrelu(3, 64, 3, 1) 93 | self.conv_original_size1 = convrelu(64, 64, 3, 1) 94 | self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1) 95 | # self.conv_last = nn.Conv2d(128, n_class, 1) 96 | self.conv_last = nn.Conv2d(64, n_class, 1) 97 | if out_dim: 98 | self.conv_out = nn.Conv2d(64, out_dim, 1) 99 | # self.conv_out = nn.Conv2d(128, out_dim, 1) 100 | 101 | # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 102 | self.strides = [8, 16, 32] 103 | self.num_channels = [512, 1024, 2048] 104 | 105 | def forward(self, inputs): 106 | x_original = self.conv_original_size0(inputs) 107 | x_original = self.conv_original_size1(x_original) 108 | 109 | layer0 = self.layer0(inputs) 110 | layer1 = self.layer1(layer0) 111 | layer2 = self.layer2(layer1) 112 | layer3 = self.layer3(layer2) 113 | layer4 = self.layer4(layer3) 114 | 115 | # layer4 = self.layer4_1x1(layer4) 116 | x = self.upsample(layer4) 117 | # layer3 = self.layer3_1x1(layer3) 118 | x = torch.cat([x, layer3], dim=1) 119 | x = self.conv_up3(x) 120 | layer3_up = x 121 | 122 | x = self.upsample(x) 123 | # layer2 = self.layer2_1x1(layer2) 124 | x = torch.cat([x, layer2], dim=1) 125 | x = self.conv_up2(x) 126 | layer2_up = x 127 | 128 | x = self.upsample(x) 129 | # layer1 = self.layer1_1x1(layer1) 130 | x = torch.cat([x, layer1], dim=1) 131 | x = self.conv_up1(x) 132 | 133 | x = self.upsample(x) 134 | # layer0 = self.layer0_1x1(layer0) 135 | x = torch.cat([x, layer0], dim=1) 136 | x = self.conv_up0(x) 137 | 138 | x = self.upsample(x) 139 | x = torch.cat([x, x_original], dim=1) 140 | x = self.conv_original_size2(x) 141 | 142 | out = self.conv_last(x) 143 | out = out.sigmoid().squeeze(1) 144 | 145 | # xs = {"0": layer2, "1": layer3, "2": layer4} 146 | xs = {"0": layer2_up, "1": layer3_up, "2": layer4} 147 | mask = torch.zeros(inputs.shape)[:, 0, :, :].to(layer4.device) 148 | # ms_feats = self.ms_feat(xs, mask) 149 | 150 | if self.return_ms_feat: 151 | if self.out_dim: 152 | out_feat = self.conv_out(x) 153 | out_feat = out_feat.permute(0, 2, 3, 1) 154 | return xs, mask, out, out_feat 155 | else: 156 | return xs, mask, out 157 | else: 158 | return out 159 | 160 | def train(self, mode=True): 161 | # Override train so that the training mode is set as we want 162 | nn.Module.train(self, mode) 163 | if mode: 164 | # fix all bn layers 165 | def set_bn_eval(m): 166 | classname = m.__class__.__name__ 167 | if classname.find('BatchNorm') != -1: 168 | m.eval() 169 | 170 | self.apply(set_bn_eval) 171 | -------------------------------------------------------------------------------- /code/learn/models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def pos_encode_1d(x, d_model=128): 6 | pe = torch.zeros(len(x), d_model, device=x.device) 7 | 8 | div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)) 9 | div_term = div_term.to(x.device) 10 | 11 | pos_w = x.clone().float().unsqueeze(1) 12 | 13 | pe[:, 0:d_model:2] = torch.sin(pos_w * div_term) 14 | pe[:, 1:d_model:2] = torch.cos(pos_w * div_term) 15 | 16 | return pe 17 | 18 | 19 | def pos_encode_2d(x, y, d_model=128): 20 | assert len(x) == len(y) 21 | pe = torch.zeros(len(x), d_model, device=x.device) 22 | 23 | d_model = int(d_model / 2) 24 | div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)) 25 | div_term = div_term.to(x.device) 26 | 27 | pos_w = x.clone().float().unsqueeze(1) 28 | pos_h = y.clone().float().unsqueeze(1) 29 | 30 | pe[:, 0:d_model:2] = torch.sin(pos_w * div_term) 31 | pe[:, 1:d_model:2] = torch.cos(pos_w * div_term) 32 | pe[:, d_model::2] = torch.sin(pos_h * div_term) 33 | pe[:, d_model + 1 :: 2] = torch.cos(pos_h * div_term) 34 | 35 | return pe 36 | 37 | 38 | def get_geom_feats(coords): 39 | (bs, N, _) = coords.shape 40 | _coords = coords.reshape(bs * N, -1) 41 | _geom_enc_a = pos_encode_2d(x=_coords[:, 0], y=_coords[:, 1]) 42 | _geom_enc_b = pos_encode_2d(x=_coords[:, 2], y=_coords[:, 3]) 43 | _geom_feats = torch.cat([_geom_enc_a, _geom_enc_b], dim=-1) 44 | geom_feats = _geom_feats.reshape(bs, N, -1) 45 | 46 | return geom_feats 47 | -------------------------------------------------------------------------------- /code/learn/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer: 4 | 5 | def __init__(self, name): 6 | self.name = name 7 | self._start = None 8 | self._end = None 9 | 10 | 11 | def __enter__(self): 12 | self._start = time.time() 13 | 14 | 15 | def __exit__(self, *exc_info): 16 | self._end = time.time() 17 | 18 | print('%s: %f' % (self.name, self._end - self._start)) -------------------------------------------------------------------------------- /code/learn/train_order_class.py: -------------------------------------------------------------------------------- 1 | from turtle import pos 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | import time 6 | import numpy as np 7 | import datetime 8 | import argparse 9 | from tqdm import tqdm 10 | from pathlib import Path 11 | from torch.utils.data import DataLoader 12 | from torch.utils.tensorboard import SummaryWriter 13 | from datasets.building_order_enum import ( 14 | BuildingCornerDataset, 15 | collate_fn_seq, 16 | get_pixel_features, 17 | ) 18 | from models.corner_models import CornerEnum 19 | from models.order_class_models import EdgeTransformer, EdgeTransformer2 20 | from models.loss import CornerCriterion, EdgeCriterion 21 | from models.corner_to_edge import prepare_edge_data 22 | import utils.misc as utils 23 | from utils.nn_utils import pos_encode_2d 24 | import torch.nn.functional as F 25 | 26 | # from infer_full import FloorHEAT 27 | 28 | # for debugging NaNs 29 | # torch.autograd.set_detect_anomaly(True) 30 | 31 | 32 | def get_args_parser(): 33 | parser = argparse.ArgumentParser("Set transformer detector", add_help=False) 34 | parser.add_argument("--lr", default=2e-4, type=float) 35 | parser.add_argument("--batch_size", default=128, type=int) 36 | parser.add_argument("--weight_decay", default=1e-5, type=float) 37 | parser.add_argument("--epochs", default=80, type=int) 38 | parser.add_argument("--lr_drop", default=50, type=int) 39 | parser.add_argument( 40 | "--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm" 41 | ) 42 | 43 | parser.add_argument("--resume", action="store_true") 44 | parser.add_argument( 45 | "--output_dir", 46 | default="../../ckpts/order_class", 47 | help="path where to save, empty for no saving", 48 | ) 49 | 50 | # my own 51 | parser.add_argument("--test_idx", type=int, default=0) 52 | parser.add_argument("--grad_accum", type=int, default=1) 53 | parser.add_argument("--rand_aug", action="store_true", default=True) 54 | parser.add_argument("--last_first", type=bool, default=False) 55 | parser.add_argument("--lock_pos", type=str, default="none") 56 | parser.add_argument("--normalize_by_seq", action="store_true") 57 | 58 | return parser 59 | 60 | 61 | def train_one_epoch( 62 | image_size, 63 | edge_model, 64 | edge_criterion, 65 | data_loader, 66 | optimizer, 67 | writer, 68 | epoch, 69 | max_norm, 70 | args, 71 | ): 72 | # backbone.train() 73 | edge_model.train() 74 | edge_criterion.train() 75 | optimizer.zero_grad() 76 | 77 | acc_avg = 0 78 | loss_avg = 0 79 | 80 | pbar = tqdm(data_loader) 81 | for batch_i, data in enumerate(pbar): 82 | logits, loss, acc = run_model(data, edge_model, epoch, edge_criterion) 83 | 84 | loss = loss / args.grad_accum 85 | loss.backward() 86 | 87 | num_iter = epoch * len(data_loader) + batch_i 88 | writer.add_scalar("train/loss_mb", loss, num_iter) 89 | writer.add_scalar("train/acc_mb", acc, num_iter) 90 | pbar.set_description("Train Loss: %.3f Acc: %.3f" % (loss, acc)) 91 | 92 | acc_avg += acc.item() 93 | loss_avg += loss.item() 94 | 95 | if ((batch_i + 1) % args.grad_accum == 0) or ( 96 | (batch_i + 1) == len(data_loader) 97 | ): 98 | if max_norm > 0: 99 | # torch.nn.utils.clip_grad_norm_(backbone.parameters(), max_norm) 100 | torch.nn.utils.clip_grad_norm_(edge_model.parameters(), max_norm) 101 | 102 | optimizer.step() 103 | optimizer.zero_grad() 104 | 105 | acc_avg /= len(data_loader) 106 | loss_avg /= len(data_loader) 107 | writer.add_scalar("train/acc", acc_avg, epoch) 108 | writer.add_scalar("train/loss", loss_avg, epoch) 109 | 110 | print("Train loss: %.3f acc: %.3f" % (loss_avg, acc_avg)) 111 | 112 | return -1 113 | 114 | 115 | @torch.no_grad() 116 | def evaluate(image_size, edge_model, edge_criterion, data_loader, writer, epoch, args): 117 | # backbone.train() 118 | edge_model.eval() 119 | edge_criterion.eval() 120 | 121 | loss_total = 0 122 | acc_total = 0 123 | 124 | pbar = tqdm(data_loader) 125 | for batch_i, data in enumerate(pbar): 126 | logits, loss, acc = run_model(data, edge_model, epoch, edge_criterion) 127 | pbar.set_description("Eval Loss: %.3f" % loss) 128 | 129 | loss_total += loss.item() 130 | acc_total += acc.item() 131 | 132 | loss_avg = loss_total / len(data_loader) 133 | acc_avg = acc_total / len(data_loader) 134 | 135 | print("Val loss: %.3f acc: %.3f\n" % (loss_avg, acc_avg)) 136 | writer.add_scalar("eval/loss", loss_avg, epoch) 137 | writer.add_scalar("eval/acc", acc_avg, epoch) 138 | 139 | return loss_avg, acc_avg 140 | 141 | 142 | def run_model(data, edge_model, epoch, edge_criterion): 143 | for key in data.keys(): 144 | if type(data[key]) is torch.Tensor: 145 | data[key] = data[key].cuda() 146 | 147 | # run the edge model 148 | assert (data["label"] >= 0).all() and (data["label"] <= 1).all() 149 | 150 | logits = edge_model(data) 151 | loss = edge_criterion(logits, data["label"]) 152 | acc = (logits.argmax(-1) == data["label"]).float().mean() 153 | 154 | return logits, loss, acc 155 | 156 | 157 | def main(args): 158 | DATAPATH = "../../data" 159 | REVIT_ROOT = "" 160 | image_size = 512 161 | 162 | # prepare datasets 163 | train_dataset = BuildingCornerDataset( 164 | DATAPATH, 165 | REVIT_ROOT, 166 | phase="train", 167 | image_size=image_size, 168 | rand_aug=args.rand_aug, 169 | test_idx=args.test_idx, 170 | loss_type="class", 171 | ) 172 | train_dataloader = DataLoader( 173 | train_dataset, 174 | batch_size=args.batch_size, 175 | shuffle=True, 176 | num_workers=8, 177 | collate_fn=collate_fn_seq, 178 | ) 179 | 180 | test_dataset = BuildingCornerDataset( 181 | DATAPATH, 182 | REVIT_ROOT, 183 | phase="valid", 184 | image_size=image_size, 185 | rand_aug=False, 186 | test_idx=args.test_idx, 187 | loss_type="class", 188 | ) 189 | test_dataloader = DataLoader( 190 | test_dataset, 191 | batch_size=args.batch_size, 192 | shuffle=False, 193 | num_workers=8, 194 | collate_fn=collate_fn_seq, 195 | ) 196 | 197 | edge_model = EdgeTransformer(d_model=256) 198 | edge_model = edge_model.cuda() 199 | 200 | edge_criterion = nn.CrossEntropyLoss() 201 | 202 | edge_params = [p for p in edge_model.parameters()] 203 | 204 | all_params = edge_params # + backbone_params 205 | optimizer = torch.optim.AdamW( 206 | all_params, lr=args.lr, weight_decay=args.weight_decay 207 | ) 208 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 209 | start_epoch = 0 210 | 211 | if args.resume: 212 | ckpt = torch.load(args.resume) 213 | edge_model.load_state_dict(ckpt["edge_model"]) 214 | optimizer.load_state_dict(ckpt["optimizer"]) 215 | lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) 216 | lr_scheduler.step_size = args.lr_drop 217 | 218 | print( 219 | "Resume from ckpt file {}, starting from epoch {}".format( 220 | args.resume, ckpt["epoch"] 221 | ) 222 | ) 223 | start_epoch = ckpt["epoch"] + 1 224 | 225 | else: 226 | ckpt = torch.load( 227 | "../../ckpts/pretrained/%d/checkpoint_best.pth" % args.test_idx 228 | ) 229 | 230 | replacements = [ 231 | ["decoder_1", "module.transformer.relational_decoder"], 232 | ["decoder_2", "module.transformer.relational_decoder"], 233 | ] 234 | 235 | edge_model_dict = edge_model.state_dict() 236 | for key in edge_model_dict.keys(): 237 | replaced = False 238 | for old, new in replacements: 239 | if old in key: 240 | assert not replaced 241 | new_key = key.replace(old, new) 242 | edge_model_dict[key] = ckpt["edge_model"][new_key] 243 | replaced = True 244 | print(key) 245 | 246 | edge_model.load_state_dict(edge_model_dict) 247 | print("Resume from pre-trained checkpoints") 248 | 249 | n_edge_parameters = sum(p.numel() for p in edge_params if p.requires_grad) 250 | n_all_parameters = sum(p.numel() for p in all_params if p.requires_grad) 251 | print("number of trainable edge params:", n_edge_parameters) 252 | print("number of all trainable params:", n_all_parameters) 253 | 254 | print("Start training") 255 | start_time = time.time() 256 | 257 | output_dir = Path("%s/%d" % (args.output_dir, args.test_idx)) 258 | if not os.path.exists(output_dir): 259 | os.makedirs(output_dir) 260 | 261 | # prepare summary writer 262 | writer = SummaryWriter(log_dir=output_dir) 263 | 264 | best_acc = 0 265 | for epoch in range(start_epoch, args.epochs): 266 | print("Epoch: %d" % epoch) 267 | train_one_epoch( 268 | image_size, 269 | edge_model, 270 | edge_criterion, 271 | train_dataloader, 272 | optimizer, 273 | writer, 274 | epoch, 275 | args.clip_max_norm, 276 | args, 277 | ) 278 | lr_scheduler.step() 279 | 280 | val_loss, val_acc = evaluate( 281 | image_size, 282 | edge_model, 283 | edge_criterion, 284 | test_dataloader, 285 | writer, 286 | epoch, 287 | args, 288 | ) 289 | 290 | if val_acc > best_acc: 291 | is_best = True 292 | best_acc = val_acc 293 | else: 294 | is_best = False 295 | 296 | if args.output_dir: 297 | checkpoint_paths = [output_dir / ("checkpoint_latest.pth")] 298 | checkpoint_paths.append(output_dir / ("checkpoint_%03d.pth" % epoch)) 299 | if is_best: 300 | checkpoint_paths.append(output_dir / "checkpoint_best.pth") 301 | 302 | for checkpoint_path in checkpoint_paths: 303 | torch.save( 304 | { 305 | "edge_model": edge_model.state_dict(), 306 | "optimizer": optimizer.state_dict(), 307 | "lr_scheduler": lr_scheduler.state_dict(), 308 | "epoch": epoch, 309 | "args": args, 310 | "val_acc": val_acc, 311 | }, 312 | checkpoint_path, 313 | ) 314 | 315 | total_time = time.time() - start_time 316 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 317 | print("Training time {}".format(total_time_str)) 318 | 319 | 320 | if __name__ == "__main__": 321 | parser = argparse.ArgumentParser( 322 | "GeoVAE training and evaluation script", parents=[get_args_parser()] 323 | ) 324 | args = parser.parse_args() 325 | main(args) 326 | -------------------------------------------------------------------------------- /code/learn/train_order_metric.py: -------------------------------------------------------------------------------- 1 | from turtle import pos 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | import time 6 | import numpy as np 7 | import datetime 8 | import argparse 9 | from tqdm import tqdm 10 | from pathlib import Path 11 | from torch.utils.data import DataLoader 12 | from torch.utils.tensorboard import SummaryWriter 13 | from datasets.building_order_enum import ( 14 | BuildingCornerDataset, 15 | collate_fn_seq, 16 | get_pixel_features, 17 | ) 18 | from models.corner_models import CornerEnum 19 | from models.order_metric_models import EdgeTransformer, EdgeTransformer2 20 | from models.corner_to_edge import prepare_edge_data 21 | import utils.misc as utils 22 | from utils.nn_utils import pos_encode_2d 23 | import torch.nn.functional as F 24 | 25 | # from infer_full import FloorHEAT 26 | 27 | # for debugging NaNs 28 | # torch.autograd.set_detect_anomaly(True) 29 | 30 | 31 | def get_args_parser(): 32 | parser = argparse.ArgumentParser("Set transformer detector", add_help=False) 33 | parser.add_argument("--lr", default=2e-4, type=float) 34 | parser.add_argument("--batch_size", default=128, type=int) 35 | parser.add_argument("--weight_decay", default=1e-5, type=float) 36 | parser.add_argument("--epochs", default=100, type=int) 37 | parser.add_argument("--lr_drop", default=50, type=int) 38 | parser.add_argument( 39 | "--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm" 40 | ) 41 | 42 | parser.add_argument("--resume", action="store_true") 43 | parser.add_argument( 44 | "--output_dir", 45 | default="../../ckpts/order_metric", 46 | help="path where to save, empty for no saving", 47 | ) 48 | 49 | # my own 50 | parser.add_argument("--test_idx", type=int, default=0) 51 | parser.add_argument("--grad_accum", type=int, default=1) 52 | parser.add_argument("--rand_aug", action="store_true", default=True) 53 | parser.add_argument("--last_first", type=bool, default=False) 54 | parser.add_argument("--lock_pos", type=str, default="none") 55 | parser.add_argument("--normalize_by_seq", action="store_true") 56 | 57 | return parser 58 | 59 | 60 | def train_one_epoch( 61 | image_size, 62 | edge_model, 63 | loss_fn, 64 | miner, 65 | dist_fn, 66 | data_loader, 67 | optimizer, 68 | writer, 69 | epoch, 70 | max_norm, 71 | args, 72 | ): 73 | # backbone.train() 74 | edge_model.train() 75 | loss_fn.train() 76 | optimizer.zero_grad() 77 | 78 | acc_avg = 0 79 | loss_avg = 0 80 | 81 | pbar = tqdm(data_loader) 82 | for batch_i, data in enumerate(pbar): 83 | logits, loss, acc = run_model(data, edge_model, epoch, loss_fn, miner, dist_fn) 84 | 85 | loss = loss / args.grad_accum 86 | loss.backward() 87 | 88 | num_iter = epoch * len(data_loader) + batch_i 89 | writer.add_scalar("train/loss_mb", loss, num_iter) 90 | writer.add_scalar("train/acc_mb", acc, num_iter) 91 | pbar.set_description("Train Loss: %.3f Acc: %.3f" % (loss, acc)) 92 | 93 | acc_avg += acc.item() 94 | loss_avg += loss.item() 95 | 96 | if ((batch_i + 1) % args.grad_accum == 0) or ( 97 | (batch_i + 1) == len(data_loader) 98 | ): 99 | if max_norm > 0: 100 | # torch.nn.utils.clip_grad_norm_(backbone.parameters(), max_norm) 101 | torch.nn.utils.clip_grad_norm_(edge_model.parameters(), max_norm) 102 | 103 | optimizer.step() 104 | optimizer.zero_grad() 105 | 106 | acc_avg /= len(data_loader) 107 | loss_avg /= len(data_loader) 108 | writer.add_scalar("train/acc", acc_avg, epoch) 109 | writer.add_scalar("train/loss", acc_avg, epoch) 110 | 111 | print("Train loss: %.3f acc: %.3f" % (loss_avg, acc_avg)) 112 | 113 | return -1 114 | 115 | 116 | @torch.no_grad() 117 | def evaluate( 118 | image_size, 119 | edge_model, 120 | loss_fn, 121 | miner, 122 | dist_fn, 123 | data_loader, 124 | writer, 125 | epoch, 126 | args, 127 | ): 128 | # backbone.train() 129 | edge_model.eval() 130 | loss_fn.eval() 131 | 132 | loss_total = 0 133 | acc_total = 0 134 | 135 | pbar = tqdm(data_loader) 136 | for batch_i, data in enumerate(pbar): 137 | logits, loss, acc = run_model(data, edge_model, epoch, loss_fn, miner, dist_fn) 138 | pbar.set_description("Eval Loss: %.3f" % loss) 139 | 140 | loss_total += loss.item() 141 | acc_total += acc.item() 142 | 143 | loss_avg = loss_total / len(data_loader) 144 | acc_avg = acc_total / len(data_loader) 145 | 146 | print("Val loss: %.3f acc: %.3f\n" % (loss_avg, acc_avg)) 147 | writer.add_scalar("eval/loss", loss_avg, epoch) 148 | writer.add_scalar("eval/acc", acc_avg, epoch) 149 | 150 | return loss_avg, acc_avg 151 | 152 | 153 | def run_model(data, edge_model, epoch, loss_fn, miner, dist_fn): 154 | for key in data.keys(): 155 | if type(data[key]) is torch.Tensor: 156 | data[key] = data[key].cuda() 157 | 158 | # run the edge model 159 | embeddings = edge_model(data) 160 | 161 | # ignore the provided edges 162 | all_neg = [] 163 | all_anc = [] 164 | all_pos = [] 165 | 166 | total_loss = 0 167 | total_acc = 0 168 | 169 | for ex_i in range(len(embeddings)): 170 | mask = ~data["edge_coords_mask"][ex_i] 171 | 172 | _embeddings = embeddings[ex_i][mask] 173 | _label = data["edge_label"][ex_i][mask] 174 | 175 | neg = _embeddings[_label == 0] 176 | anc = _embeddings[_label == 1].repeat(len(neg), 1) 177 | pos = _embeddings[_label == 2].repeat(len(neg), 1) 178 | 179 | all_neg.append(neg) 180 | all_anc.append(anc) 181 | all_pos.append(pos) 182 | 183 | anc2pos = dist_fn(anc, pos) 184 | anc2neg = dist_fn(anc, neg) 185 | total_acc += (anc2pos < anc2neg).all().float() 186 | 187 | all_neg = torch.cat(all_neg, dim=0) 188 | all_anc = torch.cat(all_anc, dim=0) 189 | all_pos = torch.cat(all_pos, dim=0) 190 | 191 | total_loss = loss_fn(all_anc, all_pos, all_neg) 192 | total_loss /= len(embeddings) 193 | 194 | total_acc = total_acc / len(embeddings) 195 | 196 | return embeddings, total_loss, total_acc 197 | 198 | 199 | def main(args): 200 | DATAPATH = "../../data" 201 | REVIT_ROOT = "" 202 | image_size = 512 203 | 204 | # prepare datasets 205 | train_dataset = BuildingCornerDataset( 206 | DATAPATH, 207 | REVIT_ROOT, 208 | phase="train", 209 | image_size=image_size, 210 | rand_aug=args.rand_aug, 211 | test_idx=args.test_idx, 212 | loss_type="metric", 213 | ) 214 | train_dataloader = DataLoader( 215 | train_dataset, 216 | batch_size=args.batch_size, 217 | shuffle=True, 218 | num_workers=8, 219 | collate_fn=collate_fn_seq, 220 | ) 221 | 222 | # for blah in train_dataset: 223 | # pass 224 | 225 | test_dataset = BuildingCornerDataset( 226 | DATAPATH, 227 | REVIT_ROOT, 228 | phase="valid", 229 | image_size=image_size, 230 | rand_aug=False, 231 | test_idx=args.test_idx, 232 | loss_type="metric", 233 | ) 234 | test_dataloader = DataLoader( 235 | test_dataset, 236 | batch_size=args.batch_size, 237 | shuffle=False, 238 | num_workers=8, 239 | collate_fn=collate_fn_seq, 240 | ) 241 | 242 | edge_model = EdgeTransformer(d_model=256) 243 | edge_model = edge_model.cuda() 244 | 245 | # loss_fn = losses.TripletMarginLoss(distance=dist_fn) 246 | # miner = miners.MultiSimilarityMiner() 247 | dist_fn = lambda x, y: 1.0 - F.cosine_similarity(x, y) 248 | loss_fn = nn.TripletMarginWithDistanceLoss( 249 | distance_function=dist_fn, reduction="sum" 250 | ) 251 | miner = None 252 | 253 | edge_params = [p for p in edge_model.parameters()] 254 | 255 | all_params = edge_params # + backbone_params 256 | optimizer = torch.optim.AdamW( 257 | all_params, lr=args.lr, weight_decay=args.weight_decay 258 | ) 259 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 260 | start_epoch = 0 261 | 262 | if args.resume: 263 | ckpt = torch.load(args.resume) 264 | edge_model.load_state_dict(ckpt["edge_model"]) 265 | optimizer.load_state_dict(ckpt["optimizer"]) 266 | lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) 267 | lr_scheduler.step_size = args.lr_drop 268 | 269 | print( 270 | "Resume from ckpt file {}, starting from epoch {}".format( 271 | args.resume, ckpt["epoch"] 272 | ) 273 | ) 274 | start_epoch = ckpt["epoch"] + 1 275 | 276 | else: 277 | ckpt = torch.load( 278 | "../../ckpts/pretrained/%d/checkpoint_best.pth" % args.test_idx 279 | ) 280 | 281 | replacements = [ 282 | ["decoder_1", "module.transformer.relational_decoder"], 283 | ["decoder_2", "module.transformer.relational_decoder"], 284 | ] 285 | 286 | edge_model_dict = edge_model.state_dict() 287 | for key in edge_model_dict.keys(): 288 | replaced = False 289 | for old, new in replacements: 290 | if old in key: 291 | assert not replaced 292 | new_key = key.replace(old, new) 293 | edge_model_dict[key] = ckpt["edge_model"][new_key] 294 | replaced = True 295 | print(key) 296 | 297 | edge_model.load_state_dict(edge_model_dict) 298 | print("Resume from pre-trained checkpoints") 299 | 300 | n_edge_parameters = sum(p.numel() for p in edge_params if p.requires_grad) 301 | n_all_parameters = sum(p.numel() for p in all_params if p.requires_grad) 302 | print("number of trainable edge params:", n_edge_parameters) 303 | print("number of all trainable params:", n_all_parameters) 304 | 305 | print("Start training") 306 | start_time = time.time() 307 | 308 | output_dir = Path("%s/%d" % (args.output_dir, args.test_idx)) 309 | if not os.path.exists(output_dir): 310 | os.makedirs(output_dir) 311 | 312 | # prepare summary writer 313 | writer = SummaryWriter(log_dir=output_dir) 314 | 315 | best_acc = 0 316 | for epoch in range(start_epoch, args.epochs): 317 | print("Epoch: %d" % epoch) 318 | train_one_epoch( 319 | image_size, 320 | edge_model, 321 | loss_fn, 322 | miner, 323 | dist_fn, 324 | train_dataloader, 325 | optimizer, 326 | writer, 327 | epoch, 328 | args.clip_max_norm, 329 | args, 330 | ) 331 | lr_scheduler.step() 332 | 333 | # is_best = False 334 | # val_acc = 0 335 | val_loss, val_acc = evaluate( 336 | image_size, 337 | edge_model, 338 | loss_fn, 339 | miner, 340 | dist_fn, 341 | test_dataloader, 342 | writer, 343 | epoch, 344 | args, 345 | ) 346 | 347 | if val_acc > best_acc: 348 | is_best = True 349 | best_acc = val_acc 350 | else: 351 | is_best = False 352 | 353 | if args.output_dir: 354 | checkpoint_paths = [output_dir / ("checkpoint_latest.pth")] 355 | # checkpoint_paths.append(output_dir / ("checkpoint_%03d.pth" % epoch)) 356 | if is_best: 357 | checkpoint_paths.append(output_dir / "checkpoint_best.pth") 358 | 359 | for checkpoint_path in checkpoint_paths: 360 | torch.save( 361 | { 362 | "edge_model": edge_model.state_dict(), 363 | "optimizer": optimizer.state_dict(), 364 | "lr_scheduler": lr_scheduler.state_dict(), 365 | "epoch": epoch, 366 | "args": args, 367 | "val_acc": val_acc, 368 | }, 369 | checkpoint_path, 370 | ) 371 | 372 | total_time = time.time() - start_time 373 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 374 | print("Training time {}".format(total_time_str)) 375 | 376 | 377 | if __name__ == "__main__": 378 | parser = argparse.ArgumentParser( 379 | "GeoVAE training and evaluation script", parents=[get_args_parser()] 380 | ) 381 | args = parser.parse_args() 382 | main(args) 383 | -------------------------------------------------------------------------------- /code/learn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/code/learn/utils/__init__.py -------------------------------------------------------------------------------- /code/learn/utils/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | 5 | def add_path(path): 6 | if path not in sys.path: sys.path.insert(0, path) 7 | 8 | 9 | this_dir = osp.dirname(__file__) 10 | 11 | project_path = osp.abspath(osp.join(this_dir, '..')) 12 | add_path(project_path) -------------------------------------------------------------------------------- /code/learn/utils/geometry_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def building_metric(logits, label): 7 | preds = torch.argmax(logits, dim=-1) 8 | true_ids = torch.where(label==1) 9 | num_true = true_ids[0].shape[0] 10 | tp = (preds[true_ids] == 1).sum().double() 11 | recall = tp / num_true 12 | prec = tp / (preds == 1).sum() 13 | fscore = 2 * recall * prec / (prec + recall) 14 | return recall, prec, fscore 15 | 16 | 17 | def edge_acc(logits, label, lengths, gt_values): 18 | all_acc = list() 19 | for i in range(logits.shape[0]): 20 | length = lengths[i] 21 | gt_value = gt_values[i, :length] 22 | pred_idx = torch.where(gt_value == 2) 23 | if len(pred_idx[0]) == 0: 24 | continue 25 | else: 26 | preds = torch.argmax(logits[i, :, :length][:, pred_idx[0]], dim=0) 27 | gts = label[i, :length][pred_idx[0]] 28 | pos_ids = torch.where(gts == 1) 29 | correct = (preds[pos_ids] == gts[pos_ids]).sum().float() 30 | num_pos_gt = len(pos_ids[0]) 31 | recall = correct / num_pos_gt if num_pos_gt > 0 else torch.tensor(0) 32 | num_pos_pred = (preds == 1).sum().float() 33 | prec = correct / num_pos_pred if num_pos_pred > 0 else torch.tensor(0) 34 | f_score = 2.0 * prec * recall / (recall + prec + 1e-8) 35 | f_score = f_score.cpu() 36 | all_acc.append(f_score) 37 | if len(all_acc) > 1: 38 | all_acc = torch.stack(all_acc, 0) 39 | avg_acc = all_acc.mean() 40 | else: 41 | avg_acc = all_acc[0] 42 | return avg_acc 43 | 44 | 45 | def corner_eval(targets, outputs): 46 | assert isinstance(targets, np.ndarray) 47 | assert isinstance(outputs, np.ndarray) 48 | output_to_gt = dict() 49 | gt_to_output = dict() 50 | for target_i, target in enumerate(targets): 51 | dist = (outputs - target) ** 2 52 | dist = np.sqrt(dist.sum(axis=-1)) 53 | min_dist = dist.min() 54 | min_idx = dist.argmin() 55 | if min_dist < 5 and min_idx not in output_to_gt: # a positive match 56 | output_to_gt[min_idx] = target_i 57 | gt_to_output[target_i] = min_idx 58 | tp = len(output_to_gt) 59 | prec = tp / len(outputs) 60 | recall = tp / len(targets) 61 | return prec, recall 62 | 63 | 64 | def rectify_data(image, annot): 65 | rows, cols, ch = image.shape 66 | bins = [0 for _ in range(180)] # 5 degree per bin 67 | # edges vote for directions 68 | 69 | gauss_weights = [0.1, 0.2, 0.5, 1, 0.5, 0.2, 0.1] 70 | 71 | for src, connections in annot.items(): 72 | for end in connections: 73 | edge = [(end[0] - src[0]), -(end[1] - src[1])] 74 | edge_len = np.sqrt(edge[0] ** 2 + edge[1] ** 2) 75 | if edge_len <= 10: # skip too short edges 76 | continue 77 | if edge[0] == 0: 78 | bin_id = 90 79 | else: 80 | theta = np.arctan(edge[1] / edge[0]) / np.pi * 180 81 | if edge[0] * edge[1] < 0: 82 | theta += 180 83 | bin_id = int(theta.round()) 84 | if bin_id == 180: 85 | bin_id = 0 86 | for offset in range(-3, 4): 87 | bin_idx = bin_id + offset 88 | if bin_idx >= 180: 89 | bin_idx -= 180 90 | bins[bin_idx] += np.sqrt(edge[1] ** 2 + edge[0] ** 2) * gauss_weights[offset + 2] 91 | 92 | bins = np.array(bins) 93 | sorted_ids = np.argsort(bins)[::-1] 94 | bin_1 = sorted_ids[0] 95 | remained_ids = [idx for idx in sorted_ids if angle_dist(bin_1, idx) >= 30] 96 | bin_2 = remained_ids[0] 97 | if bin_1 < bin_2: 98 | bin_1, bin_2 = bin_2, bin_1 99 | 100 | dir_1, dir_2 = bin_1, bin_2 101 | # compute the affine parameters, and apply affine transform to the image 102 | origin = [127, 127] 103 | p1_old = [127 + 100 * np.cos(dir_1 / 180 * np.pi), 127 - 100 * np.sin(dir_1 / 180 * np.pi)] 104 | p2_old = [127 + 100 * np.cos(dir_2 / 180 * np.pi), 127 - 100 * np.sin(dir_2 / 180 * np.pi)] 105 | pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32) 106 | p1_new = [127, 27] # y_axis 107 | p2_new = [227, 127] # x_axis 108 | pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32) 109 | 110 | M1 = cv2.getAffineTransform(pts1, pts2) 111 | 112 | all_corners = list(annot.keys()) 113 | all_corners_ = np.array(all_corners) 114 | ones = np.ones([all_corners_.shape[0], 1]) 115 | all_corners_ = np.concatenate([all_corners_, ones], axis=-1) 116 | new_corners = np.matmul(M1, all_corners_.T).T 117 | 118 | M = np.concatenate([M1, np.array([[0, 0, 1]])], axis=0) 119 | 120 | x_max = new_corners[:, 0].max() 121 | x_min = new_corners[:, 0].min() 122 | y_max = new_corners[:, 1].max() 123 | y_min = new_corners[:, 1].min() 124 | 125 | side_x = (x_max - x_min) * 0.1 126 | side_y = (y_max - y_min) * 0.1 127 | right_border = x_max + side_x 128 | left_border = x_min - side_x 129 | bot_border = y_max + side_y 130 | top_border = y_min - side_y 131 | pts1 = np.array([[left_border, top_border], [right_border, top_border], [right_border, bot_border]]).astype( 132 | np.float32) 133 | pts2 = np.array([[5, 5], [250, 5], [250, 250]]).astype(np.float32) 134 | M_scale = cv2.getAffineTransform(pts1, pts2) 135 | 136 | M = np.matmul(np.concatenate([M_scale, np.array([[0, 0, 1]])], axis=0), M) 137 | 138 | new_image = cv2.warpAffine(image, M[:2, :], (cols, rows), borderValue=(255, 255, 255)) 139 | all_corners_ = np.concatenate([all_corners, ones], axis=-1) 140 | new_corners = np.matmul(M[:2, :], all_corners_.T).T 141 | 142 | corner_mapping = dict() 143 | for idx, corner in enumerate(all_corners): 144 | corner_mapping[corner] = new_corners[idx] 145 | 146 | new_annot = dict() 147 | for corner, connections in annot.items(): 148 | new_corner = corner_mapping[corner] 149 | tuple_new_corner = tuple(new_corner) 150 | new_annot[tuple_new_corner] = list() 151 | for to_corner in connections: 152 | new_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)]) 153 | 154 | # do the affine transform 155 | return new_image, new_annot, M 156 | 157 | 158 | def angle_dist(a1, a2): 159 | if a1 > a2: 160 | a1, a2 = a2, a1 161 | d1 = a2 - a1 162 | d2 = a1 + 180 - a2 163 | dist = min(d1, d2) 164 | return dist 165 | -------------------------------------------------------------------------------- /code/learn/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from collections import defaultdict, deque 4 | import datetime 5 | from typing import Optional, List 6 | from torch import Tensor 7 | 8 | 9 | @torch.no_grad() 10 | def accuracy(output, target, topk=(1,)): 11 | """Computes the precision@k for the specified values of k""" 12 | if target.numel() == 0: 13 | return [torch.zeros([], device=output.device)] 14 | maxk = max(topk) 15 | batch_size = target.size(0) 16 | 17 | _, pred = output.topk(maxk, 1, True, True) 18 | pred = pred.t() 19 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 20 | 21 | res = [] 22 | for k in topk: 23 | correct_k = correct[:k].view(-1).float().sum(0) 24 | res.append(correct_k.mul_(100.0 / batch_size)) 25 | return res 26 | 27 | 28 | class SmoothedValue(object): 29 | """Track a series of values and provide access to smoothed values over a 30 | window or the global series average. 31 | """ 32 | 33 | def __init__(self, window_size=100, fmt=None): 34 | if fmt is None: 35 | fmt = "{median:.3f} ({global_avg:.3f})" 36 | self.deque = deque(maxlen=window_size) 37 | self.total = 0.0 38 | self.count = 0 39 | self.fmt = fmt 40 | 41 | def update(self, value, n=1): 42 | self.deque.append(value) 43 | self.count += n 44 | self.total += value * n 45 | 46 | @property 47 | def median(self): 48 | d = torch.tensor(list(self.deque)) 49 | return d.median().item() 50 | 51 | @property 52 | def avg(self): 53 | d = torch.tensor(list(self.deque), dtype=torch.float32) 54 | return d.mean().item() 55 | 56 | @property 57 | def global_avg(self): 58 | return self.total / self.count 59 | 60 | @property 61 | def max(self): 62 | return max(self.deque) 63 | 64 | @property 65 | def value(self): 66 | #return self.deque[-1] 67 | return self.avg 68 | 69 | def __str__(self): 70 | return self.fmt.format( 71 | median=self.median, 72 | avg=self.avg, 73 | global_avg=self.global_avg, 74 | max=self.max, 75 | value=self.value) 76 | 77 | 78 | class MetricLogger(object): 79 | def __init__(self, delimiter="\t"): 80 | self.meters = defaultdict(SmoothedValue) 81 | self.delimiter = delimiter 82 | 83 | def update(self, **kwargs): 84 | for k, v in kwargs.items(): 85 | if isinstance(v, torch.Tensor): 86 | v = v.item() 87 | assert isinstance(v, (float, int)) 88 | self.meters[k].update(v) 89 | 90 | def __getattr__(self, attr): 91 | if attr in self.meters: 92 | return self.meters[attr] 93 | if attr in self.__dict__: 94 | return self.__dict__[attr] 95 | raise AttributeError("'{}' object has no attribute '{}'".format( 96 | type(self).__name__, attr)) 97 | 98 | def __str__(self): 99 | loss_str = [] 100 | for name, meter in self.meters.items(): 101 | loss_str.append( 102 | "{}: {}".format(name, str(meter)) 103 | ) 104 | return self.delimiter.join(loss_str) 105 | 106 | def synchronize_between_processes(self): 107 | for meter in self.meters.values(): 108 | meter.synchronize_between_processes() 109 | 110 | def add_meter(self, name, meter): 111 | self.meters[name] = meter 112 | 113 | def log_every(self, iterable, print_freq, header=None, length_total=None): 114 | i = 0 115 | if length_total is None: 116 | length_total = len(iterable) 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(length_total))) + 'd' 124 | if torch.cuda.is_available(): 125 | log_msg = self.delimiter.join([ 126 | header, 127 | '[{0' + space_fmt + '}/{1}]', 128 | 'eta: {eta}', 129 | '{meters}', 130 | 'time: {time}', 131 | 'data: {data}', 132 | 'max mem: {memory:.0f}' 133 | ]) 134 | else: 135 | log_msg = self.delimiter.join([ 136 | header, 137 | '[{0' + space_fmt + '}/{1}]', 138 | 'eta: {eta}', 139 | '{meters}', 140 | 'time: {time}', 141 | 'data: {data}' 142 | ]) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == length_total - 1: 149 | eta_seconds = iter_time.global_avg * (length_total - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | try: 153 | print(log_msg.format( 154 | i, length_total, eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time), 157 | memory=torch.cuda.max_memory_allocated() / MB)) 158 | except Exception as e: 159 | import pdb; pdb.set_trace() 160 | else: 161 | print(log_msg.format( 162 | i, length_total, eta=eta_string, 163 | meters=str(self), 164 | time=str(iter_time), data=str(data_time))) 165 | i += 1 166 | end = time.time() 167 | total_time = time.time() - start_time 168 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 169 | print('{} Total time: {} ({:.4f} s / it)'.format( 170 | header, total_time_str, total_time / length_total)) 171 | 172 | 173 | class NestedTensor(object): 174 | def __init__(self, tensors, mask: Optional[Tensor]): 175 | self.tensors = tensors 176 | self.mask = mask 177 | 178 | def to(self, device, non_blocking=False): 179 | # type: (Device) -> NestedTensor # noqa 180 | cast_tensor = self.tensors.to(device, non_blocking=non_blocking) 181 | mask = self.mask 182 | if mask is not None: 183 | assert mask is not None 184 | cast_mask = mask.to(device, non_blocking=non_blocking) 185 | else: 186 | cast_mask = None 187 | return NestedTensor(cast_tensor, cast_mask) 188 | 189 | def record_stream(self, *args, **kwargs): 190 | self.tensors.record_stream(*args, **kwargs) 191 | if self.mask is not None: 192 | self.mask.record_stream(*args, **kwargs) 193 | 194 | def decompose(self): 195 | return self.tensors, self.mask 196 | 197 | def __repr__(self): 198 | return str(self.tensors) 199 | -------------------------------------------------------------------------------- /code/learn/utils/nn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def pos_encode_2d(x, y, d_model=128): 6 | assert len(x) == len(y) 7 | pe = torch.zeros(len(x), d_model, device=x.device) 8 | 9 | d_model = int(d_model / 2) 10 | div_term = torch.exp(torch.arange(0., d_model, 2) * 11 | -(math.log(10000.0) / d_model)) 12 | div_term = div_term.to(x.device) 13 | 14 | pos_w = torch.tensor(x).float().unsqueeze(1) 15 | pos_h = torch.tensor(y).float().unsqueeze(1) 16 | 17 | pe[:, 0:d_model:2] = torch.sin(pos_w * div_term) 18 | pe[:, 1:d_model:2] = torch.cos(pos_w * div_term) 19 | pe[:, d_model::2] = torch.sin(pos_h * div_term) 20 | pe[:, d_model+1::2] = torch.cos(pos_h * div_term) 21 | 22 | return pe 23 | 24 | 25 | def positional_encoding_2d(d_model, height, width): 26 | """ 27 | :param d_model: dimension of the model 28 | :param height: height of the positions 29 | :param width: width of the positions 30 | :return: d_model*height*width position matrix 31 | """ 32 | if d_model % 4 != 0: 33 | raise ValueError("Cannot use sin/cos positional encoding with " 34 | "odd dimension (got dim={:d})".format(d_model)) 35 | pe = torch.zeros(d_model, height, width) 36 | # Each dimension use half of d_model 37 | d_model = int(d_model / 2) 38 | div_term = torch.exp(torch.arange(0., d_model, 2) * 39 | -(math.log(10000.0) / d_model)) 40 | pos_w = torch.arange(0., width).unsqueeze(1) 41 | pos_h = torch.arange(0., height).unsqueeze(1) 42 | pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 43 | pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 44 | pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 45 | pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 46 | 47 | return pe 48 | 49 | 50 | def positional_encoding_1d(d_model, length): 51 | """ 52 | :param d_model: dimension of the model 53 | :param length: length of positions 54 | :return: length*d_model position matrix 55 | """ 56 | if d_model % 2 != 0: 57 | raise ValueError("Cannot use sin/cos positional encoding with " 58 | "odd dim (got dim={:d})".format(d_model)) 59 | pe = torch.zeros(length, d_model) 60 | position = torch.arange(0, length).unsqueeze(1) 61 | div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * 62 | -(math.log(10000.0) / d_model))) 63 | pe[:, 0::2] = torch.sin(position.float() * div_term) 64 | pe[:, 1::2] = torch.cos(position.float() * div_term) 65 | 66 | return pe 67 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.22 2 | defusedxml==0.6.0 3 | einops==0.4.1 4 | future==0.18.2 5 | imageio==2.16.1 6 | numpy==1.20.1 7 | opencv-python==4.4.0.44 8 | matplotlib==3.3.4 9 | packaging==20.9 10 | Pillow==9.0.1 11 | prometheus-client==0.9.0 12 | prompt-toolkit==3.0.16 13 | ptyprocess==0.7.0 14 | pycparser==2.20 15 | Pygments==2.8.0 16 | python-dateutil==2.8.1 17 | scikit-image==0.19.2 18 | scikit-learn==1.0 19 | scipy==1.6.1 20 | six==1.15.0 21 | cairosvg==2.5.2 22 | svgwrite==1.4.2 23 | shapely==1.8.2 -------------------------------------------------------------------------------- /resources/addin_warning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/resources/addin_warning.png -------------------------------------------------------------------------------- /resources/crop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/resources/crop.gif -------------------------------------------------------------------------------- /resources/rotate.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/resources/rotate.gif -------------------------------------------------------------------------------- /resources/translate.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weiliansong/A-Scan2BIM/7c01d0495789160095ad61532af8f76797a00c50/resources/translate.gif -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | # for automatic conda environment setup 2 | . "$HOME/miniconda3/etc/profile.d/conda.sh" 3 | conda remove -y -n bim --all 4 | conda create -y -n bim python=3.8 5 | conda activate bim 6 | 7 | conda install -y pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 8 | 9 | pip install -r requirements.txt 10 | 11 | cd code/learn/models/ops 12 | python setup.py build install 13 | 14 | conda install -y tqdm shapely 15 | pip install tensorboard rtree shapely pytorch-metric-learning laspy[lazrs] open3d typer[all] --------------------------------------------------------------------------------