├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── calibrator.py ├── convert ├── M-LSD_512_tiny_fp32.tflite ├── README.md └── convert_tf_tiny.ipynb ├── data ├── frame_1.jpg └── wireframe_anno.zip ├── demo.py ├── demo_MLSD_flask.py ├── docker-compose.flask.yml ├── docker-compose.yml ├── github ├── img.png ├── mlsd_mobile.png └── teaser.png ├── mlsd_pytorch ├── README.md ├── __init__.py ├── cfg │ └── default.py ├── configs │ ├── mobilev2_mlsd_large_512_base2_bsize24.yaml │ └── mobilev2_mlsd_tiny_512_base2_bsize24.yaml ├── data │ ├── __init__.py │ ├── utils.py │ └── wireframe_dset.py ├── learner.py ├── loss │ ├── __init__.py │ ├── _func.py │ └── mlsd_multi_loss.py ├── metric.py ├── models │ ├── __init__.py │ ├── build_model.py │ ├── layers.py │ ├── mbv2_mlsd.py │ └── mbv2_mlsd_large.py ├── optim │ ├── __init__.py │ └── lr_scheduler.py ├── pred_and_eval_sAP.py ├── tf_pred_and_eval_sAP.py ├── train.py └── utils │ ├── __init__.py │ ├── comm.py │ ├── decode.py │ ├── logger.py │ └── meter.py ├── models ├── mbv2_mlsd_large.py ├── mbv2_mlsd_tiny.py ├── mlsd_large_512_fp32.pth └── mlsd_tiny_512_fp32.pth ├── requirements.txt ├── static ├── css │ └── app.css └── favicon.ico ├── templates └── index_scan.html ├── trt_converter.py ├── utils.py └── workdir └── pretrained_models ├── mobilev2_mlsd_large_512_bsize24 └── best.pth └── mobilev2_mlsd_tiny_512_bsize24 └── best.pth /.dockerignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .gitignore 3 | convert 4 | data 5 | docker-compose.*.yml 6 | Dockerfile 7 | github 8 | LICENSE 9 | README.md 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | static/results 3 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.1.1-runtime-ubuntu20.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | 5 | COPY requirements.txt /requirements.txt 6 | 7 | RUN apt-get update \ 8 | && apt-get install -y --no-install-recommends \ 9 | libgl1-mesa-glx \ 10 | libglib2.0-0 \ 11 | python-is-python3 \ 12 | python3 \ 13 | python3-pip \ 14 | && python -m pip install --no-cache-dir -r requirements.txt \ 15 | && rm -f requirements.txt \ 16 | && rm -rf /var/lib/apt/lists/* 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021-present NAVER Corp. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # M-LSD: Towards Light-weight and Real-time Line Segment Detection 2 | 3 | ### update 2021.07.20 4 | 5 | We have push our training code in mlsd_pytorch/ 6 | 7 | [detail](mlsd_pytorch/README.md) 8 | 9 | model| img_size| sAP10 10 | |---|---|:---:| 11 | mlsd_tiny (this repo)| 512| 56.4 12 | mlsd_tiny (in the paper)| 512| 58.0 13 | mlsd_large (this repo)| 512| 59.6 14 | mlsd_large (in the paper)| 512| 62.1 15 | 16 | (this repo use: min_score=0.05, min_len=5, tok_k_lines= 500) 17 | 18 | --- 19 | 20 | Pytorch implementation of *"M-LSD: Towards Light-weight and Real-time Line Segment Detection"*
21 | 22 | origin repo: https://github.com/navervision/mlsd 23 | 24 | - [Paper](https://arxiv.org/abs/2106.00186) 25 | - [PPT](https://www.slideshare.net/ByungSooKo1/towards-lightweight-and-realtime-line-segment-detection) 26 | 27 | 28 | ## Overview 29 |

30 | 31 | 32 |

33 | 34 | 35 | **First figure**: Comparison of M-LSD and existing LSD methods on *GPU*. 36 | **Second figure**: Inference speed and memory usage on *mobile devices*. 37 | 38 | ## demo 39 | ![](github/img.png) 40 | 41 | 42 | ## How to run demo 43 | ### Install requirements 44 | ``` 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | ### Run demo 49 | 50 | The following demo test line detect (simplest): 51 | 52 | ``` 53 | python demo.py 54 | ``` 55 | 56 | The following demo run with flask in your local:
57 | 58 | ``` 59 | python demo_MLSD_flask.py 60 | ``` 61 | you can upload a image the click submit, see what happen.
62 | http://0.0.0.0:5000/ 63 | 64 | 65 | ### Run in docker 66 | 67 | 68 | Follow the instructions from , 69 | and 70 | to setup your environment. 71 | 72 | - Build the image 73 | 74 | ``` 75 | docker-compose build 76 | 77 | ``` 78 | 79 | - Run the demo 80 | 81 | ``` 82 | docker-compose up 83 | 84 | ``` 85 | 86 | - Run the flask demo 87 | 88 | ``` 89 | docker-compose -f docker-compose.yml -f docker-compose.flask.yml up 90 | 91 | ``` 92 | 93 | ### TensorRT/ Jetson Device Support 94 | 95 | #### Prerequisites 96 | 97 | Go ahead and complete installation of NVIDIA's torch2trt library with the following [instructions](https://github.com/NVIDIA-AI-IOT/torch2trt), ensuring that a compatible CUDA compiled torch wheel is available first. For instance: 98 | 99 | ``` 100 | #Jetpack 4.6.1 101 | export TORCH_INSTALL=https://developer.download.nvidia.com/compute/redist/jp/v461/pytorch/torch-1.11.0a0+17540c5+nv22.01-cp36-cp36m-linux_aarch64.whl 102 | 103 | python3 -m pip install --upgrade pip; python3 -m pip install expecttest xmlrunner hypothesis aiohttp numpy=='1.19.4' pyyaml scipy=='1.5.3' ninja cython typing_extensions protobuf; export "LD_LIBRARY_PATH=/usr/lib/llvm-8/lib:$LD_LIBRARY_PATH"; python3 -m pip install --upgrade protobuf; python3 -m pip install --no-cache $TORCH_INSTALL 104 | 105 | ``` 106 | 107 | #### Usage 108 | 109 | For simple usage go ahead and dial in the following: 110 | 111 | ``` 112 | python trt_converter.py --model tiny --conversion fp16 --bench 113 | ``` 114 | All model locations default to `./models/mlsd_{model_type}__512_trt_{conversion}.pth`. 115 | The tool also supports int8 conversion provided that a representative subset of images is is provided as follows: 116 | 117 | ``` 118 | python trt_converter.py --model tiny --conversion int8 --calibration_data calib-folder 119 | ``` 120 | 121 | **Note** You may also convert each torch2trt wrapped representation to a standard serialized engine for use with native TensorRT with both the --engine and --serialize arguments. 122 | 123 | #### Benchmarks 124 | 125 | Device| Raw FPS| Speed (ms) 126 | |---|---|:---:| 127 | Xavier NX - FP16| 134 | 7.35 128 | Xavier NX - int8| 238 | 4.13 129 | AGX Xavier - FP16 | 280 | 3.53 130 | AGX Xavier - int8 | 451 | 2.18 131 | 132 | 133 | *Tested on a Xavier NX Developer Kit(Jetpack 5.0.1 - developer preview), and an AGX Xavier Developer Kit (Jetpack 4.6.1) 134 | 135 | 136 | 137 | ## Citation 138 | If you find *M-LSD* useful in your project, please consider to cite the following paper. 139 | 140 | ``` 141 | @misc{gu2021realtime, 142 | title={Towards Real-time and Light-weight Line Segment Detection}, 143 | author={Geonmo Gu and Byungsoo Ko and SeoungHyun Go and Sung-Hyun Lee and Jingeun Lee and Minchul Shin}, 144 | year={2021}, 145 | eprint={2106.00186}, 146 | archivePrefix={arXiv}, 147 | primaryClass={cs.CV} 148 | } 149 | ``` 150 | 151 | ## License 152 | ``` 153 | Copyright 2021-present NAVER Corp. 154 | 155 | Licensed under the Apache License, Version 2.0 (the "License"); 156 | you may not use this file except in compliance with the License. 157 | You may obtain a copy of the License at 158 | 159 | http://www.apache.org/licenses/LICENSE-2.0 160 | 161 | Unless required by applicable law or agreed to in writing, software 162 | distributed under the License is distributed on an "AS IS" BASIS, 163 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 164 | See the License for the specific language governing permissions and 165 | limitations under the License. 166 | ``` 167 | -------------------------------------------------------------------------------- /calibrator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | import torch 7 | from torchvision.datasets import ImageFolder 8 | 9 | class ImageFolderCalibDataset(): 10 | 11 | def __init__(self, root): 12 | self.dataset = ImageFolder( 13 | root=root 14 | ) 15 | self.input_shape=[512, 512] 16 | 17 | def __len__(self): 18 | return len(self.dataset) 19 | 20 | def __getitem__(self, idx): 21 | image, _ = self.dataset[idx] 22 | image = np.asarray(image) 23 | img = cv2.resize(image, (512, 512)) 24 | resized_image = np.concatenate([cv2.resize(image, (self.input_shape[0], self.input_shape[1]), interpolation=cv2.INTER_AREA), 25 | np.ones([self.input_shape[0], self.input_shape[1], 1])], axis=-1) 26 | 27 | resized_image = resized_image.transpose((2,0,1)) 28 | batch_image = np.expand_dims(resized_image, axis=0).astype('float32') 29 | batch_image = (batch_image / 127.5) - 1.0 30 | batch_image = torch.from_numpy(batch_image).float().cuda() 31 | 32 | return batch_image 33 | 34 | -------------------------------------------------------------------------------- /convert/M-LSD_512_tiny_fp32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/convert/M-LSD_512_tiny_fp32.tflite -------------------------------------------------------------------------------- /convert/README.md: -------------------------------------------------------------------------------- 1 | ## Convert from TFLite model 2 | 3 | I don't found an ease-to-use tool convert the model from tflite to Pytorch.
4 | 5 | So wirte a network strictly according TFlite one, and use tflite + numpy to get the weight can be a choice.
6 | 7 | See [convert_tf_tiny.ipynb](./convert_tf_tiny.ipynb) 8 | -------------------------------------------------------------------------------- /data/frame_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/data/frame_1.jpg -------------------------------------------------------------------------------- /data/wireframe_anno.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/data/wireframe_anno.zip -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | import cv2 5 | 6 | from models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny 7 | from models.mbv2_mlsd_large import MobileV2_MLSD_Large 8 | 9 | from utils import pred_lines 10 | 11 | 12 | 13 | def main(): 14 | current_dir = os.path.dirname(__file__) 15 | if current_dir == "": 16 | current_dir = "./" 17 | # model_path = current_dir+'/models/mlsd_tiny_512_fp32.pth' 18 | # model = MobileV2_MLSD_Tiny().cuda().eval() 19 | 20 | model_path = current_dir + '/models/mlsd_large_512_fp32.pth' 21 | model = MobileV2_MLSD_Large().cuda().eval() 22 | 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | model.load_state_dict(torch.load(model_path, map_location=device), strict=True) 25 | 26 | img_fn = current_dir+'/data/frame_1.jpg' 27 | 28 | img = cv2.imread(img_fn) 29 | img = cv2.resize(img, (512, 512)) 30 | 31 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 32 | lines = pred_lines(img, model, [512, 512], 0.1, 20) 33 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 34 | 35 | for l in lines: 36 | cv2.line(img, (int(l[0]), int(l[1])), (int(l[2]), int(l[3])), (0,200,200), 1,16) 37 | cv2.imwrite(current_dir+'/data/frame_1_out.jpg', img) 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /demo_MLSD_flask.py: -------------------------------------------------------------------------------- 1 | ''' 2 | modified by lihaoweicv 3 | pytorch version 4 | ''' 5 | 6 | ''' 7 | M-LSD 8 | Copyright 2021-present NAVER Corp. 9 | Apache License v2.0 10 | ''' 11 | # for demo 12 | import os 13 | from flask import Flask, request, session, json, Response, render_template, abort, send_from_directory 14 | import requests 15 | from urllib.request import urlopen 16 | from io import BytesIO 17 | import uuid 18 | import cv2 19 | import time 20 | import argparse 21 | 22 | import numpy as np 23 | from PIL import Image 24 | import torch 25 | 26 | # for square detector 27 | from utils import pred_squares 28 | from models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny 29 | from models.mbv2_mlsd_large import MobileV2_MLSD_Large 30 | 31 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # CPU mode 32 | 33 | # flask 34 | current_dir = os.path.dirname(__file__) 35 | if current_dir == "": 36 | current_dir = "./" 37 | app = Flask(__name__, template_folder=current_dir+ '/templates/') 38 | logger = app.logger 39 | logger.info('init demo app') 40 | 41 | # config 42 | parser = argparse.ArgumentParser() 43 | 44 | ## model parameters 45 | parser.add_argument('--model_type', default='large', type=str) 46 | parser.add_argument('--model_dir', default='./models/', type=str) 47 | parser.add_argument('--input_size', default=512, type=int, 48 | help='The size of input images.') 49 | 50 | ## LSD parameter 51 | parser.add_argument('--score_thr', default=0.10, type=float, 52 | help='Discard center points when the score < score_thr.') 53 | 54 | ## intersection point parameters 55 | parser.add_argument('--outside_ratio', default=0.10, type=float, 56 | help='''Discard an intersection point 57 | when it is located outside a line segment farther than line_length * outside_ratio.''') 58 | parser.add_argument('--inside_ratio', default=0.50, type=float, 59 | help='''Discard an intersection point 60 | when it is located inside a line segment farther than line_length * inside_ratio.''') 61 | 62 | ## ranking boxes parameters 63 | parser.add_argument('--w_overlap', default=0.0, type=float, 64 | help='''When increasing w_overlap, the final box tends to overlap with 65 | the detected line segments as much as possible.''') 66 | parser.add_argument('--w_degree', default=1.14, type=float, 67 | help='''When increasing w_degree, the final box tends to be 68 | a parallel quadrilateral with reference to the angle of the box.''') 69 | parser.add_argument('--w_length', default=0.03, type=float, 70 | help='''When increasing w_length, the final box tends to be 71 | a parallel quadrilateral with reference to the length of the box.''') 72 | parser.add_argument('--w_area', default=1.84, type=float, 73 | help='When increasing w_area, the final box tends to be the largest one out of candidates.') 74 | parser.add_argument('--w_center', default=1.46, type=float, 75 | help='When increasing w_center, the final box tends to be located in the center of input image.') 76 | 77 | ## flask demo parameter 78 | parser.add_argument('--port', default=5000, type=int, 79 | help='flask demo will be running on http://0.0.0.0:port/') 80 | 81 | 82 | class model_graph: 83 | def __init__(self, args): 84 | self.model = self.load(args.model_dir, args.model_type) 85 | self.params = {'score': args.score_thr,'outside_ratio': args.outside_ratio,'inside_ratio': args.inside_ratio, 86 | 'w_overlap': args.w_overlap,'w_degree': args.w_degree,'w_length': args.w_length, 87 | 'w_area': args.w_area,'w_center': args.w_center} 88 | self.args = args 89 | 90 | 91 | def load(self, model_dir, mode_type): 92 | model_path = model_dir +"/mlsd_tiny_512_fp32.pth" 93 | if mode_type == 'large': 94 | model_path = model_dir +"/mlsd_large_512_fp32.pth" 95 | torch_model = MobileV2_MLSD_Large().cuda().eval() 96 | else: 97 | torch_model = MobileV2_MLSD_Tiny().cuda().eval() 98 | 99 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 100 | torch_model.load_state_dict(torch.load(model_path, map_location=device), strict=True) 101 | self.torch_model = torch_model 102 | 103 | return torch_model 104 | 105 | 106 | def pred(self, image): 107 | segments, squares, score_array, inter_points = pred_squares(image, self.torch_model, 108 | [self.args.input_size, self.args.input_size], 109 | params=self.params) 110 | 111 | output = {} 112 | output['segments'] = segments 113 | output['squares'] = squares 114 | output['scores'] = score_array 115 | output['inter_points'] = inter_points 116 | 117 | return output 118 | 119 | 120 | def read_image(self, image_url): 121 | response = requests.get(image_url, stream=True) 122 | image = np.asarray(Image.open(BytesIO(response.content)).convert('RGB')) 123 | 124 | max_len = 1024 125 | h, w, _ = image.shape 126 | org_shape = [h, w] 127 | max_idx = np.argmax(org_shape) 128 | 129 | max_val = org_shape[max_idx] 130 | if max_val > max_len: 131 | min_idx = (max_idx + 1) % 2 132 | ratio = max_len / max_val 133 | new_min = org_shape[min_idx] * ratio 134 | new_shape = [0, 0] 135 | new_shape[max_idx] = 1024 136 | new_shape[min_idx] = new_min 137 | 138 | image = cv2.resize(image, (int(new_shape[1]), int(new_shape[0])), interpolation=cv2.INTER_AREA) 139 | 140 | return image 141 | 142 | 143 | def init_resize_image(self, im, maximum_size=1024): 144 | h, w, _ = im.shape 145 | size = [h, w] 146 | max_arg = np.argmax(size) 147 | max_len = size[max_arg] 148 | min_arg = max_arg - 1 149 | min_len = size[min_arg] 150 | if max_len < maximum_size: 151 | return im 152 | else: 153 | ratio = maximum_size / max_len 154 | max_len = max_len * ratio 155 | min_len = min_len * ratio 156 | size[max_arg] = int(max_len) 157 | size[min_arg] = int(min_len) 158 | 159 | im = cv2.resize(im, (size[1], size[0]), interpolation = cv2.INTER_AREA) 160 | 161 | return im 162 | 163 | 164 | def decode_image(self, session_id, rawimg): 165 | dirpath = os.path.join('static/results', session_id) 166 | 167 | if not os.path.exists(dirpath): 168 | os.makedirs(dirpath) 169 | save_path = os.path.join(dirpath, 'input.png') 170 | input_image_url = os.path.join(dirpath, 'input.png') 171 | 172 | img = cv2.imdecode(np.frombuffer(rawimg, dtype='uint8'), 1)[:,:,::-1] 173 | img = self.init_resize_image(img) 174 | cv2.imwrite(save_path, img[:,:,::-1]) 175 | 176 | return img, input_image_url 177 | 178 | 179 | def draw_output(self, image, output, save_path='test.png'): 180 | color_dict = {'red': [255, 0, 0], 181 | 'green': [0, 255, 0], 182 | 'blue': [0, 0, 255], 183 | 'cyan': [0, 255, 255], 184 | 'black': [0, 0, 0], 185 | 'yellow': [255, 255, 0], 186 | 'dark_yellow': [200, 200, 0]} 187 | 188 | line_image = image.copy() 189 | square_image = image.copy() 190 | square_candidate_image = image.copy() 191 | 192 | line_thick = 5 193 | 194 | # output > line array 195 | for line in output['segments']: 196 | x_start, y_start, x_end, y_end = [int(val) for val in line] 197 | cv2.line(line_image, (x_start, y_start), (x_end, y_end), color_dict['red'], line_thick) 198 | 199 | inter_image = line_image.copy() 200 | 201 | for pt in output['inter_points']: 202 | x, y = [int(val) for val in pt] 203 | cv2.circle(inter_image, (x, y), 10, color_dict['blue'], -1) 204 | 205 | for square in output['squares']: 206 | cv2.polylines(square_candidate_image, [square.reshape([-1, 1, 2])], True, color_dict['dark_yellow'], line_thick) 207 | 208 | for square in output['squares'][0:1]: 209 | cv2.polylines(square_image, [square.reshape([-1, 1, 2])], True, color_dict['yellow'], line_thick) 210 | for pt in square: 211 | cv2.circle(square_image, (int(pt[0]), int(pt[1])), 10, color_dict['cyan'], -1) 212 | 213 | ''' 214 | square image | square candidates image 215 | inter image | line image 216 | ''' 217 | output_image = self.init_resize_image(square_image, 512) 218 | output_image = np.concatenate([output_image, self.init_resize_image(square_candidate_image, 512)], axis=1) 219 | output_image_tmp = np.concatenate([self.init_resize_image(inter_image, 512), self.init_resize_image(line_image, 512)], axis=1) 220 | output_image = np.concatenate([output_image, output_image_tmp], axis=0) 221 | 222 | cv2.imwrite(save_path, output_image[:,:,::-1]) 223 | 224 | return output_image 225 | 226 | 227 | def save_output(self, session_id, input_image_url, image, output): 228 | dirpath = os.path.join('static/results', session_id) 229 | 230 | if not os.path.exists(dirpath): 231 | os.makedirs(dirpath) 232 | 233 | save_path = os.path.join(dirpath, 'output.png') 234 | self.draw_output(image, output, save_path=save_path) 235 | 236 | output_image_url = os.path.join(dirpath, 'output.png') 237 | 238 | rst = {} 239 | rst['input_image_url'] = input_image_url 240 | rst['session_id'] = session_id 241 | rst['output_image_url'] = output_image_url 242 | 243 | with open(os.path.join(dirpath, 'results.json'), 'w') as f: 244 | json.dump(rst, f) 245 | 246 | 247 | def init_worker(args): 248 | global model 249 | 250 | model = model_graph(args) 251 | 252 | 253 | @app.route('/') 254 | def index(): 255 | return render_template('index_scan.html', session_id='dummy_session_id') 256 | 257 | 258 | @app.route('/', methods=['POST']) 259 | def index_post(): 260 | request_start = time.time() 261 | configs = request.form 262 | 263 | session_id = str(uuid.uuid1()) 264 | 265 | image_url = configs['image_url'] # image_url 266 | 267 | if len(image_url) == 0: 268 | bio = BytesIO() 269 | request.files['image'].save(bio) 270 | rawimg = bio.getvalue() 271 | image, image_url = model.decode_image(session_id, rawimg) 272 | else: 273 | image = model.read_image(image_url) 274 | 275 | output = model.pred(image) 276 | 277 | model.save_output(session_id, image_url, image, output) 278 | 279 | return render_template('index_scan.html', session_id=session_id) 280 | 281 | 282 | @app.route('/favicon.ico') 283 | def favicon(): 284 | return send_from_directory(os.path.join(app.root_path, 'static'), 285 | 'favicon.ico', mimetype='image/vnd.microsoft.icon') 286 | 287 | 288 | if __name__ == '__main__': 289 | args = parser.parse_args() 290 | 291 | init_worker(args) 292 | 293 | app.run(host='0.0.0.0', port=args.port) 294 | -------------------------------------------------------------------------------- /docker-compose.flask.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | services: 4 | app: 5 | command: "python demo_MLSD_flask.py" 6 | ports: 7 | - "5000:5000" 8 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | services: 4 | app: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | image: mlsd 9 | environment: 10 | - NVIDIA_VISIBLE_DEVICES=0 11 | - NVIDIA_DRIVER_CAPABILITIES=compute 12 | volumes: 13 | - ${PWD}:/ws 14 | working_dir: /ws 15 | command: "python demo.py" 16 | -------------------------------------------------------------------------------- /github/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/github/img.png -------------------------------------------------------------------------------- /github/mlsd_mobile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/github/mlsd_mobile.png -------------------------------------------------------------------------------- /github/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/github/teaser.png -------------------------------------------------------------------------------- /mlsd_pytorch/README.md: -------------------------------------------------------------------------------- 1 | # M-LSD: Towards Light-weight and Real-time Line Segment Detection 2 | 3 | Pytorch implementation with training code.
4 | (this is the sub project for https://github.com/lhwcv/mlsd_pytorch) 5 | 6 | ## result 7 | 8 | model| img_size| sAP10 9 | |---|---|:---:| 10 | mlsd_tiny (this repo)| 512| 56.4 11 | mlsd_tiny (in the paper)| 512| 58.0 12 | mlsd_large (this repo)| 512| 59.6 13 | mlsd_large (in the paper)| 512| 62.1 14 | 15 | (this repo use: min_score=0.05, min_len=5, tok_k_lines= 500) 16 | 17 | ## differences 18 | Due to no official opensource training code, I try to reproduce,
19 | but not get a good result according to the paper, so I make some differences.
20 | (Looking forward official opensource training code)
21 | 22 | main differences compared to the paper: (the up result) 23 | 24 | - center map use Focal Loss instead of WCE (can modify to WCE) 25 | - Use Step LR instead of Cosine .. (can modify to the later) 26 | - Use deconv instead of upsample in tiny model 27 | - No matching loss (I guess my matching loss has BUGS) 28 | - Batch size = 24 ( 64 in the paper, large batch size may good) 29 | 30 | ## val 31 | (pretrained models put in workdir/models) 32 | 33 | eval tiny: 34 | 35 | ``` 36 | python mlsd_pytorch/pred_and_eval_sAP.py 37 | ``` 38 | (modify some args can eval large) 39 | 40 | ## train 41 | 42 | ### Data Preparation 43 | (You can also follow [AFM](https://github.com/cherubicXN/afm_cvpr2019 ) or others,almost the same ) 44 | - Download the [Wireframe dataset](https://github.com/huangkuns/wireframe) and the [YorkUrban dataset](http://www.elderlab.yorku.ca/resources/york-urban-line-segment-database-information/) from their project pages. 45 | - Download the JSON-format annotations ([Google Drive](https://drive.google.com/file/d/15z3-xgIzj_-9bep8l6s8dgIbpjKp_8VK/view?usp=sharing)). 46 | - Place the images to "data/wireframe_raw/images/" 47 | - Unzip the json-format annotations to "data/wireframe_raw/" 48 | 49 | The structure of the data folder should be 50 | ```shell 51 | data/ 52 | wireframe_raw/images/*.png 53 | wireframe_raw/train.json 54 | wireframe_raw/valid.json 55 | 56 | ``` 57 | ### Train 58 | tiny: 59 | ``` 60 | 61 | python mlsd_pytorch/train.py \ 62 | --config mlsd_pytorch/configs/mobilev2_mlsd_tiny_512_base2_bsize24.yaml 63 | ``` 64 | 65 | large: 66 | ``` 67 | 68 | python mlsd_pytorch/train.py \ 69 | --config mlsd_pytorch/configs/mobilev2_mlsd_large_512_base2_bsize24.yaml 70 | ``` -------------------------------------------------------------------------------- /mlsd_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/mlsd_pytorch/__init__.py -------------------------------------------------------------------------------- /mlsd_pytorch/cfg/default.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | __all_ = ['get_cfg_defaults'] 3 | ## 4 | _C = CN() 5 | _C.sys = CN() 6 | _C.sys.cpu = False 7 | _C.sys.gpus = 1 8 | _C.sys.num_workers = 8 9 | ## 10 | _C.datasets = CN() 11 | _C.datasets.name = '' 12 | _C.datasets.input_size = 512 13 | _C.datasets.with_centermap_extend = False 14 | 15 | 16 | ## 17 | _C.model = CN() 18 | _C.model.model_name = '' 19 | _C.model.with_deconv = False 20 | _C.model.num_classes = 1 21 | 22 | 23 | ## 24 | _C.train = CN() 25 | _C.train.do_train = True 26 | _C.train.batch_size = 48 27 | _C.train.save_dir = '' 28 | _C.train.gradient_accumulation_steps = 1 29 | _C.train.num_train_epochs = 170 30 | _C.train.use_step_lr_policy = False 31 | _C.train.warmup_steps = 200 32 | _C.train.learning_rate = 0.0008 33 | _C.train.dropout = 0.1 34 | _C.train.milestones = [100, 150] 35 | _C.train.milestones_in_epo = True 36 | _C.train.lr_decay_gamma = 0.1 37 | _C.train.weight_decay = 0.000001 38 | _C.train.device_ids_str = "0" 39 | _C.train.device_ids = [0] 40 | _C.train.adam_epsilon = 1e-6 41 | _C.train.early_stop_n = 200 42 | _C.train.device_ids_str = "0" 43 | _C.train.device_ids = [0] 44 | _C.train.num_workers = 8 45 | _C.train.log_steps = 50 46 | 47 | _C.train.img_dir = '' 48 | _C.train.label_fn = '' 49 | _C.train.data_cache_dir = '' 50 | _C.train.with_cache = False 51 | _C.train.cache_to_mem = False 52 | 53 | 54 | _C.train.load_from = "" 55 | ## 56 | _C.val = CN() 57 | _C.val.batch_size = 8 58 | 59 | _C.val.img_dir = '' 60 | _C.val.label_fn = '' 61 | 62 | _C.val.val_after_epoch = 0 63 | 64 | _C.loss = CN() 65 | _C.loss.loss_weight_dict_list = [] 66 | _C.loss.loss_type = '1*L1' 67 | _C.loss.with_sol_loss = True 68 | _C.loss.with_match_loss = False 69 | _C.loss.with_focal_loss = True 70 | _C.loss.match_sap_thresh = 5.0 71 | _C.loss.focal_loss_level = 0 72 | 73 | _C.decode = CN() 74 | _C.decode.score_thresh = 0.05 75 | _C.decode.len_thresh = 5 76 | _C.decode.top_k = 500 77 | 78 | 79 | def get_cfg_defaults(merge_from = None): 80 | cfg = _C.clone() 81 | if merge_from is not None: 82 | cfg.merge_from_other_cfg(merge_from) 83 | return cfg -------------------------------------------------------------------------------- /mlsd_pytorch/configs/mobilev2_mlsd_large_512_base2_bsize24.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | name: 'wireframe' 3 | input_size: 512 4 | 5 | model: 6 | model_name: 'mobilev2_mlsd_large' 7 | with_deconv: True 8 | 9 | 10 | train: 11 | save_dir: './workdir/models/mobilev2_mlsd_large_512_bsize24/' 12 | img_dir: "./data/wireframe_raw/images/" 13 | label_fn: "./data/wireframe_raw/train.json" 14 | num_train_epochs: 155 15 | batch_size: 24 16 | learning_rate: 0.003 17 | use_step_lr_policy: True 18 | weight_decay: 0.000001 19 | load_from: "" 20 | warmup_steps: 100 21 | milestones: [ 50, 100, 150 ] 22 | milestones_in_epo: True 23 | lr_decay_gamma: 0.2 24 | 25 | data_cache_dir: "./data/wireframe_cache/" 26 | with_cache: False 27 | cache_to_mem: False 28 | 29 | val: 30 | img_dir: "./data/wireframe_raw/images/" 31 | label_fn: "./data/wireframe_raw/valid.json" 32 | batch_size: 8 33 | val_after_epoch: 50 34 | 35 | loss: 36 | loss_weight_dict_list: [ { 'tp_center_loss': 10.0,'sol_center_loss': 1.0,'tp_match_loss':1.0 } ] 37 | 38 | with_match_loss: False 39 | with_focal_loss: True 40 | focal_loss_level: 0 41 | with_sol_loss: True 42 | match_sap_thresh: 5.0 43 | 44 | decode: 45 | score_thresh: 0.05 46 | len_thresh: 5 47 | top_k: 500 -------------------------------------------------------------------------------- /mlsd_pytorch/configs/mobilev2_mlsd_tiny_512_base2_bsize24.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | name: 'wireframe' 3 | input_size: 512 4 | 5 | model: 6 | model_name: 'mobilev2_mlsd' 7 | with_deconv: True 8 | 9 | 10 | train: 11 | save_dir: './workdir/models/mobilev2_mlsd_tiny_512_bsize24/' 12 | img_dir: "./data/wireframe_raw/images/" 13 | label_fn: "./data/wireframe_raw/train.json" 14 | num_train_epochs: 155 15 | batch_size: 24 16 | learning_rate: 0.003 17 | use_step_lr_policy: True 18 | weight_decay: 0.000001 19 | load_from: "" 20 | warmup_steps: 100 21 | milestones: [ 50, 100, 150 ] 22 | milestones_in_epo: True 23 | lr_decay_gamma: 0.2 24 | 25 | data_cache_dir: "./data/wireframe_cache/" 26 | with_cache: False 27 | cache_to_mem: False 28 | 29 | val: 30 | img_dir: "./data/wireframe_raw/images/" 31 | label_fn: "./data/wireframe_raw/valid.json" 32 | batch_size: 8 33 | val_after_epoch: 50 34 | 35 | loss: 36 | loss_weight_dict_list: [ { 'tp_center_loss': 10.0,'sol_center_loss': 1.0,'tp_match_loss':1.0 } ] 37 | 38 | with_match_loss: False 39 | with_focal_loss: True 40 | focal_loss_level: 0 41 | with_sol_loss: True 42 | match_sap_thresh: 5.0 43 | 44 | decode: 45 | score_thresh: 0.05 46 | len_thresh: 5 47 | top_k: 500 -------------------------------------------------------------------------------- /mlsd_pytorch/data/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset,DataLoader 2 | from mlsd_pytorch.data.wireframe_dset import Line_Dataset, LineDataset_collate_fn 3 | 4 | __mapping_dataset = { 5 | 'wireframe': Line_Dataset, 6 | } 7 | 8 | __mapping_dataset_collate_fn = { 9 | 'wireframe': LineDataset_collate_fn, 10 | } 11 | 12 | def get_dataset(cfg, is_train = True): 13 | if cfg.datasets.name not in __mapping_dataset.keys(): 14 | raise NotImplementedError('Dataset Type : {} not supported!'.format(cfg.datasets.name)) 15 | return __mapping_dataset[cfg.datasets.name]( 16 | cfg, 17 | is_train = is_train 18 | ) 19 | 20 | def get_collate_fn(cfg): 21 | if cfg.datasets.name not in __mapping_dataset_collate_fn.keys(): 22 | raise NotImplementedError('Dataset Type not supported!') 23 | return __mapping_dataset_collate_fn[cfg.datasets.name] 24 | 25 | def get_train_dataloader(cfg): 26 | ds = get_dataset(cfg, True) 27 | dloader = DataLoader( 28 | ds, 29 | batch_size = cfg.train.batch_size, 30 | shuffle = True, 31 | num_workers = cfg.sys.num_workers, 32 | drop_last=True, 33 | collate_fn= get_collate_fn(cfg) 34 | ) 35 | return dloader 36 | 37 | def get_val_dataloader(cfg): 38 | ds = get_dataset(cfg, False) 39 | dloader = DataLoader( 40 | ds, 41 | batch_size = cfg.val.batch_size, 42 | shuffle = False, 43 | num_workers = cfg.sys.num_workers, 44 | drop_last=False, 45 | collate_fn= get_collate_fn(cfg) 46 | ) 47 | return dloader 48 | 49 | -------------------------------------------------------------------------------- /mlsd_pytorch/data/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import cv2 5 | from torch.nn import functional as F 6 | 7 | def swap_line_pt_maybe(line): 8 | ''' 9 | [x0, y0, x1, y1] 10 | ''' 11 | L = line 12 | # if line[0] > line[2]: 13 | # L = [line[2], line[3], line[0], line[1]] 14 | if abs(line[0] - line[2]) > abs(line[1] - line[3]): 15 | if line[0] > line[2]: 16 | L = [line[2], line[3], line[0], line[1]] 17 | else: 18 | if line[1] > line[3]: 19 | L = [line[2], line[3], line[0], line[1]] 20 | return L 21 | 22 | def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): 23 | ''' 24 | tpMap: 25 | center: tpMap[1, 0, :, :] 26 | displacement: tpMap[1, 1:5, :, :] 27 | ''' 28 | b, c, h, w = tpMap.shape 29 | assert b==1, 'only support bsize==1' 30 | displacement = tpMap[:, 1:5, :, :][0] 31 | center = tpMap[:, 0, :, :] 32 | heat = torch.sigmoid(center) 33 | hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) 34 | keep = (hmax == heat).float() 35 | heat = heat * keep 36 | heat = heat.reshape(-1, ) 37 | 38 | scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) 39 | yy = torch.floor_divide(indices, w).unsqueeze(-1) 40 | xx = torch.fmod(indices, w).unsqueeze(-1) 41 | ptss = torch.cat((yy, xx),dim=-1) 42 | 43 | ptss = ptss.detach().cpu().numpy() 44 | scores = scores.detach().cpu().numpy() 45 | displacement = displacement.detach().cpu().numpy() 46 | 47 | return ptss, scores, displacement 48 | 49 | # def deccode_lines(tpMap,score_thr=0.1, dist_thr= 20, topk_n = 200, ksize = 3): 50 | # pts, pts_score, vmap = deccode_output_score_and_ptss(tpMap, topk_n=topk_n, ksize=ksize) 51 | # 52 | # start = vmap[:2, :, :] 53 | # end = vmap[2:, :, :] 54 | # dist_map = np.sqrt(np.sum((start - end) ** 2, axis=0)) 55 | # 56 | # segments_list = [] 57 | # scores = [] 58 | # for center, score in zip(pts, pts_score): 59 | # y, x = center 60 | # distance = dist_map[y, x] 61 | # if score > score_thr and distance > dist_thr: 62 | # disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[:, y, x] 63 | # x_start = x + disp_x_start 64 | # y_start = y + disp_y_start 65 | # x_end = x + disp_x_end 66 | # y_end = y + disp_y_end 67 | # segments_list.append([x_start, y_start, x_end, y_end]) 68 | # scores.append(score) 69 | # 70 | # lines = np.array(segments_list) 71 | # return lines, scores 72 | 73 | def deccode_lines(tpMap,score_thresh = 0.1, len_thresh=2, topk_n = 1000, ksize = 3 ): 74 | ''' 75 | tpMap: 76 | center: tpMap[1, 0, :, :] 77 | displacement: tpMap[1, 1:5, :, :] 78 | ''' 79 | b, c, h, w = tpMap.shape 80 | assert b==1, 'only support bsize==1' 81 | displacement = tpMap[:, 1:5, :, :] 82 | center = tpMap[:, 0, :, :] 83 | heat = torch.sigmoid(center) 84 | hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) 85 | keep = (hmax == heat).float() 86 | heat = heat * keep 87 | heat = heat.reshape(-1, ) 88 | 89 | heat = torch.where(heat score_thresh) 93 | scores = scores[valid_inx] 94 | indices = indices[valid_inx] 95 | 96 | yy = torch.floor_divide(indices, w).unsqueeze(-1) 97 | xx = torch.fmod(indices, w).unsqueeze(-1) 98 | center_ptss = torch.cat((xx, yy),dim=-1) 99 | 100 | start_point = center_ptss + displacement[0, :2, yy, xx].reshape(2, -1).permute(1,0) 101 | end_point = center_ptss + displacement[0, 2:, yy, xx].reshape(2, -1).permute(1,0) 102 | 103 | lines = torch.cat((start_point, end_point), dim=-1) 104 | 105 | all_lens = (end_point - start_point) ** 2 106 | all_lens = all_lens.sum(dim=-1) 107 | all_lens = torch.sqrt(all_lens) 108 | valid_inx = torch.where(all_lens > len_thresh) 109 | 110 | center_ptss = center_ptss[valid_inx] 111 | lines = lines[valid_inx] 112 | scores = scores[valid_inx] 113 | 114 | return center_ptss, lines, scores 115 | 116 | 117 | def _max_pool_np(x, kernel=5): 118 | heat = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) 119 | pad = (kernel - 1) // 2 120 | hmax = nn.functional.max_pool2d( 121 | heat, (kernel, kernel), stride=1, padding=pad) 122 | heat = hmax.numpy()[0] 123 | return heat 124 | 125 | 126 | def _nms(heat, kernel=3): 127 | is_np = isinstance(heat, np.ndarray) 128 | if is_np: 129 | heat = torch.from_numpy(heat).unsqueeze(0) 130 | 131 | pad = (kernel - 1) // 2 132 | hmax = nn.functional.max_pool2d( 133 | heat, (kernel, kernel), stride=1, padding=pad) 134 | 135 | keep = (hmax == heat).float() 136 | heat = heat * keep 137 | 138 | if is_np: 139 | heat = heat.cpu().numpy()[0] 140 | 141 | return heat 142 | 143 | 144 | def TP_map_to_line_numpy(centermap, dis_map, thresh=0.2, inputW = 512, inputH= 512): 145 | """ 146 | centermap: (1, h, w) 147 | dis_map: (4, h, w) 148 | """ 149 | _, h, w = centermap.shape 150 | h_ratio, w_ratio = [h / inputH, w / inputW] 151 | 152 | center_nms = _nms(centermap, kernel=3)[0] 153 | # print('center_nms.shape:', center_nms.shape) 154 | 155 | center_pos = np.where(center_nms > thresh) 156 | ## [y, x] 157 | center_pos = np.array([center_pos[1], center_pos[0]]) 158 | # print("center_pos.shape:", center_pos.shape) 159 | 160 | dis_list = dis_map[:, center_pos[1], center_pos[0]] 161 | #print(dis_list) 162 | ## [x, y] 163 | dis_list = dis_list.transpose(1, 0) 164 | 165 | center_pos = center_pos.transpose(1, 0) 166 | 167 | #cale = np.array([w / 100.0, h / 100.0]) 168 | scale = np.array([w_ratio, h_ratio]) 169 | start_point = center_pos + dis_list[:, 0:2] * scale * 2 170 | end_point = center_pos + dis_list[:, 2:4] * scale * 2 171 | 172 | line = np.stack([start_point, end_point], axis=1) 173 | return 2 *line.reshape((-1, 4)) 174 | 175 | # def work_around_line(x0, y0, x1, y1, n=2, r=0.0, thickness = 3): 176 | # t = (thickness - 1) // 2 177 | # 178 | # if abs(x0 - x1) > abs(y0 -y1): 179 | # ## y = k* x + b 180 | # k = (y1 - y0) / (x1 - x0) 181 | # b = y1 - k * x1 182 | # 183 | # ptss = [] 184 | # xc = (x0 + x1) / 2 185 | # if n is None: 186 | # n = int(abs(x1 - x0) * r) 187 | # 188 | # xmin = int(xc - n) 189 | # xmax = int(xc + n) 190 | # for x in range(xmin, xmax+1): 191 | # y = k * x + b 192 | # for iy in range(thickness): 193 | # ptss.append([x, y + t - iy]) 194 | # 195 | # return ptss 196 | # else: 197 | # ## x = k* y + b 198 | # k = (x1 - x0) / (y1 - y0) 199 | # b = x1 - k * y1 200 | # ptss = [] 201 | # 202 | # yc = (y0 + y1) / 2 203 | # if n is None: 204 | # n = int(abs(y1 - y0) * r) 205 | # ymin = int(yc - n) 206 | # ymax = int(yc + n) 207 | # 208 | # for y in range(ymin, ymax+1): 209 | # x =k * y + b 210 | # for ix in range(thickness): 211 | # ptss.append([x + t - ix, y]) 212 | # return ptss 213 | 214 | 215 | # def near_area_n(xc, yc, n= 5): 216 | # n = n // 2 217 | # ptss = [] 218 | # for x in range(xc-n, xc + n +1): 219 | # for y in range(yc - n, yc +n +1): 220 | # ptss.append([x, y]) 221 | # return ptss 222 | 223 | # def line_len_and_angle(x0, y0, x1, y1): 224 | # if abs(x0 - x1) < 1e-6: 225 | # ang = np.pi / 2 226 | # else: 227 | # ang = np.arctan( abs ( (y0 -y1) / (x0 -x1) ) ) 228 | # 229 | # ang = ang / (2 * np.pi) + 0.5 230 | # lens = np.sqrt( (x0 - x1) **2 + (y0 - y1) **2) 231 | # 232 | # return lens, ang 233 | 234 | 235 | def line_len_and_angle(x0, y0, x1, y1): 236 | if abs(x0 - x1) < 1e-6: 237 | ang = np.pi / 2 238 | else: 239 | ang = np.arctan(abs((y0 - y1) / (x0 - x1))) 240 | 241 | ang = ang / (2 * np.pi) + 0.5 242 | len = np.sqrt((x0 - x1) ** 2 + (y0 - y1) ** 2) 243 | return len, ang 244 | 245 | 246 | def near_area_n(xc, yc, n=5): 247 | if n <= 1: 248 | return [[xc, yc]] 249 | n = n // 2 250 | ptss = [] 251 | for x in range(xc - n, xc + n + 1): 252 | for y in range(yc - n, yc + n + 1): 253 | ptss.append([x, y]) 254 | return ptss 255 | 256 | def cut_line_by_xmin(line, xmin): 257 | if line[0] > xmin and line[2] > xmin: 258 | return True, line 259 | if line[0] <= xmin and line[2] <= xmin: 260 | return False, line 261 | if abs(line[0] - line[2]) < 1: 262 | return False, line 263 | # y = k*x + b 264 | k = (line[3] - line[1]) / (line[2] - line[0]) 265 | b = line[3] - k * line[2] 266 | y = k * xmin + b 267 | p0 = [xmin, y] 268 | if line[0] < line[2]: 269 | p1 = [line[2], line[3]] 270 | else: 271 | p1 = [line[0], line[1]] 272 | line = [p0[0], p0[1], p1[0], p1[1]] 273 | 274 | return True, line 275 | 276 | def cut_line_by_xmax(line, xmax): 277 | if line[0] < xmax and line[2] < xmax: 278 | return True, line 279 | if line[0] >= xmax and line[2] >= xmax: 280 | return False, line 281 | if abs(line[0] - line[2]) < 1: 282 | return False, line 283 | # y = k*x + b 284 | k = (line[3] - line[1]) / (line[2] - line[0]) 285 | b = line[3] - k * line[2] 286 | y = k * xmax + b 287 | p1 = [xmax, y] 288 | if line[0] > line[2]: 289 | p0 = [line[2], line[3]] 290 | else: 291 | p0 = [line[0], line[1]] 292 | return True, [p0[0], p0[1], p1[0], p1[1]] 293 | 294 | def work_around_line(x0, y0, x1, y1, n=2, r=0.0, thickness=3): 295 | t = (thickness - 1) // 2 296 | # print("p:", x0, y0, x1, y1) 297 | if abs(x0 - x1) > abs(y0 - y1): 298 | ## y = k* x + b 299 | k = (y1 - y0) / (x1 - x0) 300 | b = y1 - k * x1 301 | 302 | ptss = [] 303 | xc = (x0 + x1) / 2 304 | if n is None: 305 | n = int(abs(x1 - x0) * r) 306 | 307 | xmin = int(xc - n) 308 | xmax = int(xc + n) 309 | for x in range(xmin, xmax + 1): 310 | y = k * x + b 311 | for iy in range(thickness): 312 | ptss.append([x, y + t - iy]) 313 | 314 | return ptss 315 | else: 316 | ## x = k* y + b 317 | k = (x1 - x0) / (y1 - y0) 318 | b = x1 - k * y1 319 | ptss = [] 320 | 321 | yc = (y0 + y1) / 2 322 | if n is None: 323 | n = int(abs(y1 - y0) * r) 324 | ymin = int(yc - n) 325 | ymax = int(yc + n) 326 | 327 | for y in range(ymin, ymax + 1): 328 | x = k * y + b 329 | for ix in range(thickness): 330 | ptss.append([x + t - ix, y]) 331 | return ptss 332 | 333 | 334 | # def gen_TP_mask(norm_lines, h=256, w=256, with_ext=True): 335 | # """ 336 | # 1 cengter + 4 dis + 2 337 | # return [7, h, w] 338 | # """ 339 | # 340 | # # h, w, _ = img.shape 341 | # 342 | # len_divide_v = np.sqrt(h ** 2 + w ** 2) 343 | # 344 | # centermap = np.zeros((1, h, w), dtype=np.uint8) 345 | # 346 | # displacement_map = np.zeros((4, h, w), dtype=np.float32) 347 | # length_map = np.zeros((1, h, w), dtype=np.float32) 348 | # degree_map = np.zeros((1, h, w), dtype=np.float32) 349 | # 350 | # for l in norm_lines: 351 | # x0, y0, x1, y1 = w * l[0], h * l[1], w * l[2], h * l[3] 352 | # 353 | # # print("p:", x0, y0, x1, y1) 354 | # 355 | # xc = (x0 + x1) / 2 356 | # yc = (y0 + y1) / 2 357 | # 358 | # if with_ext: 359 | # len_max = max(abs(x1 - x0), abs(y1 - y0)) 360 | # # len_max = int(0.5 * len_max) 361 | # exp_pix = min(7, len_max) 362 | # ptss = work_around_line(x0, y0, x1, y1, n=exp_pix, thickness=1) 363 | # 364 | # for p in ptss: 365 | # xx = int(round(p[0])) 366 | # yy = int(round(p[1])) 367 | # 368 | # xx = np.clip(xx, 0, w - 1) 369 | # yy = np.clip(yy, 0, h - 1) 370 | # 371 | # sx = (1 - abs(xx - xc) / (2 * exp_pix)) 372 | # sy = (1 - abs(yy - yc) / (2 * exp_pix)) 373 | # 374 | # centermap[0, yy, xx] = 255 * sx * sy 375 | # 376 | # x0d = x0 - xx 377 | # y0d = y0 - yy 378 | # x1d = x1 - xx 379 | # y1d = y1 - yy 380 | # 381 | # displacement_map[0, yy, xx] = x0d 382 | # displacement_map[1, yy, xx] = y0d 383 | # displacement_map[2, yy, xx] = x1d 384 | # displacement_map[3, yy, xx] = y1d 385 | # 386 | # line_len, ang = line_len_and_angle(x0, y0, x1, y1) 387 | # line_len /= len_divide_v 388 | # 389 | # x0d = x0 - xc 390 | # y0d = y0 - yc 391 | # x1d = x1 - xc 392 | # y1d = y1 - yc 393 | # 394 | # # ptss = [ 395 | # # [int(np.floor(xc)), int(np.floor(yc))], 396 | # # [int(np.ceil(xc)), int(np.ceil(yc))], 397 | # # [int(np.floor(xc)), int(np.ceil(yc))], 398 | # # [int(np.ceil(xc)), int(np.floor(yc))], 399 | # # ] 400 | # xc = int(round(xc)) 401 | # yc = int(round(yc)) 402 | # ptss = near_area_n(xc, yc, 3) 403 | # 404 | # for p in ptss: 405 | # xx = int(round(p[0])) 406 | # yy = int(round(p[1])) 407 | # 408 | # xx = np.clip(xx, 0, w - 1) 409 | # yy = np.clip(yy, 0, h - 1) 410 | # 411 | # centermap[0, yy, xx] = 255 412 | # length_map[0, yy, xx] = line_len 413 | # degree_map[0, yy, xx] = ang 414 | # 415 | # displacement_map[0, yy, xx] = x0d 416 | # displacement_map[1, yy, xx] = y0d 417 | # displacement_map[2, yy, xx] = x1d 418 | # displacement_map[3, yy, xx] = y1d 419 | # 420 | # centermap[0, :, :] = cv2.GaussianBlur(centermap[0, :, :], (3, 3), 0.0) 421 | # centermap = np.array(centermap, dtype=np.float32) / 255.0 422 | # b = centermap.max() - centermap.min() 423 | # if b != 0: 424 | # centermap = (centermap - centermap.min()) / b 425 | # 426 | # tp_mask = np.concatenate((centermap, displacement_map, length_map, degree_map), axis=0) 427 | # return tp_mask 428 | 429 | def gen_TP_mask2(norm_lines, h = 256, w = 256, with_ext=False): 430 | """ 431 | 1 cengter + 4 dis + 2 432 | return [7, h, w] 433 | """ 434 | 435 | #h, w, _ = img.shape 436 | 437 | len_divide_v = np.sqrt(h**2 + w**2) 438 | radius = 1 439 | 440 | centermap = np.zeros((1, h, w), dtype=np.uint8) 441 | #displacement_map = -np.ones((4, h, w), dtype=np.float32) * 1000.0 442 | 443 | displacement_map = np.zeros((4, h, w), dtype=np.float32) 444 | length_map = np.zeros((1, h, w), dtype=np.float32) 445 | degree_map = np.zeros((1, h, w), dtype=np.float32) 446 | 447 | for l in norm_lines: 448 | x0 = int(round(l[0] * w)) 449 | y0 = int(round(l[1] * h)) 450 | x1 = int(round(l[2] * w)) 451 | y1 = int(round(l[3] * h)) 452 | 453 | xc = round(w * (l[0] + l[2]) / 2) 454 | yc = round(h * (l[1] + l[3]) / 2) 455 | 456 | xc = int(np.clip(xc, 0, w - 1)) 457 | yc = int(np.clip(yc, 0, h - 1)) 458 | 459 | centermap[0, yc, xc] = 255 460 | 461 | line_len, ang = line_len_and_angle(x0, y0, x1, y1) 462 | line_len /= len_divide_v 463 | length_map[0, yc, xc] = line_len 464 | degree_map[0, yc, xc] = ang 465 | 466 | x0d = x0 - xc 467 | y0d = y0 - yc 468 | x1d = x1 - xc 469 | y1d = y1 - yc 470 | 471 | #print('x0d: ', x0d) 472 | 473 | displacement_map[0, yc, xc] = x0d # / 2 474 | displacement_map[1, yc, xc] = y0d # / 2 475 | displacement_map[2, yc, xc] = x1d # / 2 476 | displacement_map[3, yc, xc] = y1d # / 2 477 | 478 | ## walk around line 479 | #ptss = work_around_line(x0, y0, x1, y1, n=5, r=0.0, thickness=3) 480 | 481 | # extrapolated to a 3×3 window 482 | ptss = near_area_n(xc, yc, n=3) 483 | for p in ptss: 484 | xc = round(p[0]) 485 | yc = round(p[1]) 486 | xc = int(np.clip(xc, 0, w - 1)) 487 | yc = int(np.clip(yc, 0, h - 1)) 488 | # x0d = x0 - xc 489 | # y0d = y0 - yc 490 | # x1d = x1 - xc 491 | # y1d = y1 - yc 492 | displacement_map[0, yc, xc] = x0d# / 2 493 | displacement_map[1, yc, xc] = y0d# / 2 494 | displacement_map[2, yc, xc] = x1d# / 2 495 | displacement_map[3, yc, xc] = y1d# / 2 496 | 497 | length_map[0, yc, xc] = line_len 498 | degree_map[0, yc, xc] = ang 499 | 500 | xc = round(w * (l[0] + l[2]) / 2) 501 | yc = round(h * (l[1] + l[3]) / 2) 502 | 503 | xc = int(np.clip(xc, 0, w - 1)) 504 | yc = int(np.clip(yc, 0, h - 1)) 505 | 506 | centermap[0, yc, xc] = 255 507 | 508 | line_len, ang = line_len_and_angle(x0, y0, x1, y1) 509 | line_len /= len_divide_v 510 | length_map[0, yc, xc] = line_len 511 | degree_map[0, yc, xc] = ang 512 | 513 | x0d = x0 - xc 514 | y0d = y0 - yc 515 | x1d = x1 - xc 516 | y1d = y1 - yc 517 | 518 | displacement_map[0, yc, xc] = x0d # / 2 519 | displacement_map[1, yc, xc] = y0d # / 2 520 | displacement_map[2, yc, xc] = x1d # / 2 521 | displacement_map[3, yc, xc] = y1d # / 2 522 | 523 | centermap[0, :, :] = cv2.GaussianBlur(centermap[0, :, :], (3,3), 0.0) 524 | centermap = np.array(centermap, dtype=np.float32) / 255.0 525 | b = centermap.max() - centermap.min() 526 | if b !=0: 527 | centermap = ( centermap - centermap.min() ) / b 528 | 529 | tp_mask = np.concatenate((centermap, displacement_map, length_map, degree_map), axis=0) 530 | return tp_mask 531 | 532 | 533 | def get_ext_lines(norm_lines, h=256, w=256, min_len=0.125): 534 | mu_half = min_len / 2 535 | ext_lines = [] 536 | for line in norm_lines: 537 | x0, y0, x1, y1 = line 538 | line_len = np.sqrt((x0 - x1) ** 2 + (y0 - y1) ** 2) 539 | nn = int(line_len / mu_half) - 1 540 | # print("nn: ", nn) 541 | if nn <= 1: 542 | ext_lines.append(line) 543 | else: 544 | ## y = k * x + b 545 | if abs(x0 - x1) > abs(y0 - y1): 546 | ## y = k* x + b 547 | k = (y1 - y0) / (x1 - x0) 548 | b = y1 - k * x1 549 | step = (x1 - x0) / (nn + 1) 550 | len_step = 2 * step # (x1 - x0) / (nn - 1) 551 | for ix in range(nn): 552 | ix0 = x0 + ix * step 553 | # ix1 = x0 + (ix + 1) * step 554 | ix1 = ix0 + len_step 555 | iy0 = k * ix0 + b 556 | iy1 = k * ix1 + b 557 | ext_lines.append([ix0, iy0, ix1, iy1]) 558 | 559 | else: 560 | ## x = k* y + b 561 | k = (x1 - x0) / (y1 - y0) 562 | b = x1 - k * y1 563 | step = (y1 - y0) / (nn + 1) 564 | len_step = 2 * step # (y1 - y0) / (nn - 1) 565 | for iy in range(nn): 566 | iy0 = y0 + iy * step 567 | # iy1 = y0 + (iy + 1) * step 568 | iy1 = iy0 + len_step 569 | ix0 = k * iy0 + b 570 | ix1 = k * iy1 + b 571 | ext_lines.append([ix0, iy0, ix1, iy1]) 572 | # print("ext_lines: ", len(ext_lines)) 573 | return ext_lines 574 | 575 | def gen_SOL_map(norm_lines, h =256, w =256, min_len =0.125, with_ext= False): 576 | """ 577 | 1 + 4 + 2 578 | return [7, h, w] 579 | """ 580 | ext_lines = get_ext_lines(norm_lines, h, w, min_len) 581 | return gen_TP_mask2(ext_lines, h, w, with_ext), ext_lines 582 | 583 | 584 | def gen_junction_and_line_mask(norm_lines, h = 256, w = 256): 585 | junction_map = np.zeros((h, w, 1), dtype=np.float32) 586 | line_map = np.zeros((h, w, 1), dtype=np.float32) 587 | 588 | radius = 1 589 | for l in norm_lines: 590 | x0 = int(round(l[0] * w)) 591 | y0 = int(round(l[1] * h)) 592 | x1 = int(round(l[2] * w)) 593 | y1 = int(round(l[3] * h)) 594 | cv2.line(line_map, (x0, y0), (x1, y1), (255, 255, 255), radius) 595 | #cv2.circle(junction_map, (x0, y0), radius, (255, 255, 255), radius) 596 | #cv2.circle(junction_map, (x1, y1), radius, (255, 255, 255), radius) 597 | 598 | ptss = near_area_n(x0, y0, n=3) 599 | ptss.extend( near_area_n(x1, y1, n=3) ) 600 | for p in ptss: 601 | xc = round(p[0]) 602 | yc = round(p[1]) 603 | xc = int(np.clip(xc, 0, w - 1)) 604 | yc = int(np.clip(yc, 0, h - 1)) 605 | junction_map[yc, xc, 0] = 255 606 | 607 | junction_map[:, :, 0] = cv2.GaussianBlur(junction_map[:, :, 0], (3,3), 0.0) 608 | junction_map = np.array(junction_map, dtype=np.float32) / 255.0 609 | b = junction_map.max() - junction_map.min() 610 | if b !=0: 611 | junction_map = ( junction_map - junction_map.min() ) / b 612 | # line map use binary one 613 | line_map = np.array(line_map, dtype=np.float32) / 255.0 614 | # line_map[:, :, 0] = cv2.GaussianBlur(line_map[:, :, 0], (3, 3), 0.0) 615 | # line_map = np.array(line_map, dtype=np.float32) / 255.0 616 | # b = line_map.max() - line_map.min() 617 | # if b !=0: 618 | # line_map = ( line_map - line_map.min() ) / b 619 | 620 | return junction_map, line_map 621 | -------------------------------------------------------------------------------- /mlsd_pytorch/data/wireframe_dset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from mlsd_pytorch.utils.comm import create_dir 3 | import tqdm 4 | import os 5 | import cv2 6 | import torch 7 | import json 8 | import random 9 | import pickle 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | 13 | from albumentations import ( 14 | RandomBrightnessContrast, 15 | OneOf, 16 | HueSaturationValue, 17 | Compose, 18 | Normalize 19 | ) 20 | 21 | from mlsd_pytorch.data.utils import \ 22 | ( swap_line_pt_maybe, 23 | get_ext_lines, 24 | #gen_TP_mask, 25 | gen_TP_mask2, 26 | gen_SOL_map, 27 | gen_junction_and_line_mask, 28 | TP_map_to_line_numpy, 29 | cut_line_by_xmin, 30 | cut_line_by_xmax) 31 | 32 | def parse_label_file_info(img_dir, label_file): 33 | infos = [] 34 | contens = json.load(open(label_file, 'r')) 35 | for c in tqdm.tqdm(contens): 36 | w = c['width'] 37 | h = c['height'] 38 | lines = c['lines'] 39 | fn = c['filename'][:-4]+'.jpg' 40 | full_fn = img_dir + fn 41 | assert os.path.exists(full_fn), full_fn 42 | 43 | json_content = { 44 | 'version': '4.5.6', 45 | 'flags': {}, 46 | 'shapes': [], 47 | 'imagePath': fn, 48 | 'imageData': None, 49 | 'imageHeight': h, 50 | 'imageWidth': w, 51 | } 52 | for l in lines: 53 | item = { 54 | "label": "line", 55 | "points": [ 56 | [ 57 | np.clip( np.float(l[0]), 0, w), 58 | np.clip( np.float(l[1]), 0, h) 59 | ], 60 | [ 61 | np.clip( np.float(l[2]), 0, w), 62 | np.clip( np.float(l[3]), 0, h) 63 | ] 64 | ], 65 | "group_id": None, 66 | "shape_type": "line", 67 | "flags": {} 68 | } 69 | json_content['shapes'].append(item) 70 | infos.append(json_content) 71 | return infos 72 | 73 | 74 | class Line_Dataset(Dataset): 75 | def __init__(self, cfg, is_train): 76 | super(Line_Dataset, self).__init__() 77 | 78 | self.cfg = cfg 79 | self.min_len = cfg.decode.len_thresh 80 | self.is_train = is_train 81 | 82 | self.img_dir = cfg.train.img_dir 83 | self.label_fn = cfg.train.label_fn 84 | 85 | if not is_train: 86 | self.img_dir = cfg.val.img_dir 87 | self.label_fn = cfg.val.label_fn 88 | 89 | self.cache_dir = cfg.train.data_cache_dir 90 | self.with_cache = cfg.train.with_cache 91 | 92 | print("==> load label..") 93 | if self.with_cache: 94 | ann_cache_fn = self.cache_dir+"/"+os.path.basename(self.label_fn)+".cache" 95 | if os.path.exists(ann_cache_fn): 96 | print("==> load {} from cache dir..".format(ann_cache_fn)) 97 | self.anns = pickle.load(open(ann_cache_fn, 'rb')) 98 | else: 99 | self.anns = self._load_anns(self.img_dir, self.label_fn) 100 | print("==> cache to {}".format(ann_cache_fn)) 101 | pickle.dump(self.anns, open(ann_cache_fn, 'wb')) 102 | else: 103 | self.anns = self._load_anns(self.img_dir, self.label_fn) 104 | #random.shuffle(self.anns) 105 | print("==> valid samples: ", len(self.anns)) 106 | 107 | 108 | self.input_size = cfg.datasets.input_size 109 | self.train_aug = self._aug_train() 110 | self.test_aug = self._aug_test(input_size=self.input_size) 111 | 112 | self.cache_to_mem = cfg.train.cache_to_mem 113 | self.cache_dict = {} 114 | if self.with_cache: 115 | print("===> cache...") 116 | for ann in tqdm.tqdm(self.anns): 117 | self.load_label(ann, False) 118 | 119 | 120 | def __len__(self): 121 | return len(self.anns) 122 | 123 | def _aug_train(self): 124 | aug = Compose( 125 | [ 126 | OneOf( 127 | [ 128 | HueSaturationValue(hue_shift_limit=10, 129 | sat_shift_limit=10, 130 | val_shift_limit=10, 131 | p=0.5), 132 | RandomBrightnessContrast(brightness_limit=0.2, 133 | contrast_limit=0.2, 134 | p=0.5) 135 | ] 136 | ), 137 | # OneOf( 138 | # [ 139 | # Blur(blur_limit=3, p=0.5), 140 | # GaussianBlur(blur_limit=3, p=0.5), 141 | # MedianBlur(blur_limit=3, p=0.5) 142 | # ] 143 | # ), 144 | 145 | ], 146 | p=1.0) 147 | return aug 148 | 149 | def _aug_test(self, input_size=384): 150 | aug = Compose( 151 | [ 152 | #Resize(height=input_size, 153 | # width=input_size), 154 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 155 | # Normalize(mean=(0.0, 0.0, 0.0), std=(1.0 / 255, 1.0 / 255, 1.0 / 255)) 156 | ], 157 | p=1.0) 158 | return aug 159 | def _line_len_fn(self, l1): 160 | len1 = np.sqrt((l1[2] - l1[0]) ** 2 + (l1[3] - l1[1]) ** 2) 161 | return len1 162 | 163 | def _load_anns(self, img_dir, label_fn): 164 | infos = parse_label_file_info(img_dir, label_fn) 165 | anns = [] 166 | for c in infos: 167 | 168 | img_full_fn = os.path.join(img_dir, c['imagePath']) 169 | if not os.path.exists(img_full_fn): 170 | print(" not exist!".format(img_full_fn) ) 171 | exit(0) 172 | 173 | lines = [] 174 | for s in c['shapes']: 175 | pt = s["points"] 176 | line = [pt[0][0], pt[0][1], pt[1][0], pt[1][1]] 177 | line = swap_line_pt_maybe(line) 178 | #line = np.array(line, np.float32) 179 | 180 | if not self.is_train: 181 | lines.append(line) 182 | elif self._line_len_fn(line) > self.min_len: 183 | lines.append(line) 184 | 185 | dst_ann = { 186 | 'img_full_fn': img_full_fn, 187 | 'lines': lines, 188 | 'img_w': c['imageWidth'], 189 | 'img_h': c['imageHeight'] 190 | } 191 | anns.append(dst_ann) 192 | 193 | return anns 194 | 195 | def _crop_aug(self, img, ann_origin): 196 | assert img.shape[1] == ann_origin['img_w'] 197 | assert img.shape[0] == ann_origin['img_h'] 198 | img_w = ann_origin['img_w'] 199 | img_h = ann_origin['img_h'] 200 | lines = ann_origin['lines'] 201 | xmin = random.randint(1, int(0.1 * img_w)) 202 | 203 | 204 | #ymin = random.randint(1, 0.1 * img_h) 205 | #ymax = img_h - random.randint(1, 0.1 * img_h) 206 | 207 | 208 | ## xmin 209 | xmin_lines = [] 210 | for line in lines: 211 | flg, line = cut_line_by_xmin(line, xmin) 212 | line[0] -= xmin 213 | line[2] -= xmin 214 | if flg and self._line_len_fn(line) > self.min_len: 215 | xmin_lines.append(line) 216 | lines = xmin_lines 217 | 218 | img = img[:, xmin: , :] 219 | ## xmax 220 | xmax = img.shape[1] - random.randint(1, int(0.1 * img.shape[1])) 221 | img = img[:, :xmax , :].copy() 222 | xmax_lines = [] 223 | for line in lines: 224 | flg, line = cut_line_by_xmax(line, xmax) 225 | if flg and self._line_len_fn(line) > self.min_len: 226 | xmax_lines.append(line) 227 | lines = xmax_lines 228 | 229 | ann_origin['lines'] = lines 230 | ann_origin['img_w'] = img.shape[1] 231 | ann_origin['img_h'] = img.shape[0] 232 | 233 | return img, ann_origin 234 | 235 | 236 | def _geo_aug(self, img, ann_origin): 237 | do_aug = False 238 | 239 | lines = ann_origin['lines'].copy() 240 | if random.random() < 0.5: 241 | do_aug = True 242 | flipped_lines = [] 243 | img = np.fliplr(img) 244 | for l in lines: 245 | flipped_lines.append( 246 | swap_line_pt_maybe([ann_origin['img_w'] - l[0], 247 | l[1], ann_origin['img_w'] - l[2], l[3]])) 248 | ann_origin['lines'] = flipped_lines 249 | 250 | lines = ann_origin['lines'].copy() 251 | if random.random() < 0.5: 252 | do_aug = True 253 | flipped_lines = [] 254 | img = np.flipud(img) 255 | for l in lines: 256 | flipped_lines.append( 257 | swap_line_pt_maybe([l[0], 258 | ann_origin['img_h'] - l[1], 259 | l[2], 260 | ann_origin['img_h'] - l[3]])) 261 | ann_origin['lines'] = flipped_lines 262 | 263 | lines = ann_origin['lines'].copy() 264 | if random.random() < 0.5: 265 | do_aug = True 266 | r_lines = [] 267 | img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) 268 | for l in lines: 269 | r_lines.append( 270 | swap_line_pt_maybe([ann_origin['img_h'] - l[1], 271 | l[0], ann_origin['img_h'] - l[3], l[2]])) 272 | ann_origin['lines'] = r_lines 273 | ann_origin['img_w'] = img.shape[1] 274 | ann_origin['img_h'] = img.shape[0] 275 | 276 | if random.random() < 0.5: 277 | do_aug = True 278 | img, ann_origin = self._crop_aug(img, ann_origin) 279 | 280 | ann_origin['img_w'] = img.shape[1] 281 | ann_origin['img_h'] = img.shape[0] 282 | 283 | return do_aug, img, ann_origin 284 | 285 | 286 | def load_label(self, ann, do_aug): 287 | norm_lines = [] 288 | for l in ann['lines']: 289 | 290 | ll = [ 291 | np.clip(l[0] / ann['img_w'], 0, 1), 292 | np.clip(l[1] / ann['img_h'], 0, 1), 293 | np.clip(l[2] / ann['img_w'], 0, 1), 294 | np.clip(l[3] / ann['img_h'], 0, 1) 295 | ] 296 | x0, y0, x1, y1 = 256 * ll[0], 256 * ll[1], 256 * ll[2], 256 * ll[3] 297 | if x0 == x1 and y0 == y1: 298 | print('fatal err!') 299 | print(ann['img_w'], ann['img_h']) 300 | print(ll) 301 | print(l) 302 | print(ann) 303 | exit(0) 304 | 305 | norm_lines.append(ll) 306 | 307 | ann['norm_lines'] = norm_lines 308 | 309 | label_cache_path = os.path.basename(ann['img_full_fn'])[:-4] + '.npy' 310 | label_cache_path = self.cache_dir + '/' + label_cache_path 311 | 312 | can_load = self.with_cache and not do_aug 313 | 314 | if can_load and self.cache_to_mem and label_cache_path in self.cache_dict.keys(): 315 | label = self.cache_dict[label_cache_path] 316 | 317 | elif can_load and os.path.exists(label_cache_path): 318 | label = np.load(label_cache_path) 319 | if self.cache_to_mem: 320 | self.cache_dict[label_cache_path] = label 321 | else: 322 | 323 | tp_mask = gen_TP_mask2(ann['norm_lines'], self.input_size // 2, self.input_size // 2, 324 | with_ext=self.cfg.datasets.with_centermap_extend) 325 | sol_mask, _ = gen_SOL_map(ann['norm_lines'], self.input_size // 2, self.input_size // 2, 326 | with_ext=False) 327 | 328 | junction_map, line_map = gen_junction_and_line_mask(ann['norm_lines'], 329 | self.input_size // 2, self.input_size // 2) 330 | 331 | label = np.zeros((2 * 7 + 2, self.input_size // 2, self.input_size // 2), dtype=np.float32) 332 | label[0:7, :, :] = sol_mask 333 | label[7:14, :, :] = tp_mask 334 | label[14, :, :] = junction_map[0] 335 | label[15, :, :] = line_map[0] 336 | if not do_aug and self.with_cache: 337 | # 338 | if self.cache_to_mem: 339 | #print("cache to mem: {} [ total: {} ]".format(label_cache_path,len(self.cache_dict) )) 340 | self.cache_dict[label_cache_path] = label 341 | else: 342 | #print("cache to cache dir:", label_cache_path) 343 | np.save(label_cache_path, label) 344 | 345 | return label 346 | 347 | def __getitem__(self, index): 348 | 349 | ann = self.anns[index].copy() 350 | img = cv2.imread(ann['img_full_fn']) 351 | 352 | do_aug = False 353 | if self.is_train and random.random() < 0.5: 354 | do_aug, img, ann = self._geo_aug(img, ann) 355 | 356 | img = cv2.resize(img, (self.input_size, self.input_size)) 357 | 358 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 359 | # if not self.is_train: 360 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 361 | # elif random.random() > 0.5: 362 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 363 | 364 | label = self.load_label(ann, do_aug) 365 | ext_lines = get_ext_lines(ann['norm_lines'], self.input_size // 2, self.input_size // 2) 366 | 367 | norm_lines = ann['norm_lines'] 368 | norm_lines_512_list = [] 369 | for l in norm_lines: 370 | norm_lines_512_list.append([ 371 | l[0] * 512, 372 | l[1] * 512, 373 | l[2] * 512, 374 | l[3] * 512, 375 | ]) 376 | 377 | if self.is_train: 378 | img = self.train_aug(image=img)['image'] 379 | #img_norm = (img / 127.5) - 1.0 380 | img_norm = self.test_aug(image=img)['image'] 381 | 382 | 383 | norm_lines_512_tensor = torch.from_numpy(np.array(norm_lines_512_list, np.float32)) 384 | sol_lines_512_tensor = torch.from_numpy(np.array(ext_lines, np.float32) * 512) 385 | 386 | return img_norm, img, label, \ 387 | norm_lines_512_list, \ 388 | norm_lines_512_tensor, \ 389 | sol_lines_512_tensor, \ 390 | ann['img_full_fn'] 391 | 392 | 393 | def LineDataset_collate_fn(batch): 394 | batch_size = len(batch) 395 | h, w, c = batch[0][0].shape 396 | images = np.zeros((batch_size, 3, h, w), dtype=np.float32) 397 | labels = np.zeros((batch_size, 16, h // 2, w // 2), dtype=np.float32) 398 | img_fns = [] 399 | img_origin_list = [] 400 | norm_lines_512_all = [] 401 | norm_lines_512_all_tensor_list = [] 402 | sol_lines_512_all_tensor_list = [] 403 | 404 | for inx in range(batch_size): 405 | im, img_origin, label_mask, \ 406 | norm_lines_512, norm_lines_512_tensor, \ 407 | sol_lines_512, img_fn = batch[inx] 408 | 409 | images[inx] = im.transpose((2, 0, 1)) 410 | labels[inx] = label_mask 411 | img_origin_list.append(img_origin) 412 | img_fns.append(img_fn) 413 | norm_lines_512_all.append(norm_lines_512) 414 | norm_lines_512_all_tensor_list.append(norm_lines_512_tensor) 415 | sol_lines_512_all_tensor_list.append(sol_lines_512) 416 | 417 | images = torch.from_numpy(images) 418 | labels = torch.from_numpy(labels) 419 | 420 | return { 421 | "xs": images, 422 | "ys": labels, 423 | "img_fns": img_fns, 424 | "origin_imgs": img_origin_list, 425 | "gt_lines_512": norm_lines_512_all, 426 | "gt_lines_tensor_512_list": norm_lines_512_all_tensor_list, 427 | "sol_lines_512_all_tensor_list": sol_lines_512_all_tensor_list 428 | } 429 | 430 | 431 | 432 | if __name__ == '__main__': 433 | from mlsd_pytorch.cfg.default import get_cfg_defaults 434 | 435 | cfg = get_cfg_defaults() 436 | 437 | root_dir = "/home/lhw/m2_disk/data/czcv_2021/wireframe_raw/" 438 | cfg.train.img_dir = root_dir+ "/images/" 439 | cfg.train.label_fn = root_dir+ "/valid.json" 440 | cfg.train.batch_size = 1 441 | cfg.train.data_cache_dir = "/home/lhw/m2_disk/data/czcv_2021/wireframe_cache/" 442 | cfg.train.with_cache = True 443 | cfg.datasets.with_centermap_extend = False 444 | 445 | dset = Line_Dataset(cfg, True) 446 | for img_norm, img, label, norm_lines_512, \ 447 | norm_lines_512_tensor, sol_lines_512, fn in dset: 448 | #continue 449 | print(img.shape) 450 | print(label.shape) 451 | print(norm_lines_512_tensor.shape) 452 | centermap = label[7] 453 | centermap = centermap[np.newaxis, :] 454 | displacement_map = label[8:12] 455 | reverse_lines = TP_map_to_line_numpy(centermap, displacement_map) 456 | print("reverse_lines: ", len(reverse_lines)) 457 | for i, l in enumerate(reverse_lines): 458 | #color = (random.randint(100, 255), random.randint(100, 255), 255) 459 | color = (0, 0, 255) 460 | x0, y0, x1, y1 = l 461 | cv2.line(img, (int(round(x0)), int(round(y0))), 462 | (int(round(x1)), int(round(y1))), color, 1, 16) 463 | 464 | cv2.imwrite(root_dir+ "/gui/gui_lines.jpg", img) 465 | cv2.imwrite(root_dir+ "/gui/gui_centermap.jpg", centermap[0] * 255) 466 | 467 | displacement_map = displacement_map[0] 468 | displacement_map = np.where(displacement_map != 0, 255, 0) 469 | cv2.imwrite(root_dir+ "/gui/gui_dis0.jpg", displacement_map) 470 | 471 | 472 | len_map = np.where(label[12] != 0, 255, 0) 473 | cv2.imwrite(root_dir+ "/gui/gui_lenmap.jpg", len_map) 474 | 475 | centermap = label[0] 476 | centermap = centermap[np.newaxis, :] 477 | displacement_map = label[1:5] 478 | reverse_lines = TP_map_to_line_numpy(centermap, displacement_map) 479 | print("SOL reverse_lines: ", len(reverse_lines)) 480 | for i, l in enumerate(reverse_lines): 481 | color = (random.randint(100, 255), random.randint(100, 255), 255) 482 | x0, y0, x1, y1 = l 483 | cv2.line(img, (int(round(x0)), int(round(y0))), 484 | (int(round(x1)), int(round(y1))), color, 1, 16) 485 | 486 | cv2.imwrite(root_dir + "/gui/gui_SOL_lines.jpg", img) 487 | cv2.imwrite(root_dir + "/gui/gui_SOL_centermap.jpg", centermap[0] * 255) 488 | 489 | displacement_map = displacement_map[0] 490 | displacement_map = np.where(displacement_map != 0, 255, 0) 491 | cv2.imwrite(root_dir + "/gui/gui_SOL_dis0.jpg", displacement_map) 492 | 493 | cv2.imwrite(root_dir + "/gui/gui_line_seg.jpg", label[0][15] * 255) 494 | cv2.imwrite(root_dir + "/gui/gui_junc_seg.jpg", label[0][14] * 255) 495 | 496 | break -------------------------------------------------------------------------------- /mlsd_pytorch/learner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import tqdm 4 | import numpy as np 5 | from torch.nn import functional as F 6 | import torch.nn as nn 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.optim.optimizer import Optimizer 9 | 10 | from mlsd_pytorch.utils.logger import TxtLogger 11 | from mlsd_pytorch.utils.meter import AverageMeter 12 | 13 | from mlsd_pytorch.utils.decode import deccode_lines_TP 14 | from mlsd_pytorch.data.utils import deccode_lines 15 | from mlsd_pytorch.loss import LineSegmentLoss 16 | from mlsd_pytorch.metric import F1_score_128, TPFP, msTPFP, AP 17 | 18 | # from apex.fp16_utils import * 19 | # from apex import amp, optimizers 20 | 21 | class Simple_MLSD_Learner(): 22 | def __init__(self, 23 | cfg, 24 | model : torch.nn.Module, 25 | optimizer: Optimizer, 26 | scheduler, 27 | logger : TxtLogger, 28 | save_dir : str, 29 | log_steps = 100, 30 | device_ids = [0,1], 31 | gradient_accum_steps = 1, 32 | max_grad_norm = 100.0, 33 | batch_to_model_inputs_fn = None, 34 | early_stop_n = 4, 35 | ): 36 | self.cfg = cfg 37 | self.model = model 38 | self.optimizer = optimizer 39 | self.scheduler = scheduler 40 | self.save_dir = save_dir 41 | self.log_steps = log_steps 42 | self.logger = logger 43 | self.device_ids = device_ids 44 | self.gradient_accum_steps = gradient_accum_steps 45 | self.max_grad_norm = max_grad_norm 46 | self.batch_to_model_inputs_fn = batch_to_model_inputs_fn 47 | self.early_stop_n = early_stop_n 48 | self.global_step = 0 49 | 50 | self.input_size = self.cfg.datasets.input_size 51 | self.loss_fn = LineSegmentLoss(cfg) 52 | self.epo = 0 53 | 54 | 55 | def step(self,step_n, batch_data : dict): 56 | imgs = batch_data["xs"].cuda() 57 | label = batch_data["ys"].cuda() 58 | outputs = self.model(imgs) 59 | loss_dict = self.loss_fn(outputs, label, 60 | batch_data["gt_lines_tensor_512_list"], 61 | batch_data["sol_lines_512_all_tensor_list"]) 62 | loss = loss_dict['loss'] 63 | if self.gradient_accum_steps > 1: 64 | loss = loss / self.gradient_accum_steps 65 | 66 | #with amp.scale_loss(loss, self.optimizer) as scaled_loss: 67 | # scaled_loss.backward() 68 | loss.backward() 69 | 70 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 71 | if (step_n + 1) % self.gradient_accum_steps == 0: 72 | self.optimizer.step() 73 | self.scheduler.step() # Update learning rate schedule 74 | self.model.zero_grad() 75 | self.global_step += 1 76 | return loss, loss_dict 77 | 78 | 79 | def val(self, model, val_dataloader : DataLoader): 80 | thresh = self.cfg.decode.score_thresh 81 | topk = self.cfg.decode.top_k 82 | min_len = self.cfg.decode.len_thresh 83 | 84 | model = model.eval() 85 | sap_thresh = 10 86 | data_iter = tqdm.tqdm(val_dataloader) 87 | f_scores = [] 88 | recalls = [] 89 | precisions = [] 90 | 91 | tp_list, fp_list, scores_list = [], [], [] 92 | n_gt = 0 93 | 94 | for batch_data in data_iter: 95 | imgs = batch_data["xs"].cuda() 96 | label = batch_data["ys"].cuda() 97 | batch_outputs = model(imgs) 98 | 99 | # keep TP mask 100 | label = label[:, 7:, :, :] 101 | batch_outputs = batch_outputs[:, 7:, :, :] 102 | 103 | # batch_outputs[:, 0, :, :] = label[:, 0, :, :] 104 | # batch_outputs[:, 1:5, :,: ] = label[:, 1:5, :,: ] 105 | 106 | 107 | for outputs, gt_lines_512 in zip(batch_outputs, batch_data["gt_lines_512"]): 108 | gt_lines_512 = np.array(gt_lines_512, np.float32) 109 | 110 | outputs = outputs.unsqueeze(0) 111 | #pred_lines,scores = deccode_lines(outputs, thresh, min_len, topk, 3) 112 | 113 | center_ptss, pred_lines, _, scores = \ 114 | deccode_lines_TP(outputs, thresh, min_len, topk, 3) 115 | 116 | #print('pred_lines: ', pred_lines.shape) 117 | #print('gt_lines_512: ', gt_lines_512.shape) 118 | pred_lines =pred_lines.detach().cpu().numpy() 119 | scores = scores.detach().cpu().numpy() 120 | 121 | pred_lines_128 = 128 * pred_lines / (self.input_size / 2) 122 | 123 | gt_lines_128 = gt_lines_512 / 4 124 | fscore, recall, precision = F1_score_128(pred_lines_128.tolist(),gt_lines_128.tolist(), 125 | thickness=3) 126 | f_scores.append(fscore) 127 | recalls.append(recall) 128 | precisions.append(precision) 129 | 130 | tp, fp = msTPFP(pred_lines_128, gt_lines_128, sap_thresh) 131 | 132 | n_gt += gt_lines_128.shape[0] 133 | tp_list.append(tp) 134 | fp_list.append(fp) 135 | scores_list.append(scores) 136 | 137 | 138 | 139 | f_score = np.array(f_scores, np.float32).mean() 140 | recall = np.array(recalls, np.float32).mean() 141 | precision = np.array(precisions, np.float32).mean() 142 | 143 | 144 | tp_list = np.concatenate(tp_list) 145 | fp_list = np.concatenate(fp_list) 146 | scores_list = np.concatenate(scores_list) 147 | idx = np.argsort(scores_list)[::-1] 148 | tp = np.cumsum(tp_list[idx]) / n_gt 149 | fp = np.cumsum(fp_list[idx]) / n_gt 150 | rcs = tp 151 | pcs = tp / np.maximum(tp + fp, 1e-9) 152 | sAP = AP(tp, fp) * 100 153 | self.logger.write("==>step: {}, f_score: {}, recall: {}, precision:{}, sAP10: {}\n ". 154 | format(self.global_step, f_score, recall, precision, sAP)) 155 | 156 | 157 | return { 158 | 'fscore': f_score, 159 | 'recall': recall, 160 | 'precision':precision, 161 | 'sAP10': sAP 162 | } 163 | 164 | def train(self, train_dataloader : DataLoader, 165 | val_dataloader : DataLoader, 166 | epoches = 100): 167 | best_score = 0 168 | early_n = 0 169 | # self.model, self.optimizer = amp.initialize(self.model, self.optimizer, 170 | # opt_level="O1", 171 | # loss_scale=1.0 172 | # ) 173 | for self.epo in range(epoches): 174 | step_n = 0 175 | train_avg_loss = AverageMeter() 176 | train_avg_center_loss = AverageMeter() 177 | train_avg_replacement_loss = AverageMeter() 178 | train_avg_line_seg_loss = AverageMeter() 179 | train_avg_junc_seg_loss = AverageMeter() 180 | 181 | train_avg_match_loss = AverageMeter() 182 | train_avg_match_rario = AverageMeter() 183 | train_avg_t_loss = AverageMeter() 184 | 185 | data_iter = tqdm.tqdm(train_dataloader) 186 | for batch in data_iter: 187 | self.model.train() 188 | train_loss,loss_dict = self.step(step_n, batch) 189 | train_avg_loss.update(train_loss.item(),1) 190 | 191 | train_avg_center_loss.update(loss_dict['center_loss'].item() ,1) 192 | train_avg_replacement_loss.update(loss_dict['displacement_loss'].item(), 1) 193 | train_avg_line_seg_loss.update(loss_dict['line_seg_loss'].item(), 1) 194 | train_avg_junc_seg_loss.update(loss_dict['junc_seg_loss'].item(), 1) 195 | train_avg_match_loss.update(float(loss_dict['match_loss']), 1) 196 | train_avg_match_rario.update(loss_dict['match_ratio'], 1) 197 | 198 | status = '[{0}] lr= {1:.6f} loss= {2:.3f} avg = {3:.3f},c: {4:.3f}, d: {5:.3f}, l: {6:.3f}, ' \ 199 | 'junc:{7:.3f},m:{8:.3f},m_r:{9:.2f} '.format( 200 | self.epo + 1, 201 | self.scheduler.get_lr()[0], 202 | train_loss.item(), 203 | train_avg_loss.avg, 204 | train_avg_center_loss.avg, 205 | train_avg_replacement_loss.avg, 206 | train_avg_line_seg_loss.avg, 207 | train_avg_junc_seg_loss.avg, 208 | train_avg_match_loss.avg, 209 | train_avg_match_rario.avg 210 | ) 211 | 212 | #if step_n%self.log_steps ==0: 213 | # print(status) 214 | data_iter.set_description(status) 215 | step_n +=1 216 | 217 | ##self.scheduler.step() ## we update every step instead 218 | if self.epo > self.cfg.val.val_after_epoch: 219 | ## val 220 | m = self.val(self.model, val_dataloader) 221 | fscore = m['sAP10'] 222 | if best_score < fscore: 223 | early_n = 0 224 | best_score = fscore 225 | model_path = os.path.join(self.save_dir, 'best.pth') 226 | torch.save(self.model.state_dict(), model_path) 227 | else: 228 | early_n += 1 229 | self.logger.write("epo: {}, steps: {} ,sAP10 : {:.4f} , best sAP10: {:.4f}". \ 230 | format(self.epo, self.global_step, fscore, best_score)) 231 | self.logger.write(str(m)) 232 | self.logger.write("=="*50) 233 | 234 | if early_n > self.early_stop_n: 235 | print('early stopped!') 236 | return best_score 237 | model_path = os.path.join(self.save_dir, 'latest.pth') 238 | torch.save(self.model.state_dict(), model_path) 239 | return best_score 240 | -------------------------------------------------------------------------------- /mlsd_pytorch/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlsd_multi_loss import LineSegmentLoss -------------------------------------------------------------------------------- /mlsd_pytorch/loss/_func.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | __all__ = [ 7 | "focal_neg_loss_with_logits", 8 | "weighted_bce_with_logits", 9 | ] 10 | 11 | 12 | def focal_neg_loss_with_logits(preds, gt, alpha=2, belta=4): 13 | """ 14 | borrow from https://github.com/princeton-vl/CornerNet 15 | """ 16 | 17 | preds = torch.sigmoid(preds) 18 | 19 | pos_inds = gt.eq(1) 20 | neg_inds = gt.lt(1) 21 | 22 | # pos_inds = gt.gt(0) 23 | # neg_inds = gt.eq(0) 24 | 25 | neg_weights = torch.pow(1 - gt[neg_inds], belta) 26 | 27 | loss = 0 28 | pos_pred = preds[pos_inds] 29 | neg_pred = preds[neg_inds] 30 | 31 | pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, alpha) 32 | neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, alpha) * neg_weights 33 | 34 | num_pos = pos_inds.float().sum() 35 | pos_loss = pos_loss.sum() 36 | neg_loss = neg_loss.sum() 37 | 38 | if pos_pred.nelement() == 0: 39 | loss = loss - neg_loss 40 | else: 41 | loss = loss - (pos_loss + neg_loss) / num_pos 42 | 43 | return loss 44 | 45 | 46 | # def weighted_bce_with_logits(out, gt, pos_w=1.0, neg_w=30.0): 47 | # pos_mask = torch.where(gt == 1, torch.ones_like(gt), torch.zeros_like(gt)) 48 | # neg_mask = torch.ones_like(pos_mask) - pos_mask 49 | 50 | # losses = F.binary_cross_entropy_with_logits(out, gt, reduction='none') 51 | 52 | # loss_neg = (losses * neg_mask).sum() / (torch.sum(neg_mask)) 53 | # loss_v = loss_neg * neg_w 54 | 55 | # pos_sum = torch.sum(pos_mask) 56 | # if pos_sum != 0: 57 | # loss_pos = (losses * pos_mask).sum() / pos_sum 58 | # loss_v += (loss_pos * pos_w) 59 | # return loss_v 60 | 61 | 62 | def weighted_bce_with_logits(out, gt, pos_w=1.0, neg_w=30.0): 63 | pos_mask = torch.where(gt != 0.0, torch.ones_like(gt), torch.zeros_like(gt)) 64 | #pos_mask = torch.where(gt == 1, torch.ones_like(gt), torch.zeros_like(gt)) 65 | neg_mask = torch.ones_like(pos_mask) - pos_mask 66 | loss = F.binary_cross_entropy_with_logits(out, gt, reduction='none') 67 | loss_pos = (loss * pos_mask).sum() / ( torch.sum(pos_mask) + 1e-5) 68 | loss_neg = (loss * neg_mask).sum() / ( torch.sum(neg_mask) + 1e-5) 69 | loss = loss_pos * pos_w + loss_neg * neg_w 70 | return loss 71 | 72 | -------------------------------------------------------------------------------- /mlsd_pytorch/loss/mlsd_multi_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from ._func import focal_neg_loss_with_logits#, weighted_bce_with_logits 8 | from mlsd_pytorch.utils.decode import deccode_lines_TP 9 | 10 | __all__ = [ 11 | "LineSegmentLoss", 12 | ] 13 | 14 | def weighted_bce_with_logits(out, gt, pos_w=1.0, neg_w=30.0): 15 | pos_mask = torch.where(gt != 0, torch.ones_like(gt), torch.zeros_like(gt)) 16 | neg_mask = torch.ones_like(pos_mask) - pos_mask 17 | 18 | loss = F.binary_cross_entropy_with_logits(out, gt, reduction='none') 19 | loss_pos = (loss * pos_mask).sum() / torch.sum(pos_mask) 20 | loss_neg = (loss * neg_mask).sum() / torch.sum(neg_mask) 21 | 22 | loss = loss_pos * pos_w + loss_neg * neg_w 23 | return loss 24 | 25 | 26 | # def displacement_loss_func(pred_dis, gt_dis): 27 | # # only consider non zero part 28 | # pos_mask = torch.where(gt_dis[:, 0, :, :].unsqueeze(1) != 0, torch.ones_like(gt_dis), torch.zeros_like(gt_dis)) 29 | # pos_mask_sum = pos_mask.sum() 30 | 31 | # pred_dis = pred_dis * pos_mask 32 | # gt_dis = gt_dis * pos_mask 33 | 34 | # displacement_loss1 = F.smooth_l1_loss(pred_dis, gt_dis, reduction='none').mean(axis=[1]) 35 | 36 | # # swap pt 37 | # pred_dis2 = torch.cat((pred_dis[:, 2:, :, :], pred_dis[:, :2, :, :]), dim=1) 38 | # displacement_loss2 = F.smooth_l1_loss(pred_dis2, gt_dis, reduction='none').mean(axis=[1]) 39 | # displacement_loss = displacement_loss1.min(displacement_loss2) 40 | 41 | # displacement_loss = displacement_loss.sum() / pos_mask_sum 42 | 43 | # return displacement_loss 44 | 45 | 46 | # def len_and_angle_loss_func(pred_len, pred_angle, gt_len, gt_angle): 47 | # # only consider non zero part 48 | # pos_mask = torch.where(gt_len != 0, torch.ones_like(gt_len), torch.zeros_like(gt_len)) 49 | # pos_mask_sum = pos_mask.sum() 50 | 51 | # len_loss = F.smooth_l1_loss(pred_len, gt_len, reduction='none') 52 | # len_loss = len_loss * pos_mask 53 | # len_loss = len_loss.sum() / pos_mask_sum 54 | 55 | # angle_loss = F.smooth_l1_loss(pred_angle, gt_angle, reduction='none') 56 | # angle_loss = angle_loss * pos_mask 57 | # angle_loss = angle_loss.sum() / pos_mask_sum 58 | 59 | # return len_loss, angle_loss 60 | 61 | def len_and_angle_loss_func(pred_len, pred_angle, gt_len, gt_angle): 62 | pred_len = torch.sigmoid(pred_len) 63 | pred_angle = torch.sigmoid(pred_angle) 64 | # only consider non zero part 65 | pos_mask = torch.where(gt_len != 0, torch.ones_like(gt_len), torch.zeros_like(gt_len)) 66 | pos_mask_sum = pos_mask.sum() 67 | 68 | len_loss = F.smooth_l1_loss(pred_len, gt_len, reduction='none') 69 | len_loss = len_loss * pos_mask 70 | len_loss = len_loss.sum() / pos_mask_sum 71 | 72 | angle_loss = F.smooth_l1_loss(pred_angle, gt_angle, reduction='none') 73 | angle_loss = angle_loss * pos_mask 74 | angle_loss = angle_loss.sum() / pos_mask_sum 75 | 76 | return len_loss, angle_loss 77 | 78 | 79 | def displacement_loss_func(pred_dis, gt_dis, gt_center_mask=None): 80 | # only consider non zero part 81 | x0 = gt_dis[:, 0, :, :] 82 | y0 = gt_dis[:, 1, :, :] 83 | x1 = gt_dis[:, 2, :, :] 84 | y1 = gt_dis[:, 3, :, :] 85 | 86 | # if gt_center_mask is not None: 87 | # pos_mask = torch.where(gt_center_mask > 0.9, torch.ones_like(x0), torch.zeros_like(x0)) 88 | # else: 89 | # pos_v = x0.abs() + y0.abs() + x1.abs() + y1.abs() 90 | # pos_mask = torch.where(pos_v != 0, torch.ones_like(x0), torch.zeros_like(x0)) 91 | pos_v = x0.abs() + y0.abs() + x1.abs() + y1.abs() 92 | pos_mask = torch.where(pos_v != 0, torch.ones_like(x0), torch.zeros_like(x0)) 93 | pos_mask_sum = pos_mask.sum() 94 | 95 | pos_mask = pos_mask.unsqueeze(1) 96 | 97 | pred_dis = pred_dis * pos_mask 98 | gt_dis = gt_dis * pos_mask 99 | 100 | displacement_loss1 = F.smooth_l1_loss(pred_dis, gt_dis, reduction='none').sum(axis=[1]) 101 | 102 | # swap pt 103 | pred_dis2 = torch.cat((pred_dis[:, 2:, :, :], pred_dis[:, :2, :, :]), dim=1) 104 | displacement_loss2 = F.smooth_l1_loss(pred_dis2, gt_dis, reduction='none').sum(axis=[1]) 105 | displacement_loss = displacement_loss1.min(displacement_loss2) 106 | #if gt_center_mask is not None: 107 | # displacement_loss = displacement_loss * gt_center_mask 108 | 109 | displacement_loss = displacement_loss.sum() / pos_mask_sum 110 | 111 | return displacement_loss 112 | 113 | 114 | class LineSegmentLoss(nn.Module): 115 | def __init__(self, cfg): 116 | super(LineSegmentLoss, self).__init__() 117 | self.input_size = cfg.datasets.input_size 118 | 119 | self.with_SOL_loss = cfg.loss.with_sol_loss 120 | self.with_match_loss = cfg.loss.with_match_loss 121 | self.with_focal_loss = cfg.loss.with_focal_loss 122 | 123 | # 0: only in tp 124 | # 1: in tp and sol 125 | # 2: in tp, sol, junc 126 | # 3: in tp, sol, junc, line 127 | self.focal_loss_level = cfg.loss.focal_loss_level 128 | self.match_sap_thresh = cfg.loss.match_sap_thresh # 5 129 | 130 | self.decode_score_thresh = cfg.decode.score_thresh # 0.01 131 | self.decode_len_thresh = cfg.decode.len_thresh # 10 132 | self.decode_top_k = cfg.decode.top_k # 200 133 | 134 | self.loss_w_dict = { 135 | 'tp_center_loss': 10.0, 136 | 'tp_displacement_loss': 1.0, 137 | 'tp_len_loss': 1.0, 138 | 'tp_angle_loss': 1.0, 139 | 'tp_match_loss': 1.0, 140 | 'tp_centerness_loss': 1.0, # current not support 141 | 142 | 'sol_center_loss': 1.0, 143 | 'sol_displacement_loss': 1.0, 144 | 'sol_len_loss': 1.0, 145 | 'sol_angle_loss': 1.0, 146 | 'sol_match_loss': 1.0, 147 | 'sol_centerness_loss': 1.0, # current not support 148 | 149 | 'line_seg_loss': 1.0, 150 | 'junc_seg_loss': 1.0 151 | } 152 | 153 | 154 | if len(cfg.loss.loss_weight_dict_list) > 0: 155 | self.loss_w_dict.update(cfg.loss.loss_weight_dict_list[0]) 156 | 157 | 158 | print("===> loss weight: ", self.loss_w_dict) 159 | 160 | def _m_gt_matched_n(self, p_lines, gt_lines, thresh): 161 | gt_lines = gt_lines.cuda() 162 | distance1 = torch.cdist(gt_lines[:, :2],p_lines[:, :2], p=2) 163 | distance2 = torch.cdist(gt_lines[:, 2:], p_lines[:, 2:], p=2) 164 | 165 | distance = distance1 + distance2 166 | near_inx = torch.argsort(distance, 1)[:, 0] # most neared pred one 167 | 168 | matched_pred_lines = p_lines[near_inx] 169 | 170 | distance1 = F.pairwise_distance(gt_lines[:, :2], matched_pred_lines[:, :2], p=2) 171 | distance2 = F.pairwise_distance(gt_lines[:, 2:], matched_pred_lines[:, 2:], p=2) 172 | 173 | # print("distance1: ",distance1.shape) 174 | 175 | inx = torch.where((distance1 < thresh) & (distance2 < thresh))[0] 176 | return len(inx) 177 | 178 | def _m_match_loss_fn(self, p_lines, p_centers, p_scores, gt_lines, thresh): 179 | gt_lines = gt_lines.cuda() 180 | distance1 = torch.cdist(p_lines[:, :2], gt_lines[:, :2], p=2) 181 | distance2 = torch.cdist(p_lines[:, 2:], gt_lines[:, 2:], p=2) 182 | 183 | distance = distance1 + distance2 184 | near_inx = torch.argsort(distance, 1)[:, 0] # most neared one 185 | 186 | matched_gt_lines = gt_lines[near_inx] 187 | 188 | distance1 = F.pairwise_distance(matched_gt_lines[:, :2], p_lines[:, :2], p=2) 189 | distance2 = F.pairwise_distance(matched_gt_lines[:, 2:], p_lines[:, 2:], p=2) 190 | 191 | # print("distance1: ",distance1.shape) 192 | 193 | inx = torch.where((distance1 < thresh) & (distance2 < thresh))[0] 194 | 195 | # center_distance = F.pairwise_distance( (matched_gt_lines[:, :2] + matched_gt_lines[:, 2:])/2, 196 | # (p_lines[:, :2] + p_lines[:, 2:])/2, p=2) 197 | # unmached_inx = torch.where( (distance1 > 3*thresh) & 198 | # (distance2 > 3*thresh) & 199 | # (center_distance > 2 * thresh) )[0] 200 | 201 | # print(inx) 202 | # print(unmached_inx) 203 | 204 | match_n = len(inx) 205 | # n_gt = gt_lines.shape[0] 206 | loss = 4 * thresh 207 | #loss = 0.0 208 | # match_ratio = inx[0].shape[0] / n_gt 209 | # match_ratio = np.clip(match_ratio, 0, 1.0) 210 | if match_n > 0: 211 | mathed_gt_lines = matched_gt_lines[inx] 212 | mathed_pred_lines = p_lines[inx] 213 | mathed_pred_centers = p_centers[inx] 214 | #mathed_pred_scores = p_scores[inx] 215 | 216 | endpoint_loss = F.l1_loss(mathed_pred_lines, mathed_gt_lines, reduction='mean')# * 2 217 | 218 | gt_centers = (mathed_gt_lines[:, :2] + mathed_gt_lines[:, 2:]) / 2 219 | # print("gt_centers: ", gt_centers.shape) 220 | # print("mathed_pred_centers: ", mathed_pred_centers.shape) 221 | center_dis_loss = F.l1_loss(mathed_pred_centers, gt_centers, reduction='mean') 222 | 223 | # center_dis_loss = torch.where(center_dis_loss< 1.0, torch.zeros_like(center_dis_loss), center_dis_loss - 1.0) 224 | # endpoint_loss = torch.where(endpoint_loss< 1.0, torch.zeros_like(endpoint_loss), endpoint_loss - 1.0) 225 | #center_dis_loss = center_dis_loss.mean() 226 | #endpoint_loss = endpoint_loss.mean() 227 | # print("mean score: ", mathed_pred_scores.mean()) 228 | 229 | # larger is better 230 | #mean_score = mathed_pred_scores.mean() 231 | #print(mean_score) 232 | loss = 1.0*endpoint_loss + 1.0 * center_dis_loss# - 1.0* mean_score 233 | 234 | # if len(unmached_inx) >0: 235 | # unmathed_pred_scores = p_scores[unmached_inx] 236 | # unmatch_mean_score = unmathed_pred_scores.mean() 237 | # #print(unmatch_mean_score) 238 | # # small is better 239 | # loss = loss + 2.0 * unmatch_mean_score 240 | 241 | # print("endpoint_loss: ", endpoint_loss/ mathed_gt_lines.shape[0]) 242 | # print("center_dis_loss: ", center_dis_loss/ mathed_gt_lines.shape[0]) 243 | 244 | # loss = loss / mathed_pred_lines.shape[0] 245 | 246 | ## match ratio large is good 247 | # loss = loss - 5 * match_ratio 248 | # print("loss: ", loss) 249 | # print("match_n: ", match_n) 250 | return loss, match_n 251 | 252 | def matching_loss_func(self, pred_tp_mask, gt_line_512_tensor_list): 253 | match_loss_all = 0.0 254 | match_ratio_all = 0.0 255 | for pred, gt_line_512 in zip(pred_tp_mask, gt_line_512_tensor_list): 256 | gt_line_128 = gt_line_512 / 4 257 | n_gt = gt_line_128.shape[0] 258 | 259 | pred_center_ptss, pred_lines, pred_lines_swap, pred_scores = \ 260 | deccode_lines_TP(pred.unsqueeze(0), 261 | score_thresh=self.decode_score_thresh, 262 | len_thresh=self.decode_len_thresh, 263 | topk_n=self.decode_top_k, 264 | ksize=3) 265 | n_pred = pred_center_ptss.shape[0] 266 | if n_pred == 0: 267 | match_loss_all += 4 * self.match_sap_thresh 268 | match_ratio_all += 0.0 269 | continue 270 | # print("pred_center_ptsssss: ",pred_center_ptss.shape) 271 | # print("gt_line_128: ", gt_line_128.shape) 272 | pred_lines_128 = 128 * pred_lines / (self.input_size / 2) 273 | pred_lines_128_swap = 128 * pred_lines_swap / (self.input_size / 2) 274 | pred_center_ptss_128 = 128 * pred_center_ptss / (self.input_size / 2) 275 | 276 | pred_lines_128 = torch.cat((pred_lines_128, pred_lines_128_swap),dim=0) 277 | pred_center_ptss_128 = torch.cat((pred_center_ptss_128,pred_center_ptss_128),dim=0) 278 | pred_scores = torch.cat((pred_scores,pred_scores),dim=0) 279 | 280 | mloss, match_n_pred = self._m_match_loss_fn(pred_lines_128, 281 | pred_center_ptss_128, 282 | pred_scores, gt_line_128, self.match_sap_thresh) 283 | 284 | match_n = self._m_gt_matched_n(pred_lines_128,gt_line_128, self.match_sap_thresh) 285 | match_ratio = match_n / n_gt 286 | 287 | match_loss_all += mloss 288 | match_ratio_all += match_ratio 289 | 290 | return match_loss_all / pred_tp_mask.shape[0], match_ratio_all / pred_tp_mask.shape[0] 291 | 292 | def tp_mask_loss(self, out, gt, gt_lines_tensor_512_list): 293 | out_center = out[:, 7, :, :] 294 | gt_center = gt[:, 7, :, :] 295 | 296 | if self.with_focal_loss: 297 | center_loss = focal_neg_loss_with_logits(out_center, gt_center) 298 | #center_loss += weighted_bce_with_logits(out_center, gt_center, 1.0, 10.0) 299 | else: 300 | center_loss = weighted_bce_with_logits(out_center, gt_center, 1.0, 30.0) 301 | 302 | out_displacement = out[:, 8:12, :, :] 303 | gt_displacement = gt[:, 8:12, :, :] 304 | displacement_loss = displacement_loss_func(out_displacement, gt_displacement, gt_center) 305 | 306 | len_loss, angle_loss = len_and_angle_loss_func( 307 | pred_len=out[:, 12, :, :], 308 | pred_angle=out[:, 13, :, :], 309 | gt_len=gt[:, 12, :, :], 310 | gt_angle=gt[:, 13, :, :] 311 | ) 312 | match_loss, match_ratio = 0, 0 313 | if self.with_match_loss: 314 | match_loss, match_ratio = self.matching_loss_func(out[:, 7:12], 315 | gt_lines_tensor_512_list) 316 | 317 | return { 318 | 'tp_center_loss': center_loss, 319 | 'tp_displacement_loss': displacement_loss, 320 | 'tp_len_loss': len_loss, 321 | 'tp_angle_loss': angle_loss, 322 | 'tp_match_loss': match_loss, 323 | 'tp_match_ratio': match_ratio #not included in loss, only for log 324 | } 325 | 326 | def sol_mask_loss(self, out, gt, sol_lines_512_all_tensor_list): 327 | out_center = out[:, 0, :, :] 328 | gt_center = gt[:, 0, :, :] 329 | 330 | if self.with_focal_loss and self.focal_loss_level >=1 : 331 | center_loss = focal_neg_loss_with_logits(out_center, gt_center) 332 | else: 333 | center_loss = weighted_bce_with_logits(out_center, gt_center, 1.0, 30.0) 334 | 335 | out_displacement = out[:, 1:5, :, :] 336 | gt_displacement = gt[:, 1:5, :, :] 337 | displacement_loss = displacement_loss_func(out_displacement, gt_displacement,gt_center) 338 | 339 | len_loss, angle_loss = len_and_angle_loss_func( 340 | pred_len=out[:, 5, :, :], 341 | pred_angle=out[:, 6, :, :], 342 | gt_len=gt[:, 5, :, :], 343 | gt_angle=gt[:, 6, :, :] 344 | ) 345 | match_loss, match_ratio = 0, 0 346 | if self.with_match_loss: 347 | match_loss, match_ratio = self.matching_loss_func(out[:, 0:5], 348 | sol_lines_512_all_tensor_list) 349 | return { 350 | 'sol_center_loss': center_loss, 351 | 'sol_displacement_loss': displacement_loss, 352 | 'sol_len_loss': len_loss, 353 | 'sol_angle_loss': angle_loss, 354 | 'sol_match_loss': match_loss 355 | } 356 | 357 | def line_and_juc_seg_loss(self, out, gt): 358 | # 359 | 360 | out_line_seg = out[:, 15, :, :] 361 | gt_line_seg = gt[:, 15, :, :] 362 | if self.with_focal_loss and self.focal_loss_level >= 3: 363 | line_seg_loss = focal_neg_loss_with_logits(out_line_seg, gt_line_seg) 364 | else: 365 | line_seg_loss = weighted_bce_with_logits(out_line_seg, gt_line_seg, 1.0, 1.0) 366 | 367 | out_junc_seg = out[:, 14, :, :] 368 | gt_junc_seg = gt[:, 14, :, :] 369 | if self.with_focal_loss and self.focal_loss_level >=2: 370 | junc_seg_loss = focal_neg_loss_with_logits(out_junc_seg, gt_junc_seg) 371 | else: 372 | junc_seg_loss = weighted_bce_with_logits(out_junc_seg, gt_junc_seg, 1.0, 30.0) 373 | 374 | 375 | return line_seg_loss, junc_seg_loss 376 | 377 | def forward(self, preds, gts, 378 | tp_gt_lines_512_list, 379 | sol_gt_lines_512_list): 380 | 381 | line_seg_loss, junc_seg_loss = self.line_and_juc_seg_loss(preds, gts) 382 | 383 | loss_dict = { 384 | 'line_seg_loss': line_seg_loss, 385 | 'junc_seg_loss': junc_seg_loss 386 | } 387 | if self.with_SOL_loss: 388 | sol_loss_dict = self.sol_mask_loss(preds, gts, 389 | sol_gt_lines_512_list) 390 | loss_dict.update(sol_loss_dict) 391 | 392 | tp_loss_dict = self.tp_mask_loss(preds, gts, 393 | tp_gt_lines_512_list) 394 | 395 | loss_dict.update(tp_loss_dict) 396 | 397 | loss = 0.0 398 | for k, v in loss_dict.items(): 399 | if not self.with_SOL_loss and 'sol_' in k: 400 | continue 401 | if k in self.loss_w_dict.keys(): 402 | v = v * self.loss_w_dict[k] 403 | loss_dict[k] = v 404 | loss += v 405 | loss_dict['loss'] = loss 406 | 407 | if self.with_SOL_loss: 408 | loss_dict['center_loss'] = loss_dict['sol_center_loss'] + loss_dict['tp_center_loss'] 409 | loss_dict['displacement_loss'] = loss_dict['sol_displacement_loss'] + loss_dict['tp_displacement_loss'] 410 | loss_dict['match_loss'] = loss_dict['tp_match_loss'] + loss_dict['sol_match_loss'] 411 | loss_dict['match_ratio'] = loss_dict['tp_match_ratio'] 412 | else: 413 | loss_dict['center_loss'] = loss_dict['tp_center_loss'] 414 | loss_dict['displacement_loss'] = loss_dict['tp_displacement_loss'] 415 | loss_dict['match_loss'] = loss_dict['tp_match_loss'] 416 | loss_dict['match_ratio'] = loss_dict['tp_match_ratio'] 417 | 418 | return loss_dict -------------------------------------------------------------------------------- /mlsd_pytorch/metric.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def F1_score_128(pred_lines_128_list, gt_lines_128_list, thickness=3): 6 | """ 7 | @brief heat F1 score, draw the lines to a 128 * 128 img 8 | @pred_lines_128 [ [x0, y0, x1, y1], ... ] 9 | @gt_lines_128_list [ [x0, y0, x1, y1], ... ] 10 | """ 11 | pred_heatmap = np.zeros((128, 128), np.uint8) 12 | gt_heatmap = np.zeros((128, 128), np.uint8) 13 | 14 | for l in pred_lines_128_list: 15 | x0, y0, x1, y1 = l 16 | x0 = int(round(x0)) 17 | y0 = int(round(y0)) 18 | x1 = int(round(x1)) 19 | y1 = int(round(y1)) 20 | cv2.line(pred_heatmap, (x0, y0), (x1, y1), (1, 1, 1), thickness, 8) 21 | 22 | for l in gt_lines_128_list: 23 | x0, y0, x1, y1 = l 24 | x0 = int(round(x0)) 25 | y0 = int(round(y0)) 26 | x1 = int(round(x1)) 27 | y1 = int(round(y1)) 28 | cv2.line(gt_heatmap, (x0, y0), (x1, y1), (1, 1, 1), thickness, 8) 29 | 30 | pred_heatmap = np.array(pred_heatmap, np.float32) 31 | gt_heatmap = np.array(gt_heatmap, np.float32) 32 | 33 | intersection = np.sum(gt_heatmap * pred_heatmap) 34 | # union = np.sum(gt_heatmap) + np.sum(gt_heatmap) 35 | eps = 0.001 36 | # dice = (2. * intersection + eps) / (union + eps) 37 | 38 | recall = intersection /(np.sum(gt_heatmap) + eps) 39 | precision = intersection /(np.sum(pred_heatmap) + eps) 40 | 41 | fscore = (2 * precision * recall) / (precision + recall + eps) 42 | return fscore, recall, precision 43 | 44 | 45 | 46 | def msTPFP(line_pred, line_gt, threshold): 47 | line_pred = line_pred.reshape(-1, 2, 2)[:, :, ::-1] 48 | line_gt = line_gt.reshape(-1, 2, 2)[:, :, ::-1] 49 | diff = ((line_pred[:, None, :, None] - line_gt[:, None]) ** 2).sum(-1) 50 | diff = np.minimum( 51 | diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0] 52 | ) 53 | 54 | choice = np.argmin(diff, 1) 55 | dist = np.min(diff, 1) 56 | hit = np.zeros(len(line_gt), np.bool) 57 | tp = np.zeros(len(line_pred), np.float) 58 | fp = np.zeros(len(line_pred), np.float) 59 | for i in range(len(line_pred)): 60 | if dist[i] < threshold and not hit[choice[i]]: 61 | hit[choice[i]] = True 62 | tp[i] = 1 63 | else: 64 | fp[i] = 1 65 | return tp, fp 66 | 67 | 68 | def TPFP(lines_dt, lines_gt, threshold): 69 | lines_dt = lines_dt.reshape(-1,2,2)[:,:,::-1] 70 | lines_gt = lines_gt.reshape(-1,2,2)[:,:,::-1] 71 | diff = ((lines_dt[:, None, :, None] - lines_gt[:, None]) ** 2).sum(-1) 72 | diff = np.minimum( 73 | diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0] 74 | ) 75 | choice = np.argmin(diff,1) 76 | dist = np.min(diff,1) 77 | hit = np.zeros(len(lines_gt), np.bool) 78 | tp = np.zeros(len(lines_dt), np.float) 79 | fp = np.zeros(len(lines_dt),np.float) 80 | 81 | for i in range(lines_dt.shape[0]): 82 | if dist[i] < threshold and not hit[choice[i]]: 83 | hit[choice[i]] = True 84 | tp[i] = 1 85 | else: 86 | fp[i] = 1 87 | return tp, fp 88 | 89 | def AP(tp, fp): 90 | recall = tp 91 | precision = tp/np.maximum(tp+fp, 1e-9) 92 | 93 | recall = np.concatenate(([0.0], recall, [1.0])) 94 | precision = np.concatenate(([0.0], precision, [0.0])) 95 | 96 | for i in range(precision.size - 1, 0, -1): 97 | precision[i - 1] = max(precision[i - 1], precision[i]) 98 | i = np.where(recall[1:] != recall[:-1])[0] 99 | 100 | ap = np.sum((recall[i + 1] - recall[i]) * precision[i + 1]) 101 | 102 | return ap -------------------------------------------------------------------------------- /mlsd_pytorch/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/mlsd_pytorch/models/__init__.py -------------------------------------------------------------------------------- /mlsd_pytorch/models/build_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../') 3 | 4 | from mlsd_pytorch.models.mbv2_mlsd import MobileV2_MLSD 5 | from mlsd_pytorch.models.mbv2_mlsd_large import MobileV2_MLSD_Large 6 | 7 | 8 | def build_model(cfg): 9 | model_name = cfg.model.model_name 10 | if model_name == 'mobilev2_mlsd': 11 | m = MobileV2_MLSD(cfg) 12 | return m 13 | if model_name == 'mobilev2_mlsd_large': 14 | m = MobileV2_MLSD_Large(cfg) 15 | return m 16 | raise NotImplementedError('{} no such model!'.format(model_name)) 17 | -------------------------------------------------------------------------------- /mlsd_pytorch/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class BlockTypeA(nn.Module): 7 | def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): 8 | super(BlockTypeA, self).__init__() 9 | self.conv1 = nn.Sequential( 10 | nn.Conv2d(in_c2, out_c2, kernel_size=1), 11 | nn.BatchNorm2d(out_c2), 12 | nn.ReLU(inplace=True) 13 | ) 14 | self.conv2 = nn.Sequential( 15 | nn.Conv2d(in_c1, out_c1, kernel_size=1), 16 | nn.BatchNorm2d(out_c1), 17 | nn.ReLU(inplace=True) 18 | ) 19 | self.upscale = upscale 20 | 21 | def forward(self, a, b): 22 | b = self.conv1(b) 23 | a = self.conv2(a) 24 | if self.upscale: 25 | b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) 26 | return torch.cat((a, b), dim=1) 27 | 28 | 29 | class BlockTypeB(nn.Module): 30 | def __init__(self, in_c, out_c): 31 | super(BlockTypeB, self).__init__() 32 | self.conv1 = nn.Sequential( 33 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(in_c), 35 | nn.ReLU() 36 | ) 37 | self.conv2 = nn.Sequential( 38 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), 39 | nn.BatchNorm2d(out_c), 40 | nn.ReLU() 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.conv1(x) + x 45 | x = self.conv2(x) 46 | return x 47 | 48 | class BlockTypeC(nn.Module): 49 | def __init__(self, in_c, out_c): 50 | super(BlockTypeC, self).__init__() 51 | self.conv1 = nn.Sequential( 52 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), 53 | nn.BatchNorm2d(in_c), 54 | nn.ReLU() 55 | ) 56 | self.conv2 = nn.Sequential( 57 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 58 | nn.BatchNorm2d(in_c), 59 | nn.ReLU() 60 | ) 61 | self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) 62 | 63 | def forward(self, x): 64 | x = self.conv1(x) 65 | x = self.conv2(x) 66 | x = self.conv3(x) 67 | return x -------------------------------------------------------------------------------- /mlsd_pytorch/models/mbv2_mlsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import timm 3 | import torch.nn as nn 4 | from mlsd_pytorch.models.layers import * 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | def _make_divisible(v, divisor, min_value=None): 9 | """ 10 | This function is taken from the original tf repo. 11 | It ensures that all layers have a channel number that is divisible by 8 12 | It can be seen here: 13 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 14 | :param v: 15 | :param divisor: 16 | :param min_value: 17 | :return: 18 | """ 19 | if min_value is None: 20 | min_value = divisor 21 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 22 | # Make sure that round down does not go down by more than 10%. 23 | if new_v < 0.9 * v: 24 | new_v += divisor 25 | return new_v 26 | 27 | 28 | class ConvBNReLU(nn.Sequential): 29 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 30 | self.channel_pad = out_planes - in_planes 31 | self.stride = stride 32 | padding = (kernel_size - 1) // 2 33 | 34 | # TFLite uses slightly different padding than PyTorch 35 | # if stride == 2: 36 | # padding = 0 37 | # else: 38 | # padding = (kernel_size - 1) // 2 39 | 40 | super(ConvBNReLU, self).__init__( 41 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 42 | nn.BatchNorm2d(out_planes), 43 | nn.ReLU6(inplace=True) 44 | ) 45 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 46 | 47 | def forward(self, x): 48 | # TFLite uses slightly different padding 49 | # if self.stride == 2: 50 | # x = F.pad(x, (0, 1, 0, 1), "constant", 0) 51 | # #print(x.shape) 52 | 53 | for module in self: 54 | if not isinstance(module, nn.MaxPool2d): 55 | x = module(x) 56 | return x 57 | 58 | 59 | class InvertedResidual(nn.Module): 60 | def __init__(self, inp, oup, stride, expand_ratio): 61 | super(InvertedResidual, self).__init__() 62 | self.stride = stride 63 | assert stride in [1, 2] 64 | 65 | hidden_dim = int(round(inp * expand_ratio)) 66 | self.use_res_connect = self.stride == 1 and inp == oup 67 | 68 | layers = [] 69 | if expand_ratio != 1: 70 | # pw 71 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 72 | layers.extend([ 73 | # dw 74 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 75 | # pw-linear 76 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 77 | nn.BatchNorm2d(oup), 78 | ]) 79 | self.conv = nn.Sequential(*layers) 80 | 81 | def forward(self, x): 82 | if self.use_res_connect: 83 | return x + self.conv(x) 84 | else: 85 | return self.conv(x) 86 | 87 | 88 | class MobileNetV2(nn.Module): 89 | def __init__(self, pretrained=True): 90 | """ 91 | MobileNet V2 main class 92 | Args: 93 | num_classes (int): Number of classes 94 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 95 | inverted_residual_setting: Network structure 96 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 97 | Set to 1 to turn off rounding 98 | block: Module specifying inverted residual building block for mobilenet 99 | """ 100 | super(MobileNetV2, self).__init__() 101 | 102 | block = InvertedResidual 103 | input_channel = 32 104 | last_channel = 1280 105 | width_mult = 1.0 106 | round_nearest = 8 107 | 108 | inverted_residual_setting = [ 109 | # t, c, n, s 110 | [1, 16, 1, 1], 111 | [6, 24, 2, 2], 112 | [6, 32, 3, 2], 113 | [6, 64, 4, 2], 114 | # [6, 96, 3, 1], 115 | # [6, 160, 3, 2], 116 | # [6, 320, 1, 1], 117 | ] 118 | 119 | # only check the first element, assuming user knows t,c,n,s are required 120 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 121 | raise ValueError("inverted_residual_setting should be non-empty " 122 | "or a 4-element list, got {}".format(inverted_residual_setting)) 123 | 124 | # building first layer 125 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 126 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 127 | features = [ConvBNReLU(3, input_channel, stride=2)] 128 | # building inverted residual blocks 129 | for t, c, n, s in inverted_residual_setting: 130 | output_channel = _make_divisible(c * width_mult, round_nearest) 131 | for i in range(n): 132 | stride = s if i == 0 else 1 133 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 134 | input_channel = output_channel 135 | # building last several layers 136 | # features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 137 | # make it nn.Sequential 138 | # print('fea blocks:', len(features)) 139 | self.features = nn.Sequential(*features) 140 | 141 | self.fpn_selected = [1, 3, 6, 10] 142 | # weight initialization 143 | for m in self.modules(): 144 | if isinstance(m, nn.Conv2d): 145 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 146 | if m.bias is not None: 147 | nn.init.zeros_(m.bias) 148 | elif isinstance(m, nn.BatchNorm2d): 149 | nn.init.ones_(m.weight) 150 | nn.init.zeros_(m.bias) 151 | elif isinstance(m, nn.Linear): 152 | nn.init.normal_(m.weight, 0, 0.01) 153 | nn.init.zeros_(m.bias) 154 | 155 | if pretrained: 156 | self._load_pretrained_model() 157 | 158 | def _forward_impl(self, x): 159 | # This exists since TorchScript doesn't support inheritance, so the superclass method 160 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 161 | fpn_features = [] 162 | for i, f in enumerate(self.features): 163 | if i > self.fpn_selected[-1]: 164 | break 165 | x = f(x) 166 | if i in self.fpn_selected: 167 | fpn_features.append(x) 168 | 169 | c1, c2, c3, c4 = fpn_features 170 | return c1, c2, c3, c4 171 | 172 | def forward(self, x): 173 | return self._forward_impl(x) 174 | 175 | def _load_pretrained_model(self): 176 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') 177 | model_dict = {} 178 | state_dict = self.state_dict() 179 | for k, v in pretrain_dict.items(): 180 | if k in state_dict: 181 | model_dict[k] = v 182 | state_dict.update(model_dict) 183 | self.load_state_dict(state_dict) 184 | 185 | 186 | class BilinearConvTranspose2d(nn.ConvTranspose2d): 187 | """A conv transpose initialized to bilinear interpolation.""" 188 | 189 | def __init__(self, channels, stride, groups=1): 190 | """Set up the layer. 191 | 192 | Parameters 193 | ---------- 194 | channels: int 195 | The number of input and output channels 196 | 197 | stride: int or tuple 198 | The amount of upsampling to do 199 | 200 | groups: int 201 | Set to 1 for a standard convolution. Set equal to channels to 202 | make sure there is no cross-talk between channels. 203 | """ 204 | if isinstance(stride, int): 205 | stride = (stride, stride) 206 | 207 | assert groups in (1, channels), "Must use no grouping, " + \ 208 | "or one group per channel" 209 | 210 | kernel_size = (2 * stride[0] - 1, 2 * stride[1] - 1) 211 | padding = (stride[0] - 1, stride[1] - 1) 212 | super().__init__( 213 | channels, channels, 214 | kernel_size=kernel_size, 215 | stride=stride, 216 | padding=padding, 217 | output_padding=padding, 218 | groups=groups) 219 | 220 | def reset_parameters(self): 221 | """Reset the weight and bias.""" 222 | nn.init.constant(self.bias, 0) 223 | nn.init.constant(self.weight, 0) 224 | bilinear_kernel = self.bilinear_kernel(self.stride) 225 | for i in range(self.in_channels): 226 | if self.groups == 1: 227 | j = i 228 | else: 229 | j = 0 230 | self.weight.data[i, j] = bilinear_kernel 231 | 232 | @staticmethod 233 | def bilinear_kernel(stride): 234 | """Generate a bilinear upsampling kernel.""" 235 | num_dims = len(stride) 236 | 237 | shape = (1,) * num_dims 238 | bilinear_kernel = torch.ones(*shape) 239 | 240 | # The bilinear kernel is separable in its spatial dimensions 241 | # Build up the kernel channel by channel 242 | for channel in range(num_dims): 243 | channel_stride = stride[channel] 244 | kernel_size = 2 * channel_stride - 1 245 | # e.g. with stride = 4 246 | # delta = [-3, -2, -1, 0, 1, 2, 3] 247 | # channel_filter = [0.25, 0.5, 0.75, 1.0, 0.75, 0.5, 0.25] 248 | delta = torch.arange(1 - channel_stride, channel_stride) 249 | channel_filter = (1 - torch.abs(delta / channel_stride)) 250 | # Apply the channel filter to the current channel 251 | shape = [1] * num_dims 252 | shape[channel] = kernel_size 253 | bilinear_kernel = bilinear_kernel * channel_filter.view(shape) 254 | return bilinear_kernel 255 | 256 | 257 | class MobileV2_MLSD(nn.Module): 258 | def __init__(self, cfg): 259 | super(MobileV2_MLSD, self).__init__() 260 | 261 | self.backbone = MobileNetV2(pretrained=True) 262 | 263 | self.block12 = BlockTypeA(in_c1=32, in_c2=64, 264 | out_c1=64, out_c2=64) 265 | self.block13 = BlockTypeB(128, 64) 266 | 267 | self.block14 = BlockTypeA(in_c1=24, in_c2=64, 268 | out_c1=32, out_c2=32) 269 | self.block15 = BlockTypeB(64, 64) 270 | 271 | self.block16 = BlockTypeC(64, 16) 272 | 273 | # self.block17 = nn.ConvTranspose2d(in_channels=16, 274 | # out_channels=16, 275 | # kernel_size=3, 276 | # stride=2, 277 | # padding=1, 278 | # output_padding=1, 279 | # bias=False) 280 | 281 | self.with_deconv = cfg.model.with_deconv 282 | 283 | if self.with_deconv: 284 | self.block17 = BilinearConvTranspose2d(16, 2, 1) 285 | self.block17.reset_parameters() 286 | 287 | 288 | def forward(self, x): 289 | c1, c2, c3, c4 = self.backbone(x) 290 | 291 | # print(c1.shape) 292 | 293 | x = self.block12(c3, c4) 294 | x = self.block13(x) 295 | x = self.block14(c2, x) 296 | x = self.block15(x) 297 | x = self.block16(x) 298 | 299 | if self.with_deconv: 300 | x = self.block17(x) 301 | else: 302 | x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 303 | 304 | #x = x[:, 7:, :, :] 305 | #print(x.shape) 306 | return x 307 | 308 | 309 | if __name__ == '__main__': 310 | from mlsd_pytorch.cfg.default import get_cfg_defaults 311 | cfg = get_cfg_defaults() 312 | model = MobileV2_MLSD(cfg) 313 | x = torch.randn((1, 3, 512, 512)) 314 | y = model(x) 315 | 316 | from thop import profile 317 | 318 | flops, params = profile(model, inputs=(x,)) 319 | print('Total params: %.2fM' % (params / 1000000.0)) 320 | print('Total flops: %.2fM' % (flops / 1000000.0)) 321 | -------------------------------------------------------------------------------- /mlsd_pytorch/models/mbv2_mlsd_large.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import timm 3 | import torch.nn as nn 4 | from mlsd_pytorch.models.layers import * 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | def _make_divisible(v, divisor, min_value=None): 9 | """ 10 | This function is taken from the original tf repo. 11 | It ensures that all layers have a channel number that is divisible by 8 12 | It can be seen here: 13 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 14 | :param v: 15 | :param divisor: 16 | :param min_value: 17 | :return: 18 | """ 19 | if min_value is None: 20 | min_value = divisor 21 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 22 | # Make sure that round down does not go down by more than 10%. 23 | if new_v < 0.9 * v: 24 | new_v += divisor 25 | return new_v 26 | 27 | 28 | class ConvBNReLU(nn.Sequential): 29 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 30 | self.channel_pad = out_planes - in_planes 31 | self.stride = stride 32 | padding = (kernel_size - 1) // 2 33 | 34 | # TFLite uses slightly different padding than PyTorch 35 | # if stride == 2: 36 | # padding = 0 37 | # else: 38 | # padding = (kernel_size - 1) // 2 39 | 40 | super(ConvBNReLU, self).__init__( 41 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 42 | nn.BatchNorm2d(out_planes), 43 | nn.ReLU6(inplace=True) 44 | ) 45 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 46 | 47 | 48 | def forward(self, x): 49 | # TFLite uses slightly different padding 50 | # if self.stride == 2: 51 | # x = F.pad(x, (0, 1, 0, 1), "constant", 0) 52 | # #print(x.shape) 53 | 54 | for module in self: 55 | if not isinstance(module, nn.MaxPool2d): 56 | x = module(x) 57 | return x 58 | 59 | 60 | class InvertedResidual(nn.Module): 61 | def __init__(self, inp, oup, stride, expand_ratio): 62 | super(InvertedResidual, self).__init__() 63 | self.stride = stride 64 | assert stride in [1, 2] 65 | 66 | hidden_dim = int(round(inp * expand_ratio)) 67 | self.use_res_connect = self.stride == 1 and inp == oup 68 | 69 | layers = [] 70 | if expand_ratio != 1: 71 | # pw 72 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 73 | layers.extend([ 74 | # dw 75 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 76 | # pw-linear 77 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 78 | nn.BatchNorm2d(oup), 79 | ]) 80 | self.conv = nn.Sequential(*layers) 81 | 82 | def forward(self, x): 83 | if self.use_res_connect: 84 | return x + self.conv(x) 85 | else: 86 | return self.conv(x) 87 | 88 | 89 | class MobileNetV2(nn.Module): 90 | def __init__(self, pretrained=True): 91 | """ 92 | MobileNet V2 main class 93 | Args: 94 | num_classes (int): Number of classes 95 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 96 | inverted_residual_setting: Network structure 97 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 98 | Set to 1 to turn off rounding 99 | block: Module specifying inverted residual building block for mobilenet 100 | """ 101 | super(MobileNetV2, self).__init__() 102 | 103 | block = InvertedResidual 104 | input_channel = 32 105 | last_channel = 1280 106 | width_mult = 1.0 107 | round_nearest = 8 108 | 109 | inverted_residual_setting = [ 110 | # t, c, n, s 111 | [1, 16, 1, 1], 112 | [6, 24, 2, 2], 113 | [6, 32, 3, 2], 114 | [6, 64, 4, 2], 115 | [6, 96, 3, 1], 116 | #[6, 160, 3, 2], 117 | #[6, 320, 1, 1], 118 | ] 119 | 120 | # only check the first element, assuming user knows t,c,n,s are required 121 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 122 | raise ValueError("inverted_residual_setting should be non-empty " 123 | "or a 4-element list, got {}".format(inverted_residual_setting)) 124 | 125 | # building first layer 126 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 127 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 128 | features = [ConvBNReLU(3, input_channel, stride=2)] 129 | # building inverted residual blocks 130 | for t, c, n, s in inverted_residual_setting: 131 | output_channel = _make_divisible(c * width_mult, round_nearest) 132 | for i in range(n): 133 | stride = s if i == 0 else 1 134 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 135 | input_channel = output_channel 136 | 137 | self.features = nn.Sequential(*features) 138 | self.fpn_selected = [1, 3, 6, 10, 13] 139 | # weight initialization 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 143 | if m.bias is not None: 144 | nn.init.zeros_(m.bias) 145 | elif isinstance(m, nn.BatchNorm2d): 146 | nn.init.ones_(m.weight) 147 | nn.init.zeros_(m.bias) 148 | elif isinstance(m, nn.Linear): 149 | nn.init.normal_(m.weight, 0, 0.01) 150 | nn.init.zeros_(m.bias) 151 | if pretrained: 152 | self._load_pretrained_model() 153 | 154 | def _forward_impl(self, x): 155 | # This exists since TorchScript doesn't support inheritance, so the superclass method 156 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 157 | fpn_features = [] 158 | for i, f in enumerate(self.features): 159 | if i > self.fpn_selected[-1]: 160 | break 161 | x = f(x) 162 | if i in self.fpn_selected: 163 | fpn_features.append(x) 164 | 165 | c1, c2, c3, c4, c5 = fpn_features 166 | return c1, c2, c3, c4, c5 167 | 168 | 169 | def forward(self, x): 170 | return self._forward_impl(x) 171 | 172 | def _load_pretrained_model(self): 173 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') 174 | model_dict = {} 175 | state_dict = self.state_dict() 176 | for k, v in pretrain_dict.items(): 177 | if k in state_dict: 178 | model_dict[k] = v 179 | state_dict.update(model_dict) 180 | self.load_state_dict(state_dict) 181 | 182 | 183 | class MobileV2_MLSD_Large(nn.Module): 184 | def __init__(self, cfg): 185 | super(MobileV2_MLSD_Large, self).__init__() 186 | 187 | self.backbone = MobileNetV2(pretrained=True) 188 | ## A, B 189 | self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, 190 | out_c1= 64, out_c2=64, 191 | upscale=False) 192 | self.block16 = BlockTypeB(128, 64) 193 | 194 | ## A, B 195 | self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, 196 | out_c1= 64, out_c2= 64) 197 | self.block18 = BlockTypeB(128, 64) 198 | 199 | ## A, B 200 | self.block19 = BlockTypeA(in_c1=24, in_c2=64, 201 | out_c1=64, out_c2=64) 202 | self.block20 = BlockTypeB(128, 64) 203 | 204 | ## A, B, C 205 | self.block21 = BlockTypeA(in_c1=16, in_c2=64, 206 | out_c1=64, out_c2=64) 207 | self.block22 = BlockTypeB(128, 64) 208 | 209 | self.block23 = BlockTypeC(64, 16) 210 | 211 | def forward(self, x): 212 | c1, c2, c3, c4, c5 = self.backbone(x) 213 | 214 | x = self.block15(c4, c5) 215 | x = self.block16(x) 216 | 217 | x = self.block17(c3, x) 218 | x = self.block18(x) 219 | 220 | x = self.block19(c2, x) 221 | fea = self.block20(x) 222 | 223 | x = self.block21(c1, fea) 224 | x = self.block22(x) 225 | x = self.block23(x) 226 | #x = x[:, 7:, :, :] 227 | 228 | return x 229 | 230 | 231 | if __name__ == '__main__': 232 | from mlsd_pytorch.cfg.default import get_cfg_defaults 233 | cfg = get_cfg_defaults() 234 | model = MobileV2_MLSD_Large(cfg) 235 | x = torch.randn((1,3, 512, 512)) 236 | y = model(x) 237 | 238 | from thop import profile 239 | 240 | flops, params = profile(model, inputs=(x, )) 241 | print('Total params: %.2fM' % (params / 1000000.0)) 242 | print('Total flops: %.2fM' % (flops / 1000000.0)) 243 | -------------------------------------------------------------------------------- /mlsd_pytorch/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/mlsd_pytorch/optim/__init__.py -------------------------------------------------------------------------------- /mlsd_pytorch/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from bisect import bisect_right 3 | import torch 4 | 5 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 6 | def __init__( 7 | self, 8 | optimizer, 9 | milestones,#steps 10 | gamma=0.1, 11 | warmup_factor=1.0 / 3, 12 | warmup_iters=500, 13 | warmup_method="linear", 14 | last_epoch=-1, 15 | ): 16 | if not list(milestones) == sorted(milestones): 17 | raise ValueError( 18 | "Milestones should be a list of" " increasing integers. Got {}", 19 | milestones, 20 | ) 21 | 22 | if warmup_method not in ("constant", "linear"): 23 | raise ValueError( 24 | "Only 'constant' or 'linear' warmup_method accepted" 25 | "got {}".format(warmup_method) 26 | ) 27 | self.milestones = milestones 28 | self.gamma = gamma 29 | self.warmup_factor = warmup_factor 30 | self.warmup_iters = warmup_iters 31 | self.warmup_method = warmup_method 32 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 33 | 34 | def get_lr(self): 35 | warmup_factor = 1 36 | if self.last_epoch < self.warmup_iters: 37 | if self.warmup_method == "constant": 38 | warmup_factor = self.warmup_factor 39 | elif self.warmup_method == "linear": 40 | alpha = self.last_epoch / self.warmup_iters 41 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 42 | return [ 43 | base_lr 44 | * warmup_factor 45 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 46 | for base_lr in self.base_lrs 47 | ] 48 | 49 | -------------------------------------------------------------------------------- /mlsd_pytorch/pred_and_eval_sAP.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(__file__) +'/../') 4 | import cv2 5 | import torch 6 | import json 7 | import tqdm 8 | import argparse 9 | import numpy as np 10 | from mlsd_pytorch.cfg.default import get_cfg_defaults 11 | from mlsd_pytorch.models.build_model import build_model 12 | from mlsd_pytorch.data.utils import deccode_lines 13 | from mlsd_pytorch.metric import msTPFP, AP 14 | from albumentations import Normalize 15 | 16 | def get_args(): 17 | args = argparse.ArgumentParser() 18 | current_dir = default=os.path.dirname(__file__) 19 | args.add_argument("--config", type=str,default = current_dir + '/configs/mobilev2_mlsd_tiny_512_base2_bsize24.yaml') 20 | args.add_argument("--model_path", type=str, 21 | default= current_dir +"/../workdir/pretrained_models/mobilev2_mlsd_tiny_512_bsize24/best.pth") 22 | args.add_argument("--gt_json", type=str, 23 | default= current_dir +"/../data/wireframe_raw/valid.json") 24 | args.add_argument("--img_dir", type=str, 25 | default= current_dir + "/../data/wireframe_raw/images/") 26 | args.add_argument("--sap_thresh", type=float, help="sAP thresh", default=10.0) 27 | args.add_argument("--top_k", type=float, help="top k lines", default= 500) 28 | args.add_argument("--min_len", type=float, help="min len of line", default=5.0) 29 | args.add_argument("--score_thresh", type=float, help="line score thresh", default=0.05) 30 | args.add_argument("--input_size", type=int, help="image input size", default=512) 31 | 32 | return args.parse_args() 33 | 34 | test_aug = Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 35 | 36 | def infer_one(img_fn, model, input_size=512, score_thresh=0.01, min_len=0, topk=200): 37 | img = cv2.imread(img_fn) 38 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 39 | h, w, _ = img.shape 40 | 41 | img = cv2.resize(img, (input_size, input_size)) 42 | #img = (img / 127.5) - 1.0 43 | img = test_aug(image=img)['image'] 44 | img = img.transpose(2,0,1) 45 | img = torch.from_numpy(img).unsqueeze(0).float().cuda() 46 | 47 | with torch.no_grad(): 48 | batch_outputs = model(img) 49 | tp_mask = batch_outputs[:, 7:, :, :] 50 | 51 | center_ptss, pred_lines, scores = deccode_lines(tp_mask, score_thresh, min_len, topk, 3) 52 | 53 | pred_lines = pred_lines.detach().cpu().numpy() 54 | scores = scores.detach().cpu().numpy() 55 | 56 | pred_lines_list = [] 57 | scores_list = [] 58 | for line, score in zip(pred_lines, scores): 59 | x0, y0, x1, y1 = line 60 | 61 | x0 = w * x0 / (input_size / 2) 62 | x1 = w * x1 / (input_size / 2) 63 | 64 | y0 = h * y0 / (input_size / 2) 65 | y1 = h * y1 / (input_size / 2) 66 | 67 | pred_lines_list.append([x0, y0, x1, y1]) 68 | scores_list.append(score) 69 | 70 | return { 71 | 'full_fn': img_fn, 72 | 'filename': os.path.basename(img_fn), 73 | 'width': w, 74 | 'height': h, 75 | 'lines': pred_lines_list, 76 | 'scores': scores_list 77 | } 78 | 79 | 80 | def calculate_sAP(gt_infos, pred_infos, sap_thresh): 81 | assert len(gt_infos) == len(pred_infos) 82 | 83 | tp_list, fp_list, scores_list = [], [], [] 84 | n_gt = 0 85 | 86 | for gt, pred in zip(gt_infos, pred_infos): 87 | assert gt['filename'] == pred['filename'] 88 | h, w = gt['height'], gt['width'] 89 | pred_lines = np.array(pred['lines'], np.float32) 90 | pred_scores = np.array(pred['scores'], np.float32) 91 | 92 | gt_lines = np.array(gt['lines'], np.float32) 93 | scale = np.array([128.0/ w, 128.0/h, 128.0/ w, 128.0/h], np.float32) 94 | pred_lines_128 = pred_lines * scale 95 | gt_lines_128 = gt_lines * scale 96 | 97 | tp, fp = msTPFP(pred_lines_128, gt_lines_128, sap_thresh) 98 | 99 | n_gt += gt_lines_128.shape[0] 100 | tp_list.append(tp) 101 | fp_list.append(fp) 102 | scores_list.append(pred_scores) 103 | 104 | tp_list = np.concatenate(tp_list) 105 | fp_list = np.concatenate(fp_list) 106 | scores_list = np.concatenate(scores_list) 107 | idx = np.argsort(scores_list)[::-1] 108 | tp = np.cumsum(tp_list[idx]) / n_gt 109 | fp = np.cumsum(fp_list[idx]) / n_gt 110 | rcs = tp 111 | pcs = tp / np.maximum(tp + fp, 1e-9) 112 | sAP = AP(tp, fp) * 100 113 | 114 | return sAP 115 | 116 | 117 | def main(args): 118 | cfg = get_cfg_defaults() 119 | if args.config.endswith('\r'): 120 | args.config = args.config[:-1] 121 | print('using config: ', args.config.strip()) 122 | cfg.merge_from_file(args.config) 123 | 124 | model = build_model(cfg).cuda().eval() 125 | model.load_state_dict(torch.load(args.model_path), strict=True) 126 | 127 | label_file = args.gt_json 128 | img_dir = args.img_dir 129 | contens = json.load(open(label_file, 'r')) 130 | 131 | gt_infos = [] 132 | pred_infos = [] 133 | 134 | for c in tqdm.tqdm(contens): 135 | gt_infos.append(c) 136 | fn = c['filename'][:-4] + '.png' 137 | full_fn = img_dir + '/' + fn 138 | pred_infos.append(infer_one(full_fn, model, 139 | args.input_size, 140 | args.score_thresh, 141 | args.min_len, args.top_k )) 142 | 143 | ap = calculate_sAP(gt_infos, pred_infos, args.sap_thresh) 144 | 145 | print("====> sAP{}: {}".format(args.sap_thresh, ap)) 146 | 147 | 148 | if __name__ == '__main__': 149 | main(get_args()) 150 | -------------------------------------------------------------------------------- /mlsd_pytorch/tf_pred_and_eval_sAP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import tqdm 5 | import argparse 6 | import numpy as np 7 | from PIL import Image 8 | import tensorflow as tf 9 | 10 | 11 | def msTPFP(line_pred, line_gt, threshold): 12 | line_pred = line_pred.reshape(-1, 2, 2)[:, :, ::-1] 13 | line_gt = line_gt.reshape(-1, 2, 2)[:, :, ::-1] 14 | diff = ((line_pred[:, None, :, None] - line_gt[:, None]) ** 2).sum(-1) 15 | diff = np.minimum( 16 | diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0] 17 | ) 18 | 19 | choice = np.argmin(diff, 1) 20 | dist = np.min(diff, 1) 21 | hit = np.zeros(len(line_gt), np.bool) 22 | tp = np.zeros(len(line_pred), np.float) 23 | fp = np.zeros(len(line_pred), np.float) 24 | for i in range(len(line_pred)): 25 | if dist[i] < threshold and not hit[choice[i]]: 26 | hit[choice[i]] = True 27 | tp[i] = 1 28 | else: 29 | fp[i] = 1 30 | return tp, fp 31 | 32 | 33 | def TPFP(lines_dt, lines_gt, threshold): 34 | lines_dt = lines_dt.reshape(-1,2,2)[:,:,::-1] 35 | lines_gt = lines_gt.reshape(-1,2,2)[:,:,::-1] 36 | diff = ((lines_dt[:, None, :, None] - lines_gt[:, None]) ** 2).sum(-1) 37 | diff = np.minimum( 38 | diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0] 39 | ) 40 | choice = np.argmin(diff,1) 41 | dist = np.min(diff,1) 42 | hit = np.zeros(len(lines_gt), np.bool) 43 | tp = np.zeros(len(lines_dt), np.float) 44 | fp = np.zeros(len(lines_dt),np.float) 45 | 46 | for i in range(lines_dt.shape[0]): 47 | if dist[i] < threshold and not hit[choice[i]]: 48 | hit[choice[i]] = True 49 | tp[i] = 1 50 | else: 51 | fp[i] = 1 52 | return tp, fp 53 | 54 | def AP(tp, fp): 55 | recall = tp 56 | precision = tp/np.maximum(tp+fp, 1e-9) 57 | 58 | recall = np.concatenate(([0.0], recall, [1.0])) 59 | precision = np.concatenate(([0.0], precision, [0.0])) 60 | 61 | for i in range(precision.size - 1, 0, -1): 62 | precision[i - 1] = max(precision[i - 1], precision[i]) 63 | i = np.where(recall[1:] != recall[:-1])[0] 64 | 65 | ap = np.sum((recall[i + 1] - recall[i]) * precision[i + 1]) 66 | 67 | return ap 68 | 69 | def get_args(): 70 | args = argparse.ArgumentParser() 71 | args.add_argument("--gt_json", type=str, default="/home/lhw/data/wireframe_parsing/wireframe_afm/test.json") 72 | args.add_argument("--img_dir", type=str, default="/home/lhw/data/wireframe_parsing/wireframe_afm/images/") 73 | args.add_argument("--thresh", type=float, help="sAP thresh", default=10.0) 74 | return args.parse_args() 75 | 76 | 77 | model_name = 'tflite_models/M-LSD_512_tiny_fp32.tflite' 78 | interpreter = tf.lite.Interpreter(model_path=model_name) 79 | 80 | interpreter.allocate_tensors() 81 | input_details = interpreter.get_input_details() 82 | output_details = interpreter.get_output_details() 83 | 84 | 85 | 86 | def pred_lines_fn(image, interpreter, input_details, output_details, input_shape=[512, 512], score_thr=0.01, dist_thr=2.0): 87 | h, w, _ = image.shape 88 | h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] 89 | 90 | resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), np.ones([input_shape[0], input_shape[1], 1])], axis=-1) 91 | batch_image = np.expand_dims(resized_image, axis=0).astype('float32') 92 | interpreter.set_tensor(input_details[0]['index'], batch_image) 93 | interpreter.invoke() 94 | 95 | pts = interpreter.get_tensor(output_details[0]['index'])[0] 96 | pts_score = interpreter.get_tensor(output_details[1]['index'])[0] 97 | vmap = interpreter.get_tensor(output_details[2]['index'])[0] 98 | 99 | start = vmap[:,:,:2] 100 | end = vmap[:,:,2:] 101 | dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) 102 | 103 | segments_list = [] 104 | scores_list = [] 105 | for center, score in zip(pts, pts_score): 106 | y, x = center 107 | distance = dist_map[y, x] 108 | if score > score_thr and distance > dist_thr: 109 | disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] 110 | x_start = x + disp_x_start 111 | y_start = y + disp_y_start 112 | x_end = x + disp_x_end 113 | y_end = y + disp_y_end 114 | segments_list.append([x_start, y_start, x_end, y_end]) 115 | scores_list.append(score) 116 | 117 | lines = 2 * np.array(segments_list) # 256 > 512 118 | lines[:,0] = lines[:,0] * w_ratio 119 | lines[:,1] = lines[:,1] * h_ratio 120 | lines[:,2] = lines[:,2] * w_ratio 121 | lines[:,3] = lines[:,3] * h_ratio 122 | 123 | return lines,scores_list 124 | 125 | 126 | def infer_one(img_fn, input_size=512, score_thresh=0.0, min_len=0.0, topk=200): 127 | img = cv2.imread(img_fn) 128 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 129 | h, w, _ = img.shape 130 | 131 | img = cv2.resize(img, (input_size, input_size)) 132 | 133 | pred_lines, scores = pred_lines_fn(img, interpreter, input_details, output_details, 134 | input_shape=[input_size, input_size], score_thr=score_thresh, dist_thr=min_len) 135 | 136 | pred_lines_list = [] 137 | scores_list = [] 138 | for line, score in zip(pred_lines, scores): 139 | x0, y0, x1, y1 = line 140 | 141 | x0 = w * x0 / input_size 142 | x1 = w * x1 / input_size 143 | 144 | y0 = h * y0 / input_size 145 | y1 = h * y1 / input_size 146 | 147 | pred_lines_list.append([x0, y0, x1, y1]) 148 | scores_list.append(score) 149 | 150 | return { 151 | 'full_fn': img_fn, 152 | 'filename': os.path.basename(img_fn), 153 | 'width': w, 154 | 'height': h, 155 | 'lines': pred_lines_list, 156 | 'scores': scores_list 157 | } 158 | 159 | 160 | def calculate_sAP(gt_infos, pred_infos, sap_thresh): 161 | assert len(gt_infos) == len(pred_infos) 162 | 163 | tp_list, fp_list, scores_list = [], [], [] 164 | n_gt = 0 165 | 166 | for gt, pred in zip(gt_infos, pred_infos): 167 | assert gt['filename'] == pred['filename'] 168 | h, w = gt['height'], gt['width'] 169 | pred_lines = np.array(pred['lines'], np.float32) 170 | pred_scores = np.array(pred['scores'], np.float32) 171 | 172 | gt_lines = np.array(gt['lines'], np.float32) 173 | scale = np.array([128.0/ w, 128.0/h, 128.0/ w, 128.0/h], np.float32) 174 | pred_lines_128 = pred_lines * scale 175 | gt_lines_128 = gt_lines * scale 176 | 177 | tp, fp = msTPFP(pred_lines_128, gt_lines_128, sap_thresh) 178 | 179 | n_gt += gt_lines_128.shape[0] 180 | tp_list.append(tp) 181 | fp_list.append(fp) 182 | scores_list.append(pred_scores) 183 | 184 | tp_list = np.concatenate(tp_list) 185 | fp_list = np.concatenate(fp_list) 186 | scores_list = np.concatenate(scores_list) 187 | idx = np.argsort(scores_list)[::-1] 188 | tp = np.cumsum(tp_list[idx]) / n_gt 189 | fp = np.cumsum(fp_list[idx]) / n_gt 190 | rcs = tp 191 | pcs = tp / np.maximum(tp + fp, 1e-9) 192 | sAP = AP(tp, fp) * 100 193 | 194 | return sAP 195 | 196 | 197 | def main(args): 198 | 199 | label_file = args.gt_json 200 | img_dir = args.img_dir 201 | contens = json.load(open(label_file, 'r')) 202 | 203 | gt_infos = [] 204 | pred_infos = [] 205 | 206 | for c in tqdm.tqdm(contens): 207 | gt_infos.append(c) 208 | fn = c['filename'][:-4] + '.png' 209 | full_fn = img_dir + '/' + fn 210 | pred_infos.append(infer_one(full_fn)) 211 | 212 | ap = calculate_sAP(gt_infos, pred_infos, args.thresh) 213 | 214 | print("====> sAP{}: {}".format(args.thresh, ap)) 215 | 216 | 217 | if __name__ == '__main__': 218 | main(get_args()) 219 | 220 | 221 | -------------------------------------------------------------------------------- /mlsd_pytorch/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import math 4 | import sys 5 | 6 | sys.path.append(os.path.dirname(__file__)+'/../') 7 | 8 | from mlsd_pytorch.utils.logger import TxtLogger 9 | from mlsd_pytorch.utils.comm import setup_seed, create_dir 10 | from mlsd_pytorch.cfg.default import get_cfg_defaults 11 | from mlsd_pytorch.optim.lr_scheduler import WarmupMultiStepLR 12 | 13 | from mlsd_pytorch.data import get_train_dataloader, get_val_dataloader 14 | from mlsd_pytorch.learner import Simple_MLSD_Learner 15 | from mlsd_pytorch.models.build_model import build_model 16 | 17 | 18 | import argparse 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--config", 22 | default= os.path.dirname(__file__)+ '/configs/mobilev2_mlsd_tiny_512_base.yaml', 23 | type=str, 24 | help="") 25 | return parser.parse_args() 26 | 27 | def train(cfg): 28 | train_loader = get_train_dataloader(cfg) 29 | val_loader = get_val_dataloader(cfg) 30 | model = build_model(cfg).cuda() 31 | 32 | 33 | #print(model) 34 | if os.path.exists(cfg.train.load_from): 35 | print('load from: ', cfg.train.load_from) 36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | model.load_state_dict(torch.load(cfg.train.load_from,map_location=device),strict=False) 38 | 39 | if cfg.train.milestones_in_epo: 40 | ns = len(train_loader) 41 | milestones = [] 42 | for m in cfg.train.milestones: 43 | milestones.append(m * ns) 44 | cfg.train.milestones = milestones 45 | 46 | optimizer = torch.optim.Adam(params=model.parameters(),lr=cfg.train.learning_rate,weight_decay=cfg.train.weight_decay) 47 | 48 | if cfg.train.use_step_lr_policy: 49 | 50 | lr_scheduler = WarmupMultiStepLR( 51 | optimizer, 52 | milestones= cfg.train.milestones, 53 | gamma = cfg.train.lr_decay_gamma, 54 | warmup_iters=cfg.train.warmup_steps, 55 | ) 56 | else: ## similiar with in the paper 57 | warmup_steps = 5 * len(train_loader) ## 5 epoch warmup 58 | min_lr_scale = 0.0001 59 | start_step = 70 * len(train_loader) 60 | end_step = 150 * len(train_loader) 61 | n_t = 0.5 62 | lr_lambda_fn = lambda step: (0.9 * step / warmup_steps + 0.1) if step < warmup_steps else \ 63 | 1.0 if step < start_step else \ 64 | min_lr_scale if \ 65 | n_t * (1 + math.cos(math.pi * (step - start_step) / (end_step - start_step))) < min_lr_scale else \ 66 | n_t * (1 + math.cos(math.pi * (step - start_step) / (end_step - start_step))) 67 | 68 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fn) 69 | 70 | create_dir(cfg.train.save_dir) 71 | logger = TxtLogger(cfg.train.save_dir + "/train_logger.txt") 72 | learner = Simple_MLSD_Learner( 73 | cfg, 74 | model = model, 75 | optimizer = optimizer, 76 | scheduler = lr_scheduler, 77 | logger = logger, 78 | save_dir = cfg.train.save_dir, 79 | log_steps = cfg.train.log_steps, 80 | device_ids = cfg.train.device_ids, 81 | gradient_accum_steps = 1, 82 | max_grad_norm = 1000.0, 83 | batch_to_model_inputs_fn = None, 84 | early_stop_n= cfg.train.early_stop_n) 85 | 86 | #learner.val(model, val_loader) 87 | #learner.val(model, train_loader) 88 | learner.train(train_loader, val_loader, epoches= cfg.train.num_train_epochs) 89 | 90 | 91 | if __name__ == '__main__': 92 | setup_seed(6666) 93 | cfg = get_cfg_defaults() 94 | args = get_args() 95 | 96 | if args.config.endswith('\r'): 97 | args.config = args.config[:-1] 98 | print('using config: ',args.config.strip()) 99 | cfg.merge_from_file(args.config) 100 | print(cfg) 101 | 102 | create_dir(cfg.train.save_dir) 103 | cfg_str = cfg.dump() 104 | with open(cfg.train.save_dir+ "/cfg.yaml", "w") as f: 105 | f.write(cfg_str) 106 | f.close() 107 | 108 | train(cfg) -------------------------------------------------------------------------------- /mlsd_pytorch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/mlsd_pytorch/utils/__init__.py -------------------------------------------------------------------------------- /mlsd_pytorch/utils/comm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import errno 4 | import random 5 | import torch 6 | 7 | def setup_seed(seed): 8 | random.seed(seed) 9 | os.environ['PYTHONHASHSEED'] = str(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | # some cudnn methods can be random even after fixing the seed 15 | # unless you tell it to be deterministic 16 | torch.backends.cudnn.deterministic = True 17 | 18 | 19 | def create_dir(path): 20 | if not os.path.exists(path): 21 | try: 22 | os.makedirs(path) 23 | except OSError as exc: 24 | if exc.errno != errno.EEXIST: 25 | raise 26 | 27 | -------------------------------------------------------------------------------- /mlsd_pytorch/utils/decode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def deccode_lines_TP(tpMap, score_thresh=0.1, len_thresh=2, topk_n=1000, ksize=3): 7 | ''' 8 | tpMap: 9 | center: tpMap[1, 0, :, :] 10 | displacement: tpMap[1, 1:5, :, :] 11 | ''' 12 | b, c, h, w = tpMap.shape 13 | assert b == 1, 'only support bsize==1' 14 | displacement = tpMap[:, 1:5, :, :] 15 | center = tpMap[:, 0, :, :] 16 | heat = torch.sigmoid(center) 17 | hmax = F.max_pool2d(heat, (ksize, ksize), stride=1, padding=(ksize - 1) // 2) 18 | keep = (hmax == heat).float() 19 | heat = heat * keep 20 | heat = heat.reshape(-1, ) 21 | 22 | scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) 23 | valid_inx = torch.where(scores > score_thresh) 24 | scores = scores[valid_inx] 25 | indices = indices[valid_inx] 26 | 27 | yy = torch.floor_divide(indices, w).unsqueeze(-1) 28 | xx = torch.fmod(indices, w).unsqueeze(-1) 29 | center_ptss = torch.cat((xx, yy), dim=-1) 30 | 31 | start_point = center_ptss + displacement[0, :2, yy, xx].reshape(2, -1).permute(1, 0) 32 | end_point = center_ptss + displacement[0, 2:, yy, xx].reshape(2, -1).permute(1, 0) 33 | 34 | lines = torch.cat((start_point, end_point), dim=-1) 35 | 36 | lines_swap = torch.cat((end_point, start_point), dim=-1) 37 | 38 | all_lens = (end_point - start_point) ** 2 39 | all_lens = all_lens.sum(dim=-1) 40 | all_lens = torch.sqrt(all_lens) 41 | valid_inx = torch.where(all_lens > len_thresh) 42 | 43 | center_ptss = center_ptss[valid_inx] 44 | lines = lines[valid_inx] 45 | lines_swap = lines_swap[valid_inx] 46 | scores = scores[valid_inx] 47 | 48 | return center_ptss, lines, lines_swap, scores 49 | -------------------------------------------------------------------------------- /mlsd_pytorch/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | class TxtLogger(object): 4 | def __init__(self, output_name): 5 | dirname = os.path.dirname(output_name) 6 | if not os.path.exists(dirname): 7 | os.mkdir(dirname) 8 | 9 | self.log_file = open(output_name, 'w') 10 | self.infos = {} 11 | 12 | def append(self, key, val): 13 | vals = self.infos.setdefault(key, []) 14 | vals.append(val) 15 | 16 | def log(self, extra_msg=''): 17 | msgs = [extra_msg] 18 | for key, vals in self.infos.iteritems(): 19 | msgs.append('%s %.6f' % (key, np.mean(vals))) 20 | msg = '\n'.join(msgs) 21 | self.log_file.write(msg + '\n') 22 | self.log_file.flush() 23 | self.infos = {} 24 | return msg 25 | 26 | def write(self, msg): 27 | self.log_file.write(msg + '\n') 28 | self.log_file.flush() 29 | print(msg) 30 | 31 | def close(self): 32 | self.log_file.close() -------------------------------------------------------------------------------- /mlsd_pytorch/utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /models/mbv2_mlsd_large.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.nn import functional as F 7 | 8 | 9 | class BlockTypeA(nn.Module): 10 | def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): 11 | super(BlockTypeA, self).__init__() 12 | self.conv1 = nn.Sequential( 13 | nn.Conv2d(in_c2, out_c2, kernel_size=1), 14 | nn.BatchNorm2d(out_c2), 15 | nn.ReLU(inplace=True) 16 | ) 17 | self.conv2 = nn.Sequential( 18 | nn.Conv2d(in_c1, out_c1, kernel_size=1), 19 | nn.BatchNorm2d(out_c1), 20 | nn.ReLU(inplace=True) 21 | ) 22 | self.upscale = upscale 23 | 24 | def forward(self, a, b): 25 | b = self.conv1(b) 26 | a = self.conv2(a) 27 | if self.upscale: 28 | b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) 29 | return torch.cat((a, b), dim=1) 30 | 31 | 32 | class BlockTypeB(nn.Module): 33 | def __init__(self, in_c, out_c): 34 | super(BlockTypeB, self).__init__() 35 | self.conv1 = nn.Sequential( 36 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 37 | nn.BatchNorm2d(in_c), 38 | nn.ReLU() 39 | ) 40 | self.conv2 = nn.Sequential( 41 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), 42 | nn.BatchNorm2d(out_c), 43 | nn.ReLU() 44 | ) 45 | 46 | def forward(self, x): 47 | x = self.conv1(x) + x 48 | x = self.conv2(x) 49 | return x 50 | 51 | class BlockTypeC(nn.Module): 52 | def __init__(self, in_c, out_c): 53 | super(BlockTypeC, self).__init__() 54 | self.conv1 = nn.Sequential( 55 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), 56 | nn.BatchNorm2d(in_c), 57 | nn.ReLU() 58 | ) 59 | self.conv2 = nn.Sequential( 60 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(in_c), 62 | nn.ReLU() 63 | ) 64 | self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) 65 | 66 | def forward(self, x): 67 | x = self.conv1(x) 68 | x = self.conv2(x) 69 | x = self.conv3(x) 70 | return x 71 | 72 | def _make_divisible(v, divisor, min_value=None): 73 | """ 74 | This function is taken from the original tf repo. 75 | It ensures that all layers have a channel number that is divisible by 8 76 | It can be seen here: 77 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 78 | :param v: 79 | :param divisor: 80 | :param min_value: 81 | :return: 82 | """ 83 | if min_value is None: 84 | min_value = divisor 85 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 86 | # Make sure that round down does not go down by more than 10%. 87 | if new_v < 0.9 * v: 88 | new_v += divisor 89 | return new_v 90 | 91 | 92 | class ConvBNReLU(nn.Sequential): 93 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 94 | self.channel_pad = out_planes - in_planes 95 | self.stride = stride 96 | #padding = (kernel_size - 1) // 2 97 | 98 | # TFLite uses slightly different padding than PyTorch 99 | if stride == 2: 100 | padding = 0 101 | else: 102 | padding = (kernel_size - 1) // 2 103 | 104 | super(ConvBNReLU, self).__init__( 105 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 106 | nn.BatchNorm2d(out_planes), 107 | nn.ReLU6(inplace=True) 108 | ) 109 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 110 | 111 | 112 | def forward(self, x): 113 | # TFLite uses different padding 114 | if self.stride == 2: 115 | x = F.pad(x, (0, 1, 0, 1), "constant", 0) 116 | #print(x.shape) 117 | 118 | for module in self: 119 | if not isinstance(module, nn.MaxPool2d): 120 | x = module(x) 121 | return x 122 | 123 | 124 | class InvertedResidual(nn.Module): 125 | def __init__(self, inp, oup, stride, expand_ratio): 126 | super(InvertedResidual, self).__init__() 127 | self.stride = stride 128 | assert stride in [1, 2] 129 | 130 | hidden_dim = int(round(inp * expand_ratio)) 131 | self.use_res_connect = self.stride == 1 and inp == oup 132 | 133 | layers = [] 134 | if expand_ratio != 1: 135 | # pw 136 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 137 | layers.extend([ 138 | # dw 139 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 140 | # pw-linear 141 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 142 | nn.BatchNorm2d(oup), 143 | ]) 144 | self.conv = nn.Sequential(*layers) 145 | 146 | def forward(self, x): 147 | if self.use_res_connect: 148 | return x + self.conv(x) 149 | else: 150 | return self.conv(x) 151 | 152 | 153 | class MobileNetV2(nn.Module): 154 | def __init__(self, pretrained=True): 155 | """ 156 | MobileNet V2 main class 157 | Args: 158 | num_classes (int): Number of classes 159 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 160 | inverted_residual_setting: Network structure 161 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 162 | Set to 1 to turn off rounding 163 | block: Module specifying inverted residual building block for mobilenet 164 | """ 165 | super(MobileNetV2, self).__init__() 166 | 167 | block = InvertedResidual 168 | input_channel = 32 169 | last_channel = 1280 170 | width_mult = 1.0 171 | round_nearest = 8 172 | 173 | inverted_residual_setting = [ 174 | # t, c, n, s 175 | [1, 16, 1, 1], 176 | [6, 24, 2, 2], 177 | [6, 32, 3, 2], 178 | [6, 64, 4, 2], 179 | [6, 96, 3, 1], 180 | #[6, 160, 3, 2], 181 | #[6, 320, 1, 1], 182 | ] 183 | 184 | # only check the first element, assuming user knows t,c,n,s are required 185 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 186 | raise ValueError("inverted_residual_setting should be non-empty " 187 | "or a 4-element list, got {}".format(inverted_residual_setting)) 188 | 189 | # building first layer 190 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 191 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 192 | features = [ConvBNReLU(4, input_channel, stride=2)] 193 | # building inverted residual blocks 194 | for t, c, n, s in inverted_residual_setting: 195 | output_channel = _make_divisible(c * width_mult, round_nearest) 196 | for i in range(n): 197 | stride = s if i == 0 else 1 198 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 199 | input_channel = output_channel 200 | 201 | self.features = nn.Sequential(*features) 202 | self.fpn_selected = [1, 3, 6, 10, 13] 203 | # weight initialization 204 | for m in self.modules(): 205 | if isinstance(m, nn.Conv2d): 206 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 207 | if m.bias is not None: 208 | nn.init.zeros_(m.bias) 209 | elif isinstance(m, nn.BatchNorm2d): 210 | nn.init.ones_(m.weight) 211 | nn.init.zeros_(m.bias) 212 | elif isinstance(m, nn.Linear): 213 | nn.init.normal_(m.weight, 0, 0.01) 214 | nn.init.zeros_(m.bias) 215 | if pretrained: 216 | self._load_pretrained_model() 217 | 218 | def _forward_impl(self, x): 219 | # This exists since TorchScript doesn't support inheritance, so the superclass method 220 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 221 | fpn_features = [] 222 | for i, f in enumerate(self.features): 223 | if i > self.fpn_selected[-1]: 224 | break 225 | x = f(x) 226 | if i in self.fpn_selected: 227 | fpn_features.append(x) 228 | 229 | c1, c2, c3, c4, c5 = fpn_features 230 | return c1, c2, c3, c4, c5 231 | 232 | 233 | def forward(self, x): 234 | return self._forward_impl(x) 235 | 236 | def _load_pretrained_model(self): 237 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') 238 | model_dict = {} 239 | state_dict = self.state_dict() 240 | for k, v in pretrain_dict.items(): 241 | if k in state_dict: 242 | model_dict[k] = v 243 | state_dict.update(model_dict) 244 | self.load_state_dict(state_dict) 245 | 246 | 247 | class MobileV2_MLSD_Large(nn.Module): 248 | def __init__(self): 249 | super(MobileV2_MLSD_Large, self).__init__() 250 | 251 | self.backbone = MobileNetV2(pretrained=False) 252 | ## A, B 253 | self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, 254 | out_c1= 64, out_c2=64, 255 | upscale=False) 256 | self.block16 = BlockTypeB(128, 64) 257 | 258 | ## A, B 259 | self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, 260 | out_c1= 64, out_c2= 64) 261 | self.block18 = BlockTypeB(128, 64) 262 | 263 | ## A, B 264 | self.block19 = BlockTypeA(in_c1=24, in_c2=64, 265 | out_c1=64, out_c2=64) 266 | self.block20 = BlockTypeB(128, 64) 267 | 268 | ## A, B, C 269 | self.block21 = BlockTypeA(in_c1=16, in_c2=64, 270 | out_c1=64, out_c2=64) 271 | self.block22 = BlockTypeB(128, 64) 272 | 273 | self.block23 = BlockTypeC(64, 16) 274 | 275 | def forward(self, x): 276 | c1, c2, c3, c4, c5 = self.backbone(x) 277 | 278 | x = self.block15(c4, c5) 279 | x = self.block16(x) 280 | 281 | x = self.block17(c3, x) 282 | x = self.block18(x) 283 | 284 | x = self.block19(c2, x) 285 | x = self.block20(x) 286 | 287 | x = self.block21(c1, x) 288 | x = self.block22(x) 289 | x = self.block23(x) 290 | x = x[:, 7:, :, :] 291 | 292 | return x -------------------------------------------------------------------------------- /models/mbv2_mlsd_tiny.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.nn import functional as F 7 | 8 | 9 | class BlockTypeA(nn.Module): 10 | def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): 11 | super(BlockTypeA, self).__init__() 12 | self.conv1 = nn.Sequential( 13 | nn.Conv2d(in_c2, out_c2, kernel_size=1), 14 | nn.BatchNorm2d(out_c2), 15 | nn.ReLU(inplace=True) 16 | ) 17 | self.conv2 = nn.Sequential( 18 | nn.Conv2d(in_c1, out_c1, kernel_size=1), 19 | nn.BatchNorm2d(out_c1), 20 | nn.ReLU(inplace=True) 21 | ) 22 | self.upscale = upscale 23 | 24 | def forward(self, a, b): 25 | b = self.conv1(b) 26 | a = self.conv2(a) 27 | b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) 28 | return torch.cat((a, b), dim=1) 29 | 30 | 31 | class BlockTypeB(nn.Module): 32 | def __init__(self, in_c, out_c): 33 | super(BlockTypeB, self).__init__() 34 | self.conv1 = nn.Sequential( 35 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(in_c), 37 | nn.ReLU() 38 | ) 39 | self.conv2 = nn.Sequential( 40 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), 41 | nn.BatchNorm2d(out_c), 42 | nn.ReLU() 43 | ) 44 | 45 | def forward(self, x): 46 | x = self.conv1(x) + x 47 | x = self.conv2(x) 48 | return x 49 | 50 | class BlockTypeC(nn.Module): 51 | def __init__(self, in_c, out_c): 52 | super(BlockTypeC, self).__init__() 53 | self.conv1 = nn.Sequential( 54 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), 55 | nn.BatchNorm2d(in_c), 56 | nn.ReLU() 57 | ) 58 | self.conv2 = nn.Sequential( 59 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 60 | nn.BatchNorm2d(in_c), 61 | nn.ReLU() 62 | ) 63 | self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) 64 | 65 | def forward(self, x): 66 | x = self.conv1(x) 67 | x = self.conv2(x) 68 | x = self.conv3(x) 69 | return x 70 | 71 | def _make_divisible(v, divisor, min_value=None): 72 | """ 73 | This function is taken from the original tf repo. 74 | It ensures that all layers have a channel number that is divisible by 8 75 | It can be seen here: 76 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 77 | :param v: 78 | :param divisor: 79 | :param min_value: 80 | :return: 81 | """ 82 | if min_value is None: 83 | min_value = divisor 84 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 85 | # Make sure that round down does not go down by more than 10%. 86 | if new_v < 0.9 * v: 87 | new_v += divisor 88 | return new_v 89 | 90 | 91 | class ConvBNReLU(nn.Sequential): 92 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 93 | self.channel_pad = out_planes - in_planes 94 | self.stride = stride 95 | #padding = (kernel_size - 1) // 2 96 | 97 | # TFLite uses slightly different padding than PyTorch 98 | if stride == 2: 99 | padding = 0 100 | else: 101 | padding = (kernel_size - 1) // 2 102 | 103 | super(ConvBNReLU, self).__init__( 104 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 105 | nn.BatchNorm2d(out_planes), 106 | nn.ReLU6(inplace=True) 107 | ) 108 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 109 | 110 | 111 | def forward(self, x): 112 | # TFLite uses different padding 113 | if self.stride == 2: 114 | x = F.pad(x, (0, 1, 0, 1), "constant", 0) 115 | #print(x.shape) 116 | 117 | for module in self: 118 | if not isinstance(module, nn.MaxPool2d): 119 | x = module(x) 120 | return x 121 | 122 | 123 | class InvertedResidual(nn.Module): 124 | def __init__(self, inp, oup, stride, expand_ratio): 125 | super(InvertedResidual, self).__init__() 126 | self.stride = stride 127 | assert stride in [1, 2] 128 | 129 | hidden_dim = int(round(inp * expand_ratio)) 130 | self.use_res_connect = self.stride == 1 and inp == oup 131 | 132 | layers = [] 133 | if expand_ratio != 1: 134 | # pw 135 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 136 | layers.extend([ 137 | # dw 138 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 139 | # pw-linear 140 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 141 | nn.BatchNorm2d(oup), 142 | ]) 143 | self.conv = nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | if self.use_res_connect: 147 | return x + self.conv(x) 148 | else: 149 | return self.conv(x) 150 | 151 | 152 | class MobileNetV2(nn.Module): 153 | def __init__(self, pretrained=True): 154 | """ 155 | MobileNet V2 main class 156 | Args: 157 | num_classes (int): Number of classes 158 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 159 | inverted_residual_setting: Network structure 160 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 161 | Set to 1 to turn off rounding 162 | block: Module specifying inverted residual building block for mobilenet 163 | """ 164 | super(MobileNetV2, self).__init__() 165 | 166 | block = InvertedResidual 167 | input_channel = 32 168 | last_channel = 1280 169 | width_mult = 1.0 170 | round_nearest = 8 171 | 172 | inverted_residual_setting = [ 173 | # t, c, n, s 174 | [1, 16, 1, 1], 175 | [6, 24, 2, 2], 176 | [6, 32, 3, 2], 177 | [6, 64, 4, 2], 178 | #[6, 96, 3, 1], 179 | #[6, 160, 3, 2], 180 | #[6, 320, 1, 1], 181 | ] 182 | 183 | # only check the first element, assuming user knows t,c,n,s are required 184 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 185 | raise ValueError("inverted_residual_setting should be non-empty " 186 | "or a 4-element list, got {}".format(inverted_residual_setting)) 187 | 188 | # building first layer 189 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 190 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 191 | features = [ConvBNReLU(4, input_channel, stride=2)] 192 | # building inverted residual blocks 193 | for t, c, n, s in inverted_residual_setting: 194 | output_channel = _make_divisible(c * width_mult, round_nearest) 195 | for i in range(n): 196 | stride = s if i == 0 else 1 197 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 198 | input_channel = output_channel 199 | self.features = nn.Sequential(*features) 200 | 201 | self.fpn_selected = [3, 6, 10] 202 | # weight initialization 203 | for m in self.modules(): 204 | if isinstance(m, nn.Conv2d): 205 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 206 | if m.bias is not None: 207 | nn.init.zeros_(m.bias) 208 | elif isinstance(m, nn.BatchNorm2d): 209 | nn.init.ones_(m.weight) 210 | nn.init.zeros_(m.bias) 211 | elif isinstance(m, nn.Linear): 212 | nn.init.normal_(m.weight, 0, 0.01) 213 | nn.init.zeros_(m.bias) 214 | 215 | #if pretrained: 216 | # self._load_pretrained_model() 217 | 218 | def _forward_impl(self, x): 219 | # This exists since TorchScript doesn't support inheritance, so the superclass method 220 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 221 | fpn_features = [] 222 | for i, f in enumerate(self.features): 223 | if i > self.fpn_selected[-1]: 224 | break 225 | x = f(x) 226 | if i in self.fpn_selected: 227 | fpn_features.append(x) 228 | 229 | c2, c3, c4 = fpn_features 230 | return c2, c3, c4 231 | 232 | 233 | def forward(self, x): 234 | return self._forward_impl(x) 235 | 236 | def _load_pretrained_model(self): 237 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') 238 | model_dict = {} 239 | state_dict = self.state_dict() 240 | for k, v in pretrain_dict.items(): 241 | if k in state_dict: 242 | model_dict[k] = v 243 | state_dict.update(model_dict) 244 | self.load_state_dict(state_dict) 245 | 246 | 247 | class MobileV2_MLSD_Tiny(nn.Module): 248 | def __init__(self): 249 | super(MobileV2_MLSD_Tiny, self).__init__() 250 | 251 | self.backbone = MobileNetV2(pretrained=True) 252 | 253 | self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, 254 | out_c1= 64, out_c2=64) 255 | self.block13 = BlockTypeB(128, 64) 256 | 257 | self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, 258 | out_c1= 32, out_c2= 32) 259 | self.block15 = BlockTypeB(64, 64) 260 | 261 | self.block16 = BlockTypeC(64, 16) 262 | 263 | def forward(self, x): 264 | c2, c3, c4 = self.backbone(x) 265 | 266 | x = self.block12(c3, c4) 267 | x = self.block13(x) 268 | x = self.block14(c2, x) 269 | x = self.block15(x) 270 | x = self.block16(x) 271 | x = x[:, 7:, :, :] 272 | #print(x.shape) 273 | x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 274 | 275 | return x -------------------------------------------------------------------------------- /models/mlsd_large_512_fp32.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/models/mlsd_large_512_fp32.pth -------------------------------------------------------------------------------- /models/mlsd_tiny_512_fp32.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/models/mlsd_tiny_512_fp32.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | pillow 4 | torch 5 | Flask 6 | requests 7 | -------------------------------------------------------------------------------- /static/css/app.css: -------------------------------------------------------------------------------- 1 | #app { 2 | padding: 20px; 3 | } 4 | 5 | #result .item { 6 | padding-bottom: 20px; 7 | } 8 | 9 | .form-content-container { 10 | padding-left: 20px; 11 | } 12 | -------------------------------------------------------------------------------- /static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/static/favicon.ico -------------------------------------------------------------------------------- /templates/index_scan.html: -------------------------------------------------------------------------------- 1 | 2 | 7 | 8 | 9 | MLSD demo 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 40 | 41 |
42 |
43 |
44 |
MLSD demo
45 |
46 | image_url:
47 | image_data:
48 | 49 |
50 |
51 |
52 |
53 |
54 |
55 |
Output_image
56 |
    57 | 58 |
59 |
60 | 61 |
Input_image
62 |
    63 | 64 |
65 |
66 |
67 |
68 |
69 | 72 |
73 | 74 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /trt_converter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | import torch 8 | try: 9 | from torch2trt import torch2trt 10 | except ImportError: 11 | raise ImportError('Please ensure that you install torch2trt!') 12 | 13 | import tensorrt as trt 14 | 15 | from torch.nn import functional as F 16 | 17 | from models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny 18 | from models.mbv2_mlsd_large import MobileV2_MLSD_Large 19 | 20 | from calibrator import ImageFolderCalibDataset 21 | 22 | from argparse import ArgumentParser, SUPPRESS 23 | 24 | def build_argparser(): 25 | parser = ArgumentParser(add_help=False) 26 | args = parser.add_argument_group('Options') 27 | args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.') 28 | args.add_argument("-m", "--model", help="model type", choices=['tiny', 'large'], default='tiny') 29 | args.add_argument("-e", "--engine", help="converted engine path", type=str, default=None) 30 | args.add_argument("-c", "--conversion", help="Conversion type", choices=['fp16', 'int8', 'onnx'], default='fp16') 31 | args.add_argument("-cd", "--calibration_data", help="Path to int8 calibration data", type=str, default='') 32 | args.add_argument("-cb", "--calibration_batch", help="Calibration batch size", type=int, default=32) 33 | args.add_argument("-s", "--serialize", help="Serialize trt engine to disk", action="store_true") 34 | args.add_argument("-b", "--bench", help="Toggle simple inference cost analysis", action="store_true") 35 | 36 | return parser 37 | 38 | def onnx_convert(dummy_input, model, model_path, opset=11, device='cpu'): 39 | print('converting to onnx...') 40 | 41 | out = f"{os.path.splitext(model_path)[0]}.onnx" 42 | 43 | model = MobileV2_MLSD_Tiny().to(device).eval() 44 | dummy_input = torch.randn(1, 4, 512, 512).float().to(device) 45 | model.load_state_dict(torch.load(model_path), strict=True) 46 | 47 | torch.onnx.export(model, 48 | dummy_input, 49 | out, 50 | verbose=True, 51 | opset_version=opset 52 | ) 53 | 54 | print(f'converted successfuly at: {out}') 55 | 56 | 57 | def main(model_type='tiny', 58 | conversion='fp16', 59 | engine_path='', 60 | serialize=False, 61 | calibration_data='' , 62 | calibration_batch=32, 63 | bench=False 64 | ): 65 | 66 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 67 | 68 | model_path = f'./models/mlsd_{model_type}_512_fp32.pth' 69 | 70 | model = MobileV2_MLSD_Tiny().cuda().eval() 71 | model.load_state_dict(torch.load(model_path, map_location=device), strict=True) 72 | 73 | 74 | dummy_input = torch.randn(1, 4, 512, 512).float().to(device) 75 | 76 | if conversion == 'fp16': 77 | 78 | out_path = f'./models/mlsd_{model_type}__512_trt_{conversion}.pth' 79 | 80 | 81 | model = torch2trt(model, 82 | [dummy_input], 83 | fp16_mode=True, 84 | log_level=trt.Logger.INFO, 85 | max_workspace_size= 1 << 28 86 | ) 87 | 88 | print(f'\nsaving model to {out_path}\n') 89 | torch.save(model.state_dict(), out_path) 90 | 91 | 92 | elif conversion == 'int8': 93 | 94 | out_path = f'./models/mlsd_{model_type}__512_trt_{conversion}.pth' 95 | 96 | assert os.path.exists(calibration_data) == True, 'Calibration path does not exist!' 97 | 98 | dataset = ImageFolderCalibDataset(calibration_data) 99 | model = torch2trt(model, 100 | [dummy_input], 101 | int8_mode=True, 102 | fp16_mode=True, 103 | int8_calib_dataset=dataset, 104 | int8_calib_batch_size=calibration_batch, 105 | log_level=trt.Logger.INFO, 106 | max_workspace_size= 1 << 28 107 | ) 108 | 109 | print(f'\nsaving model to {out_path}\n') 110 | torch.save(model.state_dict(), out_path) 111 | 112 | elif conversion == 'onnx': 113 | onnx_convert(dummy_input, model, model_path) 114 | 115 | 116 | if serialize and conversion != 'onnx': 117 | 118 | print('\nsaving serialized engine to: {engine_path}\n') 119 | 120 | with open(engine_path, "wb") as f: 121 | f.write(model.engine.serialize()) 122 | 123 | if bench and conversion != 'onnx': 124 | 125 | print('Benchmarking after warmup...\n') 126 | 127 | for i in range(500): 128 | output = model(dummy_input) 129 | 130 | torch.cuda.current_stream().synchronize() 131 | t0 = time.monotonic() 132 | for i in range(100): 133 | output = model(dummy_input) 134 | 135 | it0 = time.monotonic() 136 | output = output = model(dummy_input) 137 | it1 = time.monotonic() 138 | 139 | torch.cuda.current_stream().synchronize() 140 | t1 = time.monotonic() 141 | fps = 100.0 / (t1 - t0) 142 | 143 | print(f'FPS: {fps:.2f}') 144 | print(f'Inference cost: {(it1-it0)*1000:.2f} ms') 145 | 146 | if __name__ == '__main__': 147 | 148 | args = build_argparser().parse_args() 149 | 150 | main(args.model, 151 | args.conversion, 152 | args.engine, 153 | args.serialize, 154 | args.calibration_data, 155 | args.calibration_batch, 156 | args.bench 157 | ) 158 | 159 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | modified by lihaoweicv 3 | pytorch version 4 | ''' 5 | 6 | ''' 7 | M-LSD 8 | Copyright 2021-present NAVER Corp. 9 | Apache License v2.0 10 | ''' 11 | 12 | import os 13 | import numpy as np 14 | import cv2 15 | import torch 16 | from torch.nn import functional as F 17 | 18 | 19 | def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): 20 | ''' 21 | tpMap: 22 | center: tpMap[1, 0, :, :] 23 | displacement: tpMap[1, 1:5, :, :] 24 | ''' 25 | b, c, h, w = tpMap.shape 26 | assert b==1, 'only support bsize==1' 27 | displacement = tpMap[:, 1:5, :, :][0] 28 | center = tpMap[:, 0, :, :] 29 | heat = torch.sigmoid(center) 30 | hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) 31 | keep = (hmax == heat).float() 32 | heat = heat * keep 33 | heat = heat.reshape(-1, ) 34 | 35 | scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) 36 | yy = torch.floor_divide(indices, w).unsqueeze(-1) 37 | xx = torch.fmod(indices, w).unsqueeze(-1) 38 | ptss = torch.cat((yy, xx),dim=-1) 39 | 40 | ptss = ptss.detach().cpu().numpy() 41 | scores = scores.detach().cpu().numpy() 42 | displacement = displacement.detach().cpu().numpy() 43 | displacement = displacement.transpose((1,2,0)) 44 | return ptss, scores, displacement 45 | 46 | 47 | def pred_lines(image, model, 48 | input_shape=[512, 512], 49 | score_thr=0.10, 50 | dist_thr=20.0): 51 | h, w, _ = image.shape 52 | h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] 53 | 54 | resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), 55 | np.ones([input_shape[0], input_shape[1], 1])], axis=-1) 56 | 57 | resized_image = resized_image.transpose((2,0,1)) 58 | batch_image = np.expand_dims(resized_image, axis=0).astype('float32') 59 | batch_image = (batch_image / 127.5) - 1.0 60 | 61 | batch_image = torch.from_numpy(batch_image).float().cuda() 62 | outputs = model(batch_image) 63 | pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) 64 | start = vmap[:, :, :2] 65 | end = vmap[:, :, 2:] 66 | dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) 67 | 68 | segments_list = [] 69 | for center, score in zip(pts, pts_score): 70 | y, x = center 71 | distance = dist_map[y, x] 72 | if score > score_thr and distance > dist_thr: 73 | disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] 74 | x_start = x + disp_x_start 75 | y_start = y + disp_y_start 76 | x_end = x + disp_x_end 77 | y_end = y + disp_y_end 78 | segments_list.append([x_start, y_start, x_end, y_end]) 79 | 80 | lines = 2 * np.array(segments_list) # 256 > 512 81 | lines[:, 0] = lines[:, 0] * w_ratio 82 | lines[:, 1] = lines[:, 1] * h_ratio 83 | lines[:, 2] = lines[:, 2] * w_ratio 84 | lines[:, 3] = lines[:, 3] * h_ratio 85 | 86 | return lines 87 | 88 | 89 | def pred_squares(image, 90 | model, 91 | input_shape=[512, 512], 92 | params={'score': 0.06, 93 | 'outside_ratio': 0.28, 94 | 'inside_ratio': 0.45, 95 | 'w_overlap': 0.0, 96 | 'w_degree': 1.95, 97 | 'w_length': 0.0, 98 | 'w_area': 1.86, 99 | 'w_center': 0.14}): 100 | ''' 101 | shape = [height, width] 102 | ''' 103 | h, w, _ = image.shape 104 | original_shape = [h, w] 105 | 106 | resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), 107 | np.ones([input_shape[0], input_shape[1], 1])], axis=-1) 108 | resized_image = resized_image.transpose((2, 0, 1)) 109 | batch_image = np.expand_dims(resized_image, axis=0).astype('float32') 110 | batch_image = (batch_image / 127.5) - 1.0 111 | 112 | batch_image = torch.from_numpy(batch_image).float().cuda() 113 | outputs = model(batch_image) 114 | 115 | pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) 116 | start = vmap[:, :, :2] # (x, y) 117 | end = vmap[:, :, 2:] # (x, y) 118 | dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) 119 | 120 | junc_list = [] 121 | segments_list = [] 122 | for junc, score in zip(pts, pts_score): 123 | y, x = junc 124 | distance = dist_map[y, x] 125 | if score > params['score'] and distance > 20.0: 126 | junc_list.append([x, y]) 127 | disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] 128 | d_arrow = 1.0 129 | x_start = x + d_arrow * disp_x_start 130 | y_start = y + d_arrow * disp_y_start 131 | x_end = x + d_arrow * disp_x_end 132 | y_end = y + d_arrow * disp_y_end 133 | segments_list.append([x_start, y_start, x_end, y_end]) 134 | 135 | segments = np.array(segments_list) 136 | 137 | ####### post processing for squares 138 | # 1. get unique lines 139 | point = np.array([[0, 0]]) 140 | point = point[0] 141 | start = segments[:, :2] 142 | end = segments[:, 2:] 143 | diff = start - end 144 | a = diff[:, 1] 145 | b = -diff[:, 0] 146 | c = a * start[:, 0] + b * start[:, 1] 147 | 148 | d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) 149 | theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi 150 | theta[theta < 0.0] += 180 151 | hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) 152 | 153 | d_quant = 1 154 | theta_quant = 2 155 | hough[:, 0] //= d_quant 156 | hough[:, 1] //= theta_quant 157 | _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) 158 | 159 | acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') 160 | idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 161 | yx_indices = hough[indices, :].astype('int32') 162 | acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts 163 | idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices 164 | 165 | acc_map_np = acc_map 166 | # acc_map = acc_map[None, :, :, None] 167 | # 168 | # ### fast suppression using tensorflow op 169 | # acc_map = tf.constant(acc_map, dtype=tf.float32) 170 | # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) 171 | # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) 172 | # flatten_acc_map = tf.reshape(acc_map, [1, -1]) 173 | # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) 174 | # _, h, w, _ = acc_map.shape 175 | # y = tf.expand_dims(topk_indices // w, axis=-1) 176 | # x = tf.expand_dims(topk_indices % w, axis=-1) 177 | # yx = tf.concat([y, x], axis=-1) 178 | 179 | ### fast suppression using pytorch op 180 | acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) 181 | _,_, h, w = acc_map.shape 182 | max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) 183 | acc_map = acc_map * ( (acc_map == max_acc_map).float() ) 184 | flatten_acc_map = acc_map.reshape([-1, ]) 185 | 186 | scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) 187 | yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) 188 | xx = torch.fmod(indices, w).unsqueeze(-1) 189 | yx = torch.cat((yy, xx), dim=-1) 190 | 191 | yx = yx.detach().cpu().numpy() 192 | 193 | topk_values = scores.detach().cpu().numpy() 194 | indices = idx_map[yx[:, 0], yx[:, 1]] 195 | basis = 5 // 2 196 | 197 | merged_segments = [] 198 | for yx_pt, max_indice, value in zip(yx, indices, topk_values): 199 | y, x = yx_pt 200 | if max_indice == -1 or value == 0: 201 | continue 202 | segment_list = [] 203 | for y_offset in range(-basis, basis + 1): 204 | for x_offset in range(-basis, basis + 1): 205 | indice = idx_map[y + y_offset, x + x_offset] 206 | cnt = int(acc_map_np[y + y_offset, x + x_offset]) 207 | if indice != -1: 208 | segment_list.append(segments[indice]) 209 | if cnt > 1: 210 | check_cnt = 1 211 | current_hough = hough[indice] 212 | for new_indice, new_hough in enumerate(hough): 213 | if (current_hough == new_hough).all() and indice != new_indice: 214 | segment_list.append(segments[new_indice]) 215 | check_cnt += 1 216 | if check_cnt == cnt: 217 | break 218 | group_segments = np.array(segment_list).reshape([-1, 2]) 219 | sorted_group_segments = np.sort(group_segments, axis=0) 220 | x_min, y_min = sorted_group_segments[0, :] 221 | x_max, y_max = sorted_group_segments[-1, :] 222 | 223 | deg = theta[max_indice] 224 | if deg >= 90: 225 | merged_segments.append([x_min, y_max, x_max, y_min]) 226 | else: 227 | merged_segments.append([x_min, y_min, x_max, y_max]) 228 | 229 | # 2. get intersections 230 | new_segments = np.array(merged_segments) # (x1, y1, x2, y2) 231 | start = new_segments[:, :2] # (x1, y1) 232 | end = new_segments[:, 2:] # (x2, y2) 233 | new_centers = (start + end) / 2.0 234 | diff = start - end 235 | dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) 236 | 237 | # ax + by = c 238 | a = diff[:, 1] 239 | b = -diff[:, 0] 240 | c = a * start[:, 0] + b * start[:, 1] 241 | pre_det = a[:, None] * b[None, :] 242 | det = pre_det - np.transpose(pre_det) 243 | 244 | pre_inter_y = a[:, None] * c[None, :] 245 | inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) 246 | pre_inter_x = c[:, None] * b[None, :] 247 | inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) 248 | inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') 249 | 250 | # 3. get corner information 251 | # 3.1 get distance 252 | ''' 253 | dist_segments: 254 | | dist(0), dist(1), dist(2), ...| 255 | dist_inter_to_segment1: 256 | | dist(inter,0), dist(inter,0), dist(inter,0), ... | 257 | | dist(inter,1), dist(inter,1), dist(inter,1), ... | 258 | ... 259 | dist_inter_to_semgnet2: 260 | | dist(inter,0), dist(inter,1), dist(inter,2), ... | 261 | | dist(inter,0), dist(inter,1), dist(inter,2), ... | 262 | ... 263 | ''' 264 | 265 | dist_inter_to_segment1_start = np.sqrt( 266 | np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] 267 | dist_inter_to_segment1_end = np.sqrt( 268 | np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] 269 | dist_inter_to_segment2_start = np.sqrt( 270 | np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] 271 | dist_inter_to_segment2_end = np.sqrt( 272 | np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] 273 | 274 | # sort ascending 275 | dist_inter_to_segment1 = np.sort( 276 | np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), 277 | axis=-1) # [n_batch, n_batch, 2] 278 | dist_inter_to_segment2 = np.sort( 279 | np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), 280 | axis=-1) # [n_batch, n_batch, 2] 281 | 282 | # 3.2 get degree 283 | inter_to_start = new_centers[:, None, :] - inter_pts 284 | deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi 285 | deg_inter_to_start[deg_inter_to_start < 0.0] += 360 286 | inter_to_end = new_centers[None, :, :] - inter_pts 287 | deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi 288 | deg_inter_to_end[deg_inter_to_end < 0.0] += 360 289 | 290 | ''' 291 | B -- G 292 | | | 293 | C -- R 294 | B : blue / G: green / C: cyan / R: red 295 | 296 | 0 -- 1 297 | | | 298 | 3 -- 2 299 | ''' 300 | # rename variables 301 | deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end 302 | # sort deg ascending 303 | deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) 304 | 305 | deg_diff_map = np.abs(deg1_map - deg2_map) 306 | # we only consider the smallest degree of intersect 307 | deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] 308 | 309 | # define available degree range 310 | deg_range = [60, 120] 311 | 312 | corner_dict = {corner_info: [] for corner_info in range(4)} 313 | inter_points = [] 314 | for i in range(inter_pts.shape[0]): 315 | for j in range(i + 1, inter_pts.shape[1]): 316 | # i, j > line index, always i < j 317 | x, y = inter_pts[i, j, :] 318 | deg1, deg2 = deg_sort[i, j, :] 319 | deg_diff = deg_diff_map[i, j] 320 | 321 | check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] 322 | 323 | outside_ratio = params['outside_ratio'] # over ratio >>> drop it! 324 | inside_ratio = params['inside_ratio'] # over ratio >>> drop it! 325 | check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ 326 | dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ 327 | (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ 328 | dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ 329 | ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ 330 | dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ 331 | (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ 332 | dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) 333 | 334 | if check_degree and check_distance: 335 | corner_info = None 336 | 337 | if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ 338 | (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): 339 | corner_info, color_info = 0, 'blue' 340 | elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): 341 | corner_info, color_info = 1, 'green' 342 | elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): 343 | corner_info, color_info = 2, 'black' 344 | elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ 345 | (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): 346 | corner_info, color_info = 3, 'cyan' 347 | else: 348 | corner_info, color_info = 4, 'red' # we don't use it 349 | continue 350 | 351 | corner_dict[corner_info].append([x, y, i, j]) 352 | inter_points.append([x, y]) 353 | 354 | square_list = [] 355 | connect_list = [] 356 | segments_list = [] 357 | for corner0 in corner_dict[0]: 358 | for corner1 in corner_dict[1]: 359 | connect01 = False 360 | for corner0_line in corner0[2:]: 361 | if corner0_line in corner1[2:]: 362 | connect01 = True 363 | break 364 | if connect01: 365 | for corner2 in corner_dict[2]: 366 | connect12 = False 367 | for corner1_line in corner1[2:]: 368 | if corner1_line in corner2[2:]: 369 | connect12 = True 370 | break 371 | if connect12: 372 | for corner3 in corner_dict[3]: 373 | connect23 = False 374 | for corner2_line in corner2[2:]: 375 | if corner2_line in corner3[2:]: 376 | connect23 = True 377 | break 378 | if connect23: 379 | for corner3_line in corner3[2:]: 380 | if corner3_line in corner0[2:]: 381 | # SQUARE!!! 382 | ''' 383 | 0 -- 1 384 | | | 385 | 3 -- 2 386 | square_list: 387 | order: 0 > 1 > 2 > 3 388 | | x0, y0, x1, y1, x2, y2, x3, y3 | 389 | | x0, y0, x1, y1, x2, y2, x3, y3 | 390 | ... 391 | connect_list: 392 | order: 01 > 12 > 23 > 30 393 | | line_idx01, line_idx12, line_idx23, line_idx30 | 394 | | line_idx01, line_idx12, line_idx23, line_idx30 | 395 | ... 396 | segments_list: 397 | order: 0 > 1 > 2 > 3 398 | | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | 399 | | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | 400 | ... 401 | ''' 402 | square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) 403 | connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) 404 | segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) 405 | 406 | def check_outside_inside(segments_info, connect_idx): 407 | # return 'outside or inside', min distance, cover_param, peri_param 408 | if connect_idx == segments_info[0]: 409 | check_dist_mat = dist_inter_to_segment1 410 | else: 411 | check_dist_mat = dist_inter_to_segment2 412 | 413 | i, j = segments_info 414 | min_dist, max_dist = check_dist_mat[i, j, :] 415 | connect_dist = dist_segments[connect_idx] 416 | if max_dist > connect_dist: 417 | return 'outside', min_dist, 0, 1 418 | else: 419 | return 'inside', min_dist, -1, -1 420 | 421 | top_square = None 422 | 423 | try: 424 | map_size = input_shape[0] / 2 425 | squares = np.array(square_list).reshape([-1, 4, 2]) 426 | score_array = [] 427 | connect_array = np.array(connect_list) 428 | segments_array = np.array(segments_list).reshape([-1, 4, 2]) 429 | 430 | # get degree of corners: 431 | squares_rollup = np.roll(squares, 1, axis=1) 432 | squares_rolldown = np.roll(squares, -1, axis=1) 433 | vec1 = squares_rollup - squares 434 | normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) 435 | vec2 = squares_rolldown - squares 436 | normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) 437 | inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] 438 | squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] 439 | 440 | # get square score 441 | overlap_scores = [] 442 | degree_scores = [] 443 | length_scores = [] 444 | 445 | for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): 446 | ''' 447 | 0 -- 1 448 | | | 449 | 3 -- 2 450 | 451 | # segments: [4, 2] 452 | # connects: [4] 453 | ''' 454 | 455 | ###################################### OVERLAP SCORES 456 | cover = 0 457 | perimeter = 0 458 | # check 0 > 1 > 2 > 3 459 | square_length = [] 460 | 461 | for start_idx in range(4): 462 | end_idx = (start_idx + 1) % 4 463 | 464 | connect_idx = connects[start_idx] # segment idx of segment01 465 | start_segments = segments[start_idx] 466 | end_segments = segments[end_idx] 467 | 468 | start_point = square[start_idx] 469 | end_point = square[end_idx] 470 | 471 | # check whether outside or inside 472 | start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, 473 | connect_idx) 474 | end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) 475 | 476 | cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min 477 | perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min 478 | 479 | square_length.append( 480 | dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) 481 | 482 | overlap_scores.append(cover / perimeter) 483 | ###################################### 484 | ###################################### DEGREE SCORES 485 | ''' 486 | deg0 vs deg2 487 | deg1 vs deg3 488 | ''' 489 | deg0, deg1, deg2, deg3 = degree 490 | deg_ratio1 = deg0 / deg2 491 | if deg_ratio1 > 1.0: 492 | deg_ratio1 = 1 / deg_ratio1 493 | deg_ratio2 = deg1 / deg3 494 | if deg_ratio2 > 1.0: 495 | deg_ratio2 = 1 / deg_ratio2 496 | degree_scores.append((deg_ratio1 + deg_ratio2) / 2) 497 | ###################################### 498 | ###################################### LENGTH SCORES 499 | ''' 500 | len0 vs len2 501 | len1 vs len3 502 | ''' 503 | len0, len1, len2, len3 = square_length 504 | len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 505 | len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 506 | length_scores.append((len_ratio1 + len_ratio2) / 2) 507 | 508 | ###################################### 509 | 510 | overlap_scores = np.array(overlap_scores) 511 | overlap_scores /= np.max(overlap_scores) 512 | 513 | degree_scores = np.array(degree_scores) 514 | # degree_scores /= np.max(degree_scores) 515 | 516 | length_scores = np.array(length_scores) 517 | 518 | ###################################### AREA SCORES 519 | area_scores = np.reshape(squares, [-1, 4, 2]) 520 | area_x = area_scores[:, :, 0] 521 | area_y = area_scores[:, :, 1] 522 | correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] 523 | area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) 524 | area_scores = 0.5 * np.abs(area_scores + correction) 525 | area_scores /= (map_size * map_size) # np.max(area_scores) 526 | ###################################### 527 | 528 | ###################################### CENTER SCORES 529 | centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] 530 | # squares: [n, 4, 2] 531 | square_centers = np.mean(squares, axis=1) # [n, 2] 532 | center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) 533 | center_scores = center2center / (map_size / np.sqrt(2.0)) 534 | 535 | ''' 536 | score_w = [overlap, degree, area, center, length] 537 | ''' 538 | score_w = [0.0, 1.0, 10.0, 0.5, 1.0] 539 | score_array = params['w_overlap'] * overlap_scores \ 540 | + params['w_degree'] * degree_scores \ 541 | + params['w_area'] * area_scores \ 542 | - params['w_center'] * center_scores \ 543 | + params['w_length'] * length_scores 544 | 545 | best_square = [] 546 | 547 | sorted_idx = np.argsort(score_array)[::-1] 548 | score_array = score_array[sorted_idx] 549 | squares = squares[sorted_idx] 550 | 551 | except Exception as e: 552 | pass 553 | 554 | '''return list 555 | merged_lines, squares, scores 556 | ''' 557 | 558 | try: 559 | new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] 560 | new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] 561 | new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] 562 | new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] 563 | except: 564 | new_segments = [] 565 | 566 | try: 567 | squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] 568 | squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] 569 | except: 570 | squares = [] 571 | score_array = [] 572 | 573 | try: 574 | inter_points = np.array(inter_points) 575 | inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] 576 | inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] 577 | except: 578 | inter_points = [] 579 | 580 | return new_segments, squares, score_array, inter_points 581 | -------------------------------------------------------------------------------- /workdir/pretrained_models/mobilev2_mlsd_large_512_bsize24/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/workdir/pretrained_models/mobilev2_mlsd_large_512_bsize24/best.pth -------------------------------------------------------------------------------- /workdir/pretrained_models/mobilev2_mlsd_tiny_512_bsize24/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhwcv/mlsd_pytorch/ee38c9614ccf9f6af956c50963d593288cc4ae17/workdir/pretrained_models/mobilev2_mlsd_tiny_512_bsize24/best.pth --------------------------------------------------------------------------------