├── .gitignore ├── LEGAL.md ├── LICENSE ├── configs ├── hawpv2.yaml └── plnet.yaml ├── docker └── Dockerfile ├── docs ├── HAWPv2.md ├── HAWPv3.md ├── HAWPv3.train.md └── figures │ ├── dtu-24 │ ├── 000000.png │ ├── 000001.png │ ├── 000002.png │ ├── 000003.png │ ├── 000004.png │ ├── 000005.png │ ├── 000009.png │ ├── 000015.png │ └── 000045.png │ ├── plnet │ ├── plnet.png │ └── uma_feature.png │ ├── v3-BSDS │ ├── 37073.png │ ├── 42049.png │ └── 85048.png │ ├── v3-CrowdAI │ ├── 000000000190.png │ ├── 000000000210.png │ └── 000000000230.png │ └── v3-wireframe │ ├── 00037187.png │ ├── 00051510.png │ └── 00074259.png ├── downloads.sh ├── evaluation ├── EdgeEval │ ├── CSA++ │ │ ├── GNUmakefile │ │ ├── csa.cc │ │ ├── csa.hh │ │ ├── csaAssign.m │ │ ├── csaAssign_c │ │ ├── csa_defs.h │ │ ├── csa_types.h │ │ ├── sparsify.m │ │ ├── test.cc │ │ ├── test.m │ │ └── test.txt │ ├── EdgeMapEval.cc │ ├── Util │ │ ├── Array.hh │ │ ├── Exception.cc │ │ ├── Exception.hh │ │ ├── GNUmakefile │ │ ├── GNUmakefile-library │ │ ├── Lab2RGB.m │ │ ├── Matrix.cc │ │ ├── Matrix.hh │ │ ├── Point.hh │ │ ├── RGB2Lab.m │ │ ├── Random.cc │ │ ├── Random.hh │ │ ├── Sort.hh │ │ ├── String.cc │ │ ├── String.hh │ │ ├── Timer.cc │ │ ├── Timer.hh │ │ ├── distSqr.m │ │ ├── fftconv2.m │ │ ├── gethosttype │ │ ├── isum.c │ │ ├── isum.m │ │ ├── kmeansML.m │ │ ├── kofn.cc │ │ ├── kofn.hh │ │ ├── logist2.m │ │ ├── padReflect.m │ │ ├── progbar.m │ │ ├── test │ │ └── test_cc │ ├── __init__.py │ ├── build │ │ ├── temp.linux-x86_64-3.5 │ │ │ ├── CSA++ │ │ │ │ └── csa.o │ │ │ ├── EdgeMapEval.o │ │ │ ├── Util │ │ │ │ ├── Exception.o │ │ │ │ ├── Matrix.o │ │ │ │ ├── Random.o │ │ │ │ ├── String.o │ │ │ │ ├── Timer.o │ │ │ │ └── kofn.o │ │ │ ├── correspondPixels.o │ │ │ └── match.o │ │ └── temp.linux-x86_64-3.8 │ │ │ ├── CSA++ │ │ │ └── csa.o │ │ │ ├── EdgeMapEval.o │ │ │ ├── Util │ │ │ ├── Exception.o │ │ │ ├── Matrix.o │ │ │ ├── Random.o │ │ │ ├── String.o │ │ │ ├── Timer.o │ │ │ └── kofn.o │ │ │ ├── correspondPixels.o │ │ │ └── match.o │ ├── correspondPixels.cpp │ ├── correspondPixels.h │ ├── correspondPixels.pyx │ ├── include │ │ ├── Array.hh │ │ ├── Exception.hh │ │ ├── Point.hh │ │ ├── Random.hh │ │ ├── Sort.hh │ │ ├── String.hh │ │ ├── Timer.hh │ │ ├── csa.hh │ │ ├── csa_defs.h │ │ ├── csa_types.h │ │ ├── kofn.hh │ │ └── match.hh │ ├── match.cc │ └── setup.py ├── Makefile ├── RasterizeLine │ ├── __init__.py │ ├── build │ │ ├── temp.linux-x86_64-3.5 │ │ │ ├── draw.o │ │ │ └── kernel.o │ │ └── temp.linux-x86_64-3.8 │ │ │ ├── draw.o │ │ │ └── kernel.o │ ├── draw.cpp │ ├── draw.hpp │ ├── draw.pyx │ ├── kernel.cpp │ └── setup.py ├── __init__.py ├── compute_prec_recall.py ├── draw-hap.py ├── draw-json.py ├── draw-sap.py ├── eval-json.py ├── eval-junctions.py ├── eval-sap.py ├── evaluation.py ├── example_evaluation.py ├── example_rasterline.py ├── prmeter.py ├── runs │ ├── draw-hap.sh │ ├── draw-sap.sh │ ├── draw-vis-im-york.sh │ ├── draw-vis-im.sh │ ├── draw-vis-york.sh │ ├── draw-vis.sh │ ├── eval-aph.sh │ └── sap.sh └── sAPEval │ ├── __init__.py │ └── metric.py ├── hawp ├── __init__.py ├── base │ ├── __init__.py │ ├── csrc │ │ ├── __init__.py │ │ ├── binding.cpp │ │ ├── linesegment.cu │ │ └── linesegment.h │ ├── show │ │ ├── __init__.py │ │ ├── canvas.py │ │ ├── cli.py │ │ └── painters.py │ ├── utils │ │ ├── __init__.py │ │ ├── c2_model_loading.py │ │ ├── checkpoint.py │ │ ├── comm.py │ │ ├── imports.py │ │ ├── logger.py │ │ ├── metric_evaluation.py │ │ ├── metric_logger.py │ │ ├── miscellaneous.py │ │ ├── model_serialization.py │ │ ├── model_zoo.py │ │ └── registry.py │ └── wireframe.py ├── encoder │ ├── __init__.py │ └── hafm.py ├── fsl │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── build.py │ │ ├── multi_task_head.py │ │ ├── point_line.py │ │ ├── registry.py │ │ ├── resnets.py │ │ ├── stacked_hg.py │ │ └── stacked_point_line.py │ ├── benchmark.py │ ├── config │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── defaults.py │ │ ├── detr.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── head.py │ │ │ ├── models.py │ │ │ ├── proposal_head.py │ │ │ ├── resnets.py │ │ │ └── shg.py │ │ ├── paths_catalog.py │ │ └── solver.py │ ├── dataset │ │ ├── __init__.py │ │ ├── build.py │ │ ├── imagelist.py │ │ ├── iteration_based_batch_sampler.py │ │ ├── stream.py │ │ ├── test_dataset.py │ │ ├── train_dataset.bck.py │ │ ├── train_dataset.py │ │ └── transforms.py │ ├── model │ │ ├── __init__.py │ │ ├── build.py │ │ ├── hafm.py │ │ ├── losses.py │ │ ├── misc.py │ │ └── models.py │ ├── point_model │ │ └── point_model.pth │ ├── predict.py │ ├── solver.py │ └── train.py └── ssl │ ├── __init__.py │ ├── config │ ├── __init__.py │ ├── exports │ │ ├── wireframe-100iters.yaml │ │ └── wireframe-10iters.yaml │ ├── hawpv3-hrheat.yaml │ ├── hawpv3.yaml │ ├── project_config.py │ ├── synthetic_dataset-4k.yaml │ ├── synthetic_dataset.yaml │ ├── utils.py │ └── wireframe_official_gt.yaml │ ├── datasets │ ├── __init__.py │ ├── dataset_util.py │ ├── images_dataset.py │ ├── synthetic_dataset.py │ ├── synthetic_util.py │ ├── transforms │ │ ├── __init__.py │ │ ├── homographic_transforms.py │ │ ├── photometric_transforms.py │ │ └── utils.py │ ├── wireframe.py │ ├── wireframe_dataset.py │ └── yorkurban_dataset.py │ ├── evaluate.py │ ├── homoadp-bm.py │ ├── homoadp.py │ ├── misc │ ├── __init__.py │ ├── geometry_utils.old.py │ ├── geometry_utils.py │ ├── train_utils.py │ └── visualize_util.py │ ├── models │ ├── __init__.py │ ├── base.py │ ├── detector.py │ ├── detector_with_heatmap.py │ ├── detector_with_hrheat.py │ ├── hafm.py │ ├── heatmap_decoder.py │ ├── losses.py │ └── registry.py │ ├── predict.py │ └── train.py ├── readme.md ├── requirement.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/* 2 | **/*.DS_Store 3 | parsing/data/wireframe/images/* 4 | .vscode/* 5 | data/* 6 | data 7 | 8 | **/*.pyc 9 | **/*.so 10 | data-ssl/**/* 11 | exp-ssl/**/* -------------------------------------------------------------------------------- /LEGAL.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/LEGAL.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Nan Xue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/hawpv2.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | NUM_WORKERS: 0 3 | DATASETS: 4 | DISTANCE_TH: 0.02 5 | IMAGE: 6 | HEIGHT: 512 7 | PIXEL_MEAN: 8 | - 109.73 9 | - 103.832 10 | - 98.681 11 | PIXEL_STD: 12 | - 22.275 13 | - 22.124 14 | - 23.229 15 | TO_255: true 16 | WIDTH: 512 17 | NUM_STATIC_NEGATIVE_LINES: 40 18 | NUM_STATIC_POSITIVE_LINES: 300 19 | AUGMENTATION: 5 20 | TARGET: 21 | HEIGHT: 128 22 | WIDTH: 128 23 | TEST: 24 | - wireframe_test 25 | TRAIN: 26 | - wireframe_train 27 | VAL: 28 | - wireframe_test 29 | ENCODER: 30 | ANG_TH: 0.0 31 | BACKGROUND_WEIGHT: 0.0 32 | DIS_TH: 2 33 | NUM_STATIC_NEG_LINES: 0 34 | NUM_STATIC_POS_LINES: 300 35 | MODEL: 36 | DEVICE: cuda 37 | FOCAL_LOSS: 38 | ALPHA: -1.0 39 | GAMMA: 0.0 40 | HEAD_SIZE: 41 | - - 3 42 | - - 1 43 | - - 1 44 | - - 2 45 | - - 2 46 | HGNETS: 47 | DEPTH: 4 48 | INPLANES: 64 49 | NUM_BLOCKS: 1 50 | NUM_FEATS: 128 51 | NUM_STACKS: 2 52 | LOI_POOLING: 53 | ACTIVATION: relu 54 | DIM_EDGE_FEATURE: 4 55 | DIM_FC: 1024 56 | DIM_JUNCTION_FEATURE: 128 57 | LAYER_NORM: false 58 | NUM_POINTS: 32 59 | TYPE: softmax 60 | LOSS_WEIGHTS: 61 | loss_aux: 1.0 62 | loss_dis: 1.0 63 | loss_jloc: 8.0 64 | loss_joff: 0.25 65 | loss_lineness: 1.0 66 | loss_md: 1.0 67 | loss_neg: 1.0 68 | loss_pos: 1.0 69 | loss_res: 1.0 70 | NAME: Hourglass 71 | OUT_FEATURE_CHANNELS: 256 72 | PARSING_HEAD: 73 | DIM_FC: 1024 74 | DIM_LOI: 128 75 | J2L_THRESHOLD: 10.0 76 | JMATCH_THRESHOLD: 1.5 77 | JUNCTION_HM_THRESHOLD: 0.008 #magic number 78 | MATCHING_STRATEGY: junction 79 | MAX_DISTANCE: 5.0 80 | N_DYN_JUNC: 300 81 | N_DYN_NEGL: 40 82 | N_DYN_OTHR: 0 83 | N_DYN_OTHR2: 300 84 | N_DYN_POSL: 300 85 | N_OUT_JUNC: 250 86 | N_OUT_LINE: 2500 87 | N_PTS0: 32 88 | N_PTS1: 8 89 | N_STC_NEGL: 40 90 | N_STC_POSL: 300 91 | USE_RESIDUAL: 1 92 | RESNETS: 93 | BASENET: resnet50 94 | PRETRAIN: true 95 | SCALE: 1.0 96 | WEIGHTS: '' 97 | MODELING_PATH: ihawp-v2 98 | OUTPUT_DIR: output/ihawp 99 | SOLVER: 100 | AMSGRAD: true 101 | BACKBONE_LR_FACTOR: 1.0 102 | BASE_LR: 0.0004 103 | BIAS_LR_FACTOR: 1 104 | CHECKPOINT_PERIOD: 1 105 | GAMMA: 0.1 106 | IMS_PER_BATCH: 6 107 | MAX_EPOCH: 30 108 | MOMENTUM: 0.9 109 | OPTIMIZER: ADAM 110 | STEPS: 111 | - 25 112 | WEIGHT_DECAY: 0.0001 113 | WEIGHT_DECAY_BIAS: 0 114 | -------------------------------------------------------------------------------- /configs/plnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | NUM_WORKERS: 0 3 | DATASETS: 4 | DISTANCE_TH: 0.02 5 | IMAGE: 6 | HEIGHT: 512 7 | PIXEL_MEAN: 8 | - 109.73 9 | - 103.832 10 | - 98.681 11 | PIXEL_STD: 12 | - 22.275 13 | - 22.124 14 | - 23.229 15 | TO_255: true 16 | WIDTH: 512 17 | NUM_STATIC_NEGATIVE_LINES: 40 18 | NUM_STATIC_POSITIVE_LINES: 300 19 | AUGMENTATION: 5 20 | TARGET: 21 | HEIGHT: 128 22 | WIDTH: 128 23 | TEST: 24 | - wireframe_test 25 | TRAIN: 26 | - wireframe_train 27 | VAL: 28 | - wireframe_test 29 | ENCODER: 30 | ANG_TH: 0.0 31 | BACKGROUND_WEIGHT: 0.0 32 | DIS_TH: 2 33 | NUM_STATIC_NEG_LINES: 0 34 | NUM_STATIC_POS_LINES: 300 35 | MODEL: 36 | DEVICE: cuda 37 | FOCAL_LOSS: 38 | ALPHA: -1.0 39 | GAMMA: 0.0 40 | HEAD_SIZE: 41 | - - 3 42 | - - 1 43 | - - 1 44 | - - 2 45 | - - 2 46 | HGNETS: 47 | DEPTH: 4 48 | INPLANES: 64 49 | NUM_BLOCKS: 2 50 | NUM_FEATS: 128 51 | NUM_STACKS: 2 52 | LOI_POOLING: 53 | ACTIVATION: relu 54 | DIM_EDGE_FEATURE: 4 55 | # DIM_FC: 1024 56 | DIM_FC: 128 57 | DIM_JUNCTION_FEATURE: 128 58 | LAYER_NORM: false 59 | NUM_POINTS: 32 60 | TYPE: softmax 61 | LOSS_WEIGHTS: 62 | loss_aux: 1.0 # 2 63 | loss_dis: 1.0 # 2 64 | loss_jloc: 8.0 # 2 65 | loss_joff: 0.25 # 2 66 | loss_res: 1.0 # 2 67 | loss_md: 1.0 # 2 68 | loss_lineness: 1.0 # 1 69 | loss_neg: 1.0 # 1 70 | loss_pos: 1.0 # 1 71 | NAME: PointLine 72 | OUT_FEATURE_CHANNELS: 256 73 | PARSING_HEAD: 74 | DIM_FC: 1024 75 | DIM_LOI: 128 76 | J2L_THRESHOLD: 10.0 77 | JMATCH_THRESHOLD: 1.5 78 | JUNCTION_HM_THRESHOLD: 0.008 #magic number 79 | MATCHING_STRATEGY: junction 80 | MAX_DISTANCE: 5.0 81 | N_DYN_JUNC: 300 82 | N_DYN_NEGL: 40 83 | N_DYN_OTHR: 0 84 | N_DYN_OTHR2: 300 85 | N_DYN_POSL: 300 86 | N_OUT_JUNC: 250 87 | N_OUT_LINE: 2500 88 | N_PTS0: 32 89 | N_PTS1: 8 90 | N_STC_NEGL: 40 91 | N_STC_POSL: 300 92 | USE_RESIDUAL: 1 93 | RESNETS: 94 | BASENET: resnet50 95 | PRETRAIN: true 96 | SCALE: 1.0 97 | WEIGHTS: '' 98 | MODELING_PATH: ihawp-v2 99 | OUTPUT_DIR: output/ihawp 100 | SOLVER: 101 | AMSGRAD: true 102 | BACKBONE_LR_FACTOR: 1.0 103 | BASE_LR: 0.0004 104 | BIAS_LR_FACTOR: 1 105 | CHECKPOINT_PERIOD: 1 106 | GAMMA: 0.2 107 | IMS_PER_BATCH: 6 108 | MAX_EPOCH: 40 109 | MOMENTUM: 0.9 110 | OPTIMIZER: ADAM 111 | STEPS: 112 | - 35 113 | WEIGHT_DECAY: 0.0001 114 | WEIGHT_DECAY_BIAS: 0 115 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.6.0-devel-ubuntu20.04 2 | 3 | # Prevent stop building ubuntu at time zone selection. 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | RUN apt-get update && apt-get install -y \ 7 | git \ 8 | build-essential \ 9 | vim \ 10 | openssh-server \ 11 | python3-opencv \ 12 | ca-certificates \ 13 | python3-dev \ 14 | python3-pip \ 15 | wget \ 16 | ninja-build \ 17 | mesa-common-dev \ 18 | libgl1-mesa-dev \ 19 | libglu1-mesa-dev \ 20 | xauth 21 | 22 | RUN sed -i "s/^.*X11Forwarding.*$/X11Forwarding yes/" /etc/ssh/sshd_config \ 23 | && sed -i "s/^.*X11UseLocalhost.*$/X11UseLocalhost no/" /etc/ssh/sshd_config \ 24 | && grep "^X11UseLocalhost" /etc/ssh/sshd_config || echo "X11UseLocalhost no" >> /etc/ssh/sshd_config 25 | 26 | RUN ln -sv /usr/bin/python3 /usr/bin/python 27 | 28 | RUN pip3 install numpy==1.23.4 29 | RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116 30 | 31 | RUN echo 'root:root' | chpasswd 32 | 33 | RUN echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config 34 | 35 | RUN service ssh start 36 | 37 | CMD ["/usr/sbin/sshd","-D"] 38 | -------------------------------------------------------------------------------- /docs/HAWPv2.md: -------------------------------------------------------------------------------- 1 | # HAWPv2: Learning Wireframes via Fully-Supervised Learning 2 | 3 | *The codes of HAWPv2 are placed in the directory of [hawp/fsl](../hawp/fsl).* 4 | ## Quickstart & Evaluation 5 | - Please download the dataset and checkpoints as in [readme.md](../readme.md). 6 | - Run the following command line(s) to evaluate the offical model on the Wireframe dataset and YorkUrban dataset by 7 | 8 |
9 | Evaluation on the Wireframe dataset. 10 | 11 | ```bash 12 | python -m hawp.fsl.benchmark configs/hawpv2.yaml \ 13 | --ckpt checkpoints/hawpv2-edb9b23f.pth \ 14 | --dataset wireframe 15 | ``` 16 |
17 | 18 |
19 | Evaluation on the YorkUrban dataset. 20 | 21 | ```bash 22 | python -m hawp.fsl.benchmark configs/hawpv2.yaml \ 23 | --ckpt checkpoints/hawpv2-edb9b23f.pth \ 24 | --dataset york 25 | ``` 26 |
27 | 28 | ## Evaluation Results 29 | 30 | |Dataset|sAP-5|sAP-10|sAP-15|command line|comment| 31 | |--|--|--|--|--|--| 32 | |Wireframe| 65.8 | 69.8 |71.4|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset wireframe --jhm=0.001``|jhm = 0.001| 33 | |Wireframe| 65.7 | 69.8 |71.4|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset wireframe --jhm=0.005``|jhm = 0.005| 34 | |Wireframe| 65.7 | 69.7 |71.3|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset wireframe --jhm=0.008``|jhm = 0.008 (default setting)| 35 | |YorkUrban|29.0|31.4|32.8|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset york --jhm=0.005``|jhm = 0.001|jhm=0.001 36 | |YorkUrban|28.9|31.4|32.7|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset york --jhm=0.005``|jhm = 0.005|jhm = 0.005 37 | |YorkUrban|28.8|31.3|32.6|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset york --jhm=0.005``|jhm = 0.008|jhm = 0.008 (default setting) 38 | 39 | # Training 40 | - Run the following command line to train the HAWPv2 on the Wireframe dataset. 41 | ``` 42 | python -m hawp.fsl.train configs/hawpv2.yaml --logdir outputs 43 | ``` 44 | 45 | - The usage of [hawp.fsl.train](hawp/fsl/../../../hawp/fsl/train.py) is as follow: 46 | ```dotnetcli 47 | HAWPv2 Training 48 | 49 | positional arguments: 50 | config path to config file 51 | 52 | optional arguments: 53 | -h, --help show this help message and exit 54 | --logdir LOGDIR 55 | --resume RESUME 56 | --clean 57 | --seed SEED 58 | --tf32 toggle on the TF32 of pytorch 59 | --dtm {True,False} toggle the deterministic option of CUDNN. This option will affect the replication of experiments 60 | 61 | ``` -------------------------------------------------------------------------------- /docs/HAWPv3.md: -------------------------------------------------------------------------------- 1 | # HAWPv3: Learning Wireframes via Self-Supervised Learning 2 | 3 | *The codes of HAWPv3 are placed in the directory of [hawp/ssl](../hawp/ssl).* 4 | 5 | |Model Name|Comments|MD5| 6 | |---|---|---| 7 | |[hawpv3-fdc5487a.pth](https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv3/hawpv3-fdc5487a.pth)| Trained on the images of Wireframe dataset | fdc5487a43e3d42f6b2addf79d8b930d 8 | |[hawpv3-imagenet-03a84.pth](https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv3/hawpv3-imagenet-03a84.pth)| Trained on 100k images of ImageNet dataset| 03a8400e9474320f2b42973d1ba19487| 9 | 10 | ### Inference on your own images 11 | 12 | - Run the following command line to obtain wireframes from HAWPv3 model 13 |
14 | hawpv3-fdc5487a.pth 15 | 16 | python -m hawp.ssl.predict --ckpt checkpoints/hawpv3-fdc5487a.pth \ 17 | --threshold 0.05 \ 18 | --img {filename.png} 19 |
20 | 21 |
22 | hawpv3-imagenet-03a84.pth 23 | 24 | python -m hawp.ssl.predict --ckpt checkpoints/hawpv3-imagenet-03a84.pth \ 25 | --threshold 0.05 \ 26 | --img {filename.png} 27 |
28 | 29 | - A running example on the DTU-24 images 30 | ```bash 31 | python -m hawp.ssl.predict --ckpt checkpoints/hawpv3-imagenet-03a84.pth \ 32 | --threshold 0.05 \ 33 | --img ~/datasets/DTU/scan24/image/*.png \ 34 | --saveto docs/figures/dtu-24 --ext png \ 35 | ``` 36 |

37 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 |

48 | 49 | ## Training 50 | 51 | ```bash 52 | 53 | python -m hawp.ssl.train --help 54 | usage: train.py [-h] --datacfg DATACFG --modelcfg MODELCFG --name NAME 55 | [--pretrained PRETRAINED] [--overwrite] [--tf32] 56 | [--dtm {True,False}] [--batch-size BATCH_SIZE] 57 | [--num-workers NUM_WORKERS] [--base-lr BASE_LR] 58 | [--steps STEPS [STEPS ...]] [--gamma GAMMA] [--epochs EPOCHS] 59 | [--seed SEED] [--iterations ITERATIONS] 60 | 61 | optional arguments: 62 | -h, --help show this help message and exit 63 | --datacfg DATACFG filepath of the data config 64 | --modelcfg MODELCFG filepath of the model config 65 | --name NAME the name of experiment 66 | --pretrained PRETRAINED 67 | the pretrained model 68 | --overwrite [Caution!] the option to overwrite an existed 69 | experiment 70 | --tf32 toggle on the TF32 of pytorch 71 | --dtm {True,False} toggle the deterministic option of CUDNN. This option 72 | will affect the replication of experiments 73 | 74 | training recipe: 75 | --batch-size BATCH_SIZE 76 | the batch size of training 77 | --num-workers NUM_WORKERS 78 | the number of workers for training 79 | --base-lr BASE_LR the initial learning rate 80 | --steps STEPS [STEPS ...] 81 | the steps of the scheduler 82 | --gamma GAMMA the lr decay factor 83 | --epochs EPOCHS the number of epochs for training 84 | --seed SEED the random seed for training 85 | --iterations ITERATIONS 86 | the number of training iterations 87 | 88 | ``` -------------------------------------------------------------------------------- /docs/HAWPv3.train.md: -------------------------------------------------------------------------------- 1 | # Training Recipes of HAWPv3 2 | 3 | *HAWPv3 consists of multiple training phases including the synthetic training phase and the real data training.* 4 | 5 | ## Step 0: Synthetic learning 6 | 7 | ``` 8 | python -m hawp.ssl.train \ 9 | --datacfg hawp/ssl/config/synthetic_dataset.yaml \ 10 | --modelcfg hawp/ssl/config/hawpv3.yaml \ 11 | --base-lr 0.0004 \ 12 | --epochs 10 \ 13 | --batch-size 6 \ 14 | --name hawpv3-round0 15 | ``` 16 | 17 | ## Step 1: Homographic Adaptation for Pseudo Wireframe Generation 18 | 19 | If you prefer single-image mode to minimize the usage of GPU memory footprint, please use the following command to obtain pseudo labels 20 | ``` 21 | python -m hawp.ssl.homoadp --metarch HAWP-heatmap \ 22 | --datacfg hawp/ssl/config/export/wireframe-10iters.yaml \ 23 | --workdir exp-ssl/hawpv3-round0 \ 24 | --epoch 10 \ 25 | --modelcfg exp-ssl/hawpv3-round0/model.yaml \ 26 | --min_score 0.75 27 | 28 | ``` 29 | 30 | For the batch mode, please use the following command 31 | ``` 32 | python -m hawp.ssl.homoadp-bm --metarch HAWP-heatmap \ 33 | --datacfg hawp/ssl/config/exports/wireframe-10iters.yaml \ 34 | --workdir exp-ssl/hawpv3-round0 \ 35 | --epoch 10 \ 36 | --modelcfg exp-ssl/hawpv3-round0/model.yaml \ 37 | --min-score 0.75 --batch-size=16 38 | ``` 39 | *On my machine (NVIDIA A6000), the batch size of 16 will take 40G GPU memory in 43 minutes to generate the wireframe labels for 20k images of the training images in the Wireframe dataset (Huang et al., CVPR 2018)* 40 | 41 | ### Remarks 42 | 1. After the homographic adaptation step finished, three files will be generated and stored at the directory of ```data-ssl/{name}```, where ```{name}``` is the name for the last round of training. For example, if we train HAWPv3 with the name of ```hawpv3-round0```, the exported data will be saved at ```data-ssl/hawpv3-round0```. 43 | 44 | 2. For each generated wireframe and its auxiluary files, their names are started with ```{hash}-model-{epoch:05d}```. 45 | 46 | 3. In sum, the generated datacfg ``YAML`` file is located at ``data-ssl/{name}/{hash}-model-{epoch:05d}.yaml``. 47 | ## Step 2: Learning from Real-World images 48 | 49 | - Once we have the pseudo wireframe labels, we can train HAWPv3 on the real-world images. An example usage is in the below command line: 50 | ``` 51 | python -m hawp.ssl.train --datacfg data-ssl/export_datasets/{name}/{hash}-model-00010.yaml --modelcfg hawp/ssl/config/hawpv3.yaml --base-lr 0.0004 --epochs 30 --name hawpv3-round1 --batch-size 6 52 | ``` 53 | 54 | - Then, we can run the homographic adaptation with new model checkpoints trained on real-world images, and then train/fine-tune a new model to further improve the repeatibility. 55 | -------------------------------------------------------------------------------- /docs/figures/dtu-24/000000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000000.png -------------------------------------------------------------------------------- /docs/figures/dtu-24/000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000001.png -------------------------------------------------------------------------------- /docs/figures/dtu-24/000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000002.png -------------------------------------------------------------------------------- /docs/figures/dtu-24/000003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000003.png -------------------------------------------------------------------------------- /docs/figures/dtu-24/000004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000004.png -------------------------------------------------------------------------------- /docs/figures/dtu-24/000005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000005.png -------------------------------------------------------------------------------- /docs/figures/dtu-24/000009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000009.png -------------------------------------------------------------------------------- /docs/figures/dtu-24/000015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000015.png -------------------------------------------------------------------------------- /docs/figures/dtu-24/000045.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000045.png -------------------------------------------------------------------------------- /docs/figures/plnet/plnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/plnet/plnet.png -------------------------------------------------------------------------------- /docs/figures/plnet/uma_feature.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/plnet/uma_feature.png -------------------------------------------------------------------------------- /docs/figures/v3-BSDS/37073.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-BSDS/37073.png -------------------------------------------------------------------------------- /docs/figures/v3-BSDS/42049.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-BSDS/42049.png -------------------------------------------------------------------------------- /docs/figures/v3-BSDS/85048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-BSDS/85048.png -------------------------------------------------------------------------------- /docs/figures/v3-CrowdAI/000000000190.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-CrowdAI/000000000190.png -------------------------------------------------------------------------------- /docs/figures/v3-CrowdAI/000000000210.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-CrowdAI/000000000210.png -------------------------------------------------------------------------------- /docs/figures/v3-CrowdAI/000000000230.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-CrowdAI/000000000230.png -------------------------------------------------------------------------------- /docs/figures/v3-wireframe/00037187.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-wireframe/00037187.png -------------------------------------------------------------------------------- /docs/figures/v3-wireframe/00051510.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-wireframe/00051510.png -------------------------------------------------------------------------------- /docs/figures/v3-wireframe/00074259.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-wireframe/00074259.png -------------------------------------------------------------------------------- /downloads.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | wget https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv2/hawpv2-edb9b23f.pth -P checkpoints 4 | wget https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv3/hawpv3-ce3ae2cb.pth -P checkpoints 5 | wget https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv3/hawpv3-fdc5487a.pth -P checkpoints 6 | wget https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv3/hawpv3-imagenet-03a84.pth -P checkpoints 7 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/CSA++/GNUmakefile: -------------------------------------------------------------------------------- 1 | # use gmake! 2 | 3 | srcs := csa.cc 4 | hdrs := csa.hh csa_types.h csa_defs.h 5 | matlab := csaAssign.m sparsify.m 6 | mex := csaAssign.cc 7 | lib := libcsa.a 8 | 9 | cxxFlags := -O3 10 | mexFlags := 11 | 12 | include ../Util/GNUmakefile-library 13 | 14 | runtest: 15 | $(cxx) $(cxxFlags) -o test $(srcs) test.cc 16 | ./test 17 | 18 | clean:: 19 | -rm -f test 20 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/CSA++/csa.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "csa.hh" 3 | 4 | char* CSA::err_messages[] = 5 | { 6 | "Can't read from the input file.", 7 | "Not a correct assignment problem line.", 8 | "Error reading a node descriptor from the input.", 9 | "Error reading an arc descriptor from the input.", 10 | "Unknown line type in the input", 11 | "Inconsistent number of arcs in the input.", 12 | "Parsing noncontiguous node ID numbers not implemented.", 13 | "Can't obtain enough memory to solve this problem.", 14 | }; 15 | 16 | char* CSA::nomem_msg = "Insufficient memory.\n"; 17 | 18 | CSA::CSA (int n, int m, const int* graph) 19 | { 20 | assert(n>0); 21 | assert(m>0); 22 | assert(graph!=NULL); 23 | assert((n%2)==0); 24 | _init(n,m); 25 | main(graph); 26 | } 27 | 28 | CSA::~CSA () 29 | { 30 | _delete(); 31 | } 32 | 33 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/CSA++/csaAssign.m: -------------------------------------------------------------------------------- 1 | % [edges] = csaAssign(n,graph) 2 | % 3 | % Compute min-cost assignment with non-negative integral edge weights 4 | % using Andrew Goldberg's CSA package (precise costs version). 5 | % 6 | % INPUT 7 | % n Number of nodes in the bipartite graph (must be even). 8 | % graph 3xm matrix describing graph. 9 | % 10 | % OUTPUT 11 | % edges 3xn matrix of edges in assignment. 12 | % 13 | % You must ensure that an assignment involving all nodes exists, else 14 | % the code may hang. This is a feature of the CSA package. If your 15 | % problem does not necessarily provide such an assignment, then you 16 | % should overlay a high-cost perfect match as a safety net. 17 | % 18 | % Both graph and edges matrices have the same structure. Each column 19 | % gives a graph edge e. The two nodes are given by e(1) and e(2): 20 | % 21 | % e(1) < e(2) 22 | % 1 <= e(1) <= n/2 23 | % n/2 < e(2) <= n 24 | % 25 | % The edge weight is given by e(3). 26 | % 27 | % Since the output edge matrix should contain one reference to each 28 | % node, sum(sum(edges(1:2,:))) == n*(n+1)/2. 29 | % 30 | % David Martin 31 | % January, 2003 32 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/CSA++/csaAssign_c: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include "csa.hh" 5 | 6 | extern "C" { 7 | 8 | void 9 | mexFunction ( 10 | int nlhs, mxArray* plhs[], 11 | int nrhs, const mxArray* prhs[]) 12 | { 13 | // Check argument counts. 14 | if (nlhs < 1) { 15 | mexErrMsgTxt("Too few output arguments."); 16 | } 17 | if (nlhs > 1) { 18 | mexErrMsgTxt("Too many output arguments."); 19 | } 20 | if (nrhs < 2) { 21 | mexErrMsgTxt("Too few input arguments."); 22 | } 23 | if (nrhs > 2) { 24 | mexErrMsgTxt("Too many input arguments."); 25 | } 26 | 27 | // Get input arguments. 28 | const int n = (int) mxGetScalar (prhs[0]); 29 | const double* g = mxGetPr (prhs[1]); 30 | const int three = mxGetM (prhs[1]); 31 | const int m = mxGetN (prhs[1]); 32 | 33 | // Check input arguments. 34 | if (n < 1) { 35 | mexErrMsgTxt("n must be >0"); 36 | } 37 | if ((n%2) != 0) { 38 | mexErrMsgTxt("n must be even"); 39 | } 40 | if (m < 1) { 41 | mexErrMsgTxt("m must be >0"); 42 | } 43 | if (three != 3) { 44 | mexErrMsgTxt("graph matrix must be 3xM"); 45 | } 46 | 47 | // Build the input graph and check the data. 48 | int* graph = new int[m*3]; 49 | int maxc = 0; 50 | for (int i = 0; i < m; i++) { 51 | int a = (int) g[3*i+0]; 52 | int b = (int) g[3*i+1]; 53 | int c = (int) g[3*i+2]; 54 | graph[3*i+0] = a; 55 | graph[3*i+1] = b; 56 | graph[3*i+2] = c; 57 | if (a < 1 || a > n/2) { 58 | mexErrMsgTxt("edge tail not in [1,n/2]"); 59 | } 60 | if (b <= n/2 || b > n) { 61 | mexErrMsgTxt("edge head not in (n/2,n]"); 62 | } 63 | if (c < 0) { 64 | mexErrMsgTxt("edge weights must be non-negative"); 65 | } 66 | maxc = std::max(c,maxc); 67 | } 68 | 69 | // The CSA package segfaults if all the edge weights are zero. 70 | // In this case, set all the weights to one, and then later 71 | // remember to set the returned graph weights back to zero. 72 | if (maxc == 0) { 73 | for (int i = 0; i < m; i++) { 74 | graph[3*i+2] = 1; 75 | } 76 | } 77 | 78 | // Run CSA. It will either run successfully or segfault or loop 79 | // forever or return garbage. But it claims to always return a 80 | // valid result if the input is valid. The checks above try to 81 | // ensure that the input is ok, but I don't check that a perfect 82 | // match is present (which CSA requires but does not check for, 83 | // grumble, grumble). In that case, you're on your own, since 84 | // I'm not sure how to quickly check for that condition. 85 | CSA csa (n, m, graph); 86 | int e = csa.edges(); 87 | 88 | // Done with input graph. 89 | delete [] graph; 90 | graph = NULL; 91 | 92 | // Construct result. 93 | plhs[0] = mxCreateDoubleMatrix(3, e, mxREAL); 94 | double* points = mxGetPr(plhs[0]); 95 | for (int i = 0; i < e; i++) { 96 | int a, b, cost; 97 | csa.edge(i,a,b,cost); 98 | points[i*3+0] = a; 99 | points[i*3+1] = b; 100 | points[i*3+2] = (maxc==0) ? 0 : cost; 101 | } 102 | } 103 | 104 | }; // extern "C" 105 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/CSA++/sparsify.m: -------------------------------------------------------------------------------- 1 | function graph = sparsify(sm) 2 | % function graph = sparsify(sm) 3 | % 4 | % Given a smilarity matrix sm, return a sparse graph representation 5 | % suitable for csaAssign. For example, if sm is an nxn similarity 6 | % matrix, then Hungarian-style matching can be accomplished by 7 | % executing csaAssign(2*n,sparsify(sm)). 8 | % 9 | % See also csaAssign. 10 | % 11 | % David Martin 12 | % March, 2003 13 | 14 | if ndims(sm)~=2 | size(sm,1)~=size(sm,2), 15 | error('sm must be a square matrix'); 16 | end 17 | 18 | n = size(sm,1); 19 | graph = zeros(3,n*n); 20 | graph(1,:) = repmat(1:n,1,n); 21 | graph(2,:) = n + reshape(repmat(1:n,n,1),1,n*n); 22 | graph(3,:) = reshape(sm,1,n*n); 23 | 24 | 25 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/CSA++/test.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "csa.hh" 3 | 4 | int n = 6; 5 | int m = 6; 6 | static int data[] = { 7 | 1, 4, 3, 8 | 2, 5, 3, 9 | 3, 6, 3, 10 | 1, 5, 1, 11 | 2, 6, 1, 12 | 3, 4, 5, 13 | }; 14 | 15 | int 16 | main(int argc, char** argv) 17 | { 18 | for (int iter = 0; iter < 10; iter++) { 19 | CSA csa (n, m, data); 20 | for (int i = 0; i < csa.edges(); i++) { 21 | int a, b, cost; 22 | csa.edge(i,a,b,cost); 23 | fprintf (stderr, "%d %d %d\n", a, b, cost); 24 | } 25 | fprintf (stderr, "TOTAL %d\n", csa.cost()); 26 | } 27 | return 0; 28 | } 29 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/CSA++/test.m: -------------------------------------------------------------------------------- 1 | % toy test 2 | g1 = [ 1 2 3 1 2 3 ; 3 | 4 5 6 5 6 4 ; 4 | 3 3 3 1 1 5 ]; 5 | n = 6; 6 | e1 = csaAssign(n,g1) 7 | 8 | % big random test 9 | n = 1000; 10 | dg = (rand(n,n) > 0.5); 11 | m = sum(dg(:)); 12 | i = find(dg==1)' - 1; 13 | g2 = [ 1 + floor(i/n) ; 14 | 1 + mod(i,n) + n ; 15 | 1 + floor(rand(1,m)*1000) ]; 16 | tic; 17 | e2 = csaAssign(2*n,g2); 18 | toc; 19 | if sum(e2(1,:)) ~= n*(n+1)/2, error('bug'); end 20 | if sum(e2(2,:)) ~= n*(n+1)/2 + n*n, error('bug'); end 21 | if sum(sum(e2(1:2,:))) ~= 2*n*(2*n+1)/2, error('bug'); end 22 | disp('[n m cost] = '); 23 | [n m sum(e2(3,:))] 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/CSA++/test.txt: -------------------------------------------------------------------------------- 1 | p asn 6 6 2 | n 1 3 | n 2 4 | n 3 5 | a 1 4 3 6 | a 2 5 3 7 | a 3 6 3 8 | a 1 5 1 9 | a 2 6 1 10 | a 3 4 5 11 | 12 | 13 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/EdgeMapEval.cc: -------------------------------------------------------------------------------- 1 | #include "Matrix.hh" 2 | #include "match.hh" 3 | #include 4 | #include 5 | #include 6 | #include "correspondPixels.h" 7 | 8 | 9 | void correspondPixels(double* bmap1, 10 | double* bmap2, 11 | double* match1, 12 | double* match2, 13 | double& cost, 14 | double maxDist, 15 | double outlierCost, 16 | int height, 17 | int width) 18 | { 19 | const int rows = width; 20 | const int cols = height; 21 | 22 | const double idiag = sqrt(float(rows*rows+cols*cols)); 23 | const double oc = outlierCost*maxDist*idiag; 24 | 25 | Matrix m1,m2; 26 | cost = matchEdgeMaps( 27 | Matrix(rows,cols,bmap1), Matrix(rows,cols,bmap2), 28 | maxDist*idiag, oc, m1, m2); 29 | 30 | memcpy(match1, m1.data(), m1.numel()*sizeof(double)); 31 | memcpy(match2, m2.data(), m2.numel()*sizeof(double)); 32 | } -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/Exception.cc: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (C) 2002 David R. Martin 3 | // 4 | // This program is free software; you can redistribute it and/or 5 | // modify it under the terms of the GNU General Public License as 6 | // published by the Free Software Foundation; either version 2 of the 7 | // License, or (at your option) any later version. 8 | // 9 | // This program is distributed in the hope that it will be useful, but 10 | // WITHOUT ANY WARRANTY; without even the implied warranty of 11 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 12 | // General Public License for more details. 13 | // 14 | // You should have received a copy of the GNU General Public License 15 | // along with this program; if not, write to the Free Software 16 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 17 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 18 | 19 | #include 20 | #include 21 | #include "Exception.hh" 22 | 23 | Exception::Exception (const char* msg) 24 | : _msg (strdup (msg)) 25 | { 26 | } 27 | 28 | Exception::Exception (const Exception& that) 29 | : _msg (strdup (that._msg)) 30 | { 31 | } 32 | 33 | Exception::~Exception () 34 | { 35 | free (_msg); 36 | } 37 | 38 | const char* 39 | Exception::msg () const 40 | { 41 | return _msg; 42 | } 43 | 44 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/Exception.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Exception_hh__ 3 | #define __Exception_hh__ 4 | 5 | // A simple exception class that contains an error message. 6 | 7 | // Copyright (C) 2002 David R. Martin 8 | // 9 | // This program is free software; you can redistribute it and/or 10 | // modify it under the terms of the GNU General Public License as 11 | // published by the Free Software Foundation; either version 2 of the 12 | // License, or (at your option) any later version. 13 | // 14 | // This program is distributed in the hope that it will be useful, but 15 | // WITHOUT ANY WARRANTY; without even the implied warranty of 16 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 17 | // General Public License for more details. 18 | // 19 | // You should have received a copy of the GNU General Public License 20 | // along with this program; if not, write to the Free Software 21 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 22 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 23 | 24 | #include 25 | 26 | class Exception 27 | { 28 | public: 29 | 30 | // Always construct exception with a message, so we can print 31 | // a useful error/log message. 32 | Exception (const char* msg); 33 | 34 | // We need to implement the copy constructor so that rethrowing 35 | // works. 36 | Exception (const Exception& that); 37 | 38 | virtual ~Exception (); 39 | 40 | // Retrieve the message that this exception carries. 41 | virtual const char* msg () const; 42 | 43 | protected: 44 | 45 | char* _msg; 46 | 47 | }; 48 | 49 | // write to output stream 50 | inline std::ostream& operator<< (std::ostream& out, const Exception& e) { 51 | out << e.msg(); 52 | return out; 53 | } 54 | 55 | #endif // __Exception_hh__ 56 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/GNUmakefile: -------------------------------------------------------------------------------- 1 | # use gmake! 2 | 3 | srcs := Exception.cc String.cc Random.cc Timer.cc Matrix.cc kofn.cc 4 | hdrs := Exception.hh String.hh Random.hh Timer.hh Matrix.hh \ 5 | Sort.hh Point.hh Array.hh kofn.hh 6 | matlab := isum.m kmeansML.m distSqr.m fftconv2.m padReflect.m \ 7 | Lab2RGB.m RGB2Lab.m logist2.m progbar.m 8 | mex := isum.c 9 | mexLibs := -lutil 10 | 11 | lib := libutil.a 12 | cxxFlags := -O3 -DNOBLAS 13 | 14 | include ./GNUmakefile-library 15 | 16 | # test of Matrix module 17 | test: 18 | # g++ -g -Wall -o test test.cc build/ix86_linux/libutil.a -lgsl -lgslcblas -lm 19 | # g++ -g -Wall -o test test.cc build/ix86_linux/libutil.a -L/usr/mill/lib -lblas -lpgf90 -lpgf90_rpm1 -lpgf902 -lpgf90rtl -lpgftnrtl -lpgc -lm 20 | # g++ -g -Wall -o test test.cc build/ix86_linux/libutil.a -lblas -lg2c -lm 21 | # g++ -g -Wall -o test test.cc -L./build/ix86_linux -L/home/cs/dmartin/lib/$(hostType) -lutil -lf77blas -latlas -lg2c -lm 22 | g++ -g -Wall -o test test.cc build/ix86_linux/libutil.a -lblas -lm 23 | 24 | # eof 25 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/Lab2RGB.m: -------------------------------------------------------------------------------- 1 | function [R, G, B] = Lab2RGB(L, a, b) 2 | % function [R, G, B] = Lab2RGB(L, a, b) 3 | % Lab2RGB takes matrices corresponding to L, a, and b in CIELab space 4 | % and transforms them into RGB. This transform is based on ITU-R 5 | % Recommendation BT.709 using the D65 white point reference. 6 | % and the error in transforming RGB -> Lab -> RGB is approximately 7 | % 10^-5. By Mark Ruzon from C code by Yossi Rubner, 23 September 1997. 8 | % Updated for MATLAB 5 28 January 1998. 9 | % Fixed a bug in conversion back to uint8 9 September 1999. 10 | 11 | if (nargin == 1) 12 | b = L(:,:,3); 13 | a = L(:,:,2); 14 | L = L(:,:,1); 15 | end 16 | 17 | % Thresholds 18 | T1 = 0.008856; 19 | T2 = 0.206893; 20 | 21 | [M, N] = size(L); 22 | s = M * N; 23 | L = reshape(L, 1, s); 24 | a = reshape(a, 1, s); 25 | b = reshape(b, 1, s); 26 | 27 | % Compute Y 28 | fY = ((L + 16) / 116) .^ 3; 29 | YT = fY > T1; 30 | fY = (~YT) .* (L / 903.3) + YT .* fY; 31 | Y = fY; 32 | 33 | % Alter fY slightly for further calculations 34 | fY = YT .* (fY .^ (1/3)) + (~YT) .* (7.787 .* fY + 16/116); 35 | 36 | % Compute X 37 | fX = a / 500 + fY; 38 | XT = fX > T2; 39 | X = (XT .* (fX .^ 3) + (~XT) .* ((fX - 16/116) / 7.787)); 40 | 41 | % Compute Z 42 | fZ = fY - b / 200; 43 | ZT = fZ > T2; 44 | Z = (ZT .* (fZ .^ 3) + (~ZT) .* ((fZ - 16/116) / 7.787)); 45 | 46 | X = X * 0.950456; 47 | Z = Z * 1.088754; 48 | 49 | MAT = [ 3.240479 -1.537150 -0.498535; 50 | -0.969256 1.875992 0.041556; 51 | 0.055648 -0.204043 1.057311]; 52 | 53 | RGB = max(min(MAT * [X; Y; Z], 1), 0); 54 | 55 | R = reshape(RGB(1,:), M, N) * 255; 56 | G = reshape(RGB(2,:), M, N) * 255; 57 | B = reshape(RGB(3,:), M, N) * 255; 58 | 59 | if ((nargout == 1) | (nargout == 0)) 60 | R = uint8(round(cat(3,R,G,B))); 61 | end 62 | 63 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/Point.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Point_hh__ 3 | #define __Point_hh__ 4 | 5 | // Simple point template classes. 6 | // Probably only make sense for intrinsic types. 7 | 8 | // 2D Points 9 | 10 | template 11 | class Point2D 12 | { 13 | public: 14 | Point2D () { x = 0; y = 0; } 15 | Point2D (T x, T y) { this->x = x; this->y = y; } 16 | T x,y; 17 | }; 18 | 19 | template 20 | inline int operator== (const Point2D& a, const Point2D& b) 21 | { return (a.x == b.x) && (a.y == b.y); } 22 | 23 | template 24 | inline int operator!= (const Point2D& a, const Point2D& b) 25 | { return (a.x != b.x) || (a.y != b.y); } 26 | 27 | typedef Point2D Pixel; 28 | 29 | // 3D Points 30 | 31 | template 32 | class Point3D 33 | { 34 | public: 35 | Point3D () { x = 0; y = 0; z = 0; } 36 | Point3D (T x, T y) { this->x = x; this->y = y; this->z = z;} 37 | T x,y,z; 38 | }; 39 | 40 | template 41 | inline int operator== (const Point3D& a, const Point3D& b) 42 | { return (a.x == b.x) && (a.y == b.y) && (a.z == b.z); } 43 | 44 | template 45 | inline int operator!= (const Point3D& a, const Point3D& b) 46 | { return (a.x != b.x) || (a.y != b.y) || (a.z != b.z); } 47 | 48 | typedef Point3D Voxel; 49 | 50 | #endif // __Point_hh__ 51 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/RGB2Lab.m: -------------------------------------------------------------------------------- 1 | function [L,a,b] = RGB2Lab(R,G,B) 2 | % function [L, a, b] = RGB2Lab(R, G, B) 3 | % RGB2Lab takes matrices corresponding to Red, Green, and Blue, and 4 | % transforms them into CIELab. This transform is based on ITU-R 5 | % Recommendation BT.709 using the D65 white point reference. 6 | % The error in transforming RGB -> Lab -> RGB is approximately 7 | % 10^-5. RGB values can be either between 0 and 1 or between 0 and 255. 8 | % By Mark Ruzon from C code by Yossi Rubner, 23 September 1997. 9 | % Updated for MATLAB 5 28 January 1998. 10 | 11 | if (nargin == 1) 12 | B = double(R(:,:,3)); 13 | G = double(R(:,:,2)); 14 | R = double(R(:,:,1)); 15 | end 16 | 17 | if ((max(max(R)) > 1.0) | (max(max(G)) > 1.0) | (max(max(B)) > 1.0)) 18 | R = R/255; 19 | G = G/255; 20 | B = B/255; 21 | end 22 | 23 | [M, N] = size(R); 24 | s = M*N; 25 | 26 | % Set a threshold 27 | T = 0.008856; 28 | 29 | RGB = [reshape(R,1,s); reshape(G,1,s); reshape(B,1,s)]; 30 | 31 | % RGB to XYZ 32 | MAT = [0.412453 0.357580 0.180423; 33 | 0.212671 0.715160 0.072169; 34 | 0.019334 0.119193 0.950227]; 35 | XYZ = MAT * RGB; 36 | 37 | X = XYZ(1,:) / 0.950456; 38 | Y = XYZ(2,:); 39 | Z = XYZ(3,:) / 1.088754; 40 | 41 | XT = X > T; 42 | YT = Y > T; 43 | ZT = Z > T; 44 | 45 | fX = XT .* X.^(1/3) + (~XT) .* (7.787 .* X + 16/116); 46 | 47 | % Compute L 48 | Y3 = Y.^(1/3); 49 | fY = YT .* Y3 + (~YT) .* (7.787 .* Y + 16/116); 50 | L = YT .* (116 * Y3 - 16.0) + (~YT) .* (903.3 * Y); 51 | 52 | fZ = ZT .* Z.^(1/3) + (~ZT) .* (7.787 .* Z + 16/116); 53 | 54 | % Compute a and b 55 | a = 500 * (fX - fY); 56 | b = 200 * (fY - fZ); 57 | 58 | L = reshape(L, M, N); 59 | a = reshape(a, M, N); 60 | b = reshape(b, M, N); 61 | 62 | if ((nargout == 1) | (nargout == 0)) 63 | L = cat(3,L,a,b); 64 | end -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/Random.cc: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "Random.hh" 8 | #include "String.hh" 9 | #include "Exception.hh" 10 | 11 | // Copyright (C) 2002 David R. Martin 12 | // 13 | // This program is free software; you can redistribute it and/or 14 | // modify it under the terms of the GNU General Public License as 15 | // published by the Free Software Foundation; either version 2 of the 16 | // License, or (at your option) any later version. 17 | // 18 | // This program is distributed in the hope that it will be useful, but 19 | // WITHOUT ANY WARRANTY; without even the implied warranty of 20 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 21 | // General Public License for more details. 22 | // 23 | // You should have received a copy of the GNU General Public License 24 | // along with this program; if not, write to the Free Software 25 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 26 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 27 | 28 | Random Random::rand; 29 | 30 | Random::Random () 31 | { 32 | reseed (0); 33 | } 34 | 35 | Random::Random (u_int64_t seed) 36 | { 37 | reseed (seed); 38 | } 39 | 40 | Random::Random (Random& that) 41 | { 42 | u_int64_t a = that.ui32 (); 43 | u_int64_t b = that.ui32 (); 44 | u_int64_t seed = (a << 32) | b; 45 | _init (seed); 46 | } 47 | 48 | void 49 | Random::reset () 50 | { 51 | _init (_seed); 52 | } 53 | 54 | void 55 | Random::reseed (u_int64_t seed) 56 | { 57 | if (seed == 0) { 58 | struct timeval t; 59 | gettimeofday (&t, NULL); 60 | u_int64_t a = (t.tv_usec >> 3) & 0xffff; 61 | u_int64_t b = t.tv_sec & 0xffff; 62 | u_int64_t c = (t.tv_sec >> 16) & 0xffff; 63 | seed = a | (b << 16) | (c << 32); 64 | } 65 | _init (seed); 66 | } 67 | 68 | void 69 | Random::_init (u_int64_t seed) 70 | { 71 | _seed = seed & 0xffffffffffffull; 72 | _xsubi[0] = (seed >> 0) & 0xffff; 73 | _xsubi[1] = (seed >> 16) & 0xffff; 74 | _xsubi[2] = (seed >> 32) & 0xffff; 75 | } 76 | 77 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/Random.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Random_hh__ 3 | #define __Random_hh__ 4 | 5 | // Copyright (C) 2002 David R. Martin 6 | // 7 | // This program is free software; you can redistribute it and/or 8 | // modify it under the terms of the GNU General Public License as 9 | // published by the Free Software Foundation; either version 2 of the 10 | // License, or (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, but 13 | // WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 | // General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program; if not, write to the Free Software 19 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 20 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | // All random numbers are generated from a single seed. This is true 28 | // even when private random streams (seperate from the global 29 | // Random::rand stream) are spawned from existing streams, since the new 30 | // streams are seeded automatically from the parent's random stream. 31 | // Any random stream can be reset so that a sequence of random values 32 | // can be replayed. 33 | 34 | // If seed==0, then the seed is generated from the system clock. 35 | 36 | class Random 37 | { 38 | public: 39 | 40 | static Random rand; 41 | 42 | // These are defined in as the limits of int, but 43 | // here we need the limits of int32_t. 44 | static const int32_t int32_max = 2147483647; 45 | static const int32_t int32_min = -int32_max-1; 46 | static const u_int32_t u_int32_max = 4294967295u; 47 | 48 | // Seed from the system clock. 49 | Random (); 50 | 51 | // Specify seed. 52 | // If zero, seed from the system clock. 53 | Random (u_int64_t seed); 54 | 55 | // Spawn off a new random stream seeded from the parent's stream. 56 | Random (Random& that); 57 | 58 | // Restore initial seed so we can replay a random sequence. 59 | void reset (); 60 | 61 | // Set the seed. 62 | // If zero, seed from the system clock. 63 | void reseed (u_int64_t seed); 64 | 65 | // double in [0..1) or [a..b) 66 | inline double fp (); 67 | inline double fp (double a, double b); 68 | 69 | // 32-bit signed integer in [-2^31,2^31) or [a..b] 70 | inline int32_t i32 (); 71 | inline int32_t i32 (int32_t a, int32_t b); 72 | 73 | // 32-bit unsigned integer in [0,2^32) or [a..b] 74 | inline u_int32_t ui32 (); 75 | inline u_int32_t ui32 (u_int32_t a, u_int32_t b); 76 | 77 | protected: 78 | 79 | void _init (u_int64_t seed); 80 | 81 | // The original seed for this random stream. 82 | u_int64_t _seed; 83 | 84 | // The current state for this random stream. 85 | u_int16_t _xsubi[3]; 86 | 87 | }; 88 | 89 | inline u_int32_t 90 | Random::ui32 () 91 | { 92 | return ui32(0,u_int32_max); 93 | } 94 | 95 | inline u_int32_t 96 | Random::ui32 (u_int32_t a, u_int32_t b) 97 | { 98 | assert (a <= b); 99 | double x = fp (); 100 | return (u_int32_t) floor (x * ((double)b - (double)a + 1) + a); 101 | } 102 | 103 | inline int32_t 104 | Random::i32 () 105 | { 106 | return i32(int32_min,int32_max); 107 | } 108 | 109 | inline int32_t 110 | Random::i32 (int32_t a, int32_t b) 111 | { 112 | assert (a <= b); 113 | double x = fp (); 114 | return (int32_t) floor (x * ((double)b - (double)a + 1) + a); 115 | } 116 | 117 | inline double 118 | Random::fp () 119 | { 120 | return erand48 (_xsubi); 121 | } 122 | 123 | inline double 124 | Random::fp (double a, double b) 125 | { 126 | assert (a < b); 127 | return erand48 (_xsubi) * (b - a) + a; 128 | } 129 | 130 | #endif // __Random_hh__ 131 | 132 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/Timer.cc: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (C) 2002 David R. Martin 3 | // 4 | // This program is free software; you can redistribute it and/or 5 | // modify it under the terms of the GNU General Public License as 6 | // published by the Free Software Foundation; either version 2 of the 7 | // License, or (at your option) any later version. 8 | // 9 | // This program is distributed in the hope that it will be useful, but 10 | // WITHOUT ANY WARRANTY; without even the implied warranty of 11 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 12 | // General Public License for more details. 13 | // 14 | // You should have received a copy of the GNU General Public License 15 | // along with this program; if not, write to the Free Software 16 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 17 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 18 | 19 | #include 20 | #include 21 | #include 22 | #include "Timer.hh" 23 | 24 | typedef unsigned long long uint64; 25 | 26 | void 27 | Timer::_compute () 28 | { 29 | // Compute elapsed time. 30 | long sec = _elapsed_stop.tv_sec - _elapsed_start.tv_sec; 31 | long usec = _elapsed_stop.tv_usec - _elapsed_start.tv_usec; 32 | if (usec < 0) { 33 | sec -= 1; 34 | usec += 1000000; 35 | } 36 | _elapsed += (double) sec + usec / 1e6; 37 | 38 | // Computer CPU user and system times. 39 | _user += (double) (_cpu_stop.tms_utime - _cpu_start.tms_utime) 40 | / sysconf(_SC_CLK_TCK); 41 | _system += (double) (_cpu_stop.tms_stime - _cpu_start.tms_stime) 42 | / sysconf(_SC_CLK_TCK); 43 | } 44 | 45 | // Convert time in seconds into a nice human-friendly format: h:mm:ss.ss 46 | // Return a pointer to a static buffer. 47 | const char* 48 | Timer::formatTime (double sec, int precision) 49 | { 50 | static char buf[128]; 51 | 52 | // Limit range of precision for safety and sanity. 53 | precision = (precision < 0) ? 0 : precision; 54 | precision = (precision > 9) ? 9 : precision; 55 | uint64 base = 1; 56 | for (int digit = 0; digit < precision; digit++) { base *= 10;} 57 | 58 | bool neg = (sec < 0); 59 | uint64 ticks = (uint64) rint (fabs (sec) * base); 60 | uint64 rsec = ticks / base; // Rounded seconds. 61 | uint64 frac = ticks % base; 62 | 63 | uint64 h = rsec / 3600; 64 | uint64 m = (rsec / 60) % 60; 65 | uint64 s = rsec % 60; 66 | 67 | sprintf (buf, "%s%llu:%02llu:%02llu", 68 | neg ? "-" : "", h, m, s); 69 | 70 | if (precision > 0) { 71 | static char fmt[10]; 72 | sprintf (fmt, ".%%0%dlld", precision); 73 | sprintf (buf + strlen (buf), fmt, frac); 74 | } 75 | 76 | return buf; 77 | } 78 | 79 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/Timer.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Timer_hh__ 3 | #define __Timer_hh__ 4 | 5 | // Copyright (C) 2002 David R. Martin 6 | // 7 | // This program is free software; you can redistribute it and/or 8 | // modify it under the terms of the GNU General Public License as 9 | // published by the Free Software Foundation; either version 2 of the 10 | // License, or (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, but 13 | // WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 | // General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program; if not, write to the Free Software 19 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 20 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | class Timer 29 | { 30 | public: 31 | 32 | inline Timer (); 33 | inline ~Timer (); 34 | 35 | inline void start (); 36 | inline void stop (); 37 | inline void reset (); 38 | 39 | // All times are in seconds. 40 | inline double cpu (); 41 | inline double user (); 42 | inline double system (); 43 | inline double elapsed (); 44 | 45 | // Convert time in seconds into a nice human-friendly format: h:mm:ss.ss 46 | // Precision is the number of digits after the decimal. 47 | // Return a pointer to a static buffer. 48 | static const char* formatTime (double sec, int precision = 2); 49 | 50 | private: 51 | 52 | void _compute (); 53 | 54 | enum State { stopped, running }; 55 | 56 | State _state; 57 | 58 | struct timeval _elapsed_start; 59 | struct timeval _elapsed_stop; 60 | double _elapsed; 61 | 62 | struct tms _cpu_start; 63 | struct tms _cpu_stop; 64 | double _user; 65 | double _system; 66 | }; 67 | 68 | Timer::Timer () 69 | { 70 | reset (); 71 | } 72 | 73 | Timer::~Timer () 74 | { 75 | } 76 | 77 | void 78 | Timer::reset () 79 | { 80 | _state = stopped; 81 | _elapsed = _user = _system = 0; 82 | } 83 | 84 | void 85 | Timer::start () 86 | { 87 | assert (_state == stopped); 88 | _state = running; 89 | gettimeofday (&_elapsed_start, NULL); 90 | times (&_cpu_start); 91 | } 92 | 93 | void 94 | Timer::stop () 95 | { 96 | assert (_state == running); 97 | gettimeofday (&_elapsed_stop, NULL); 98 | times (&_cpu_stop); 99 | _compute (); 100 | _state = stopped; 101 | } 102 | 103 | double 104 | Timer::cpu () 105 | { 106 | assert (_state == stopped); 107 | return _user + _system; 108 | } 109 | 110 | double 111 | Timer::user () 112 | { 113 | assert (_state == stopped); 114 | return _user; 115 | } 116 | 117 | double 118 | Timer::system () 119 | { 120 | assert (_state == stopped); 121 | return _system; 122 | } 123 | 124 | double 125 | Timer::elapsed () 126 | { 127 | assert (_state == stopped); 128 | return _elapsed; 129 | } 130 | 131 | #endif // __Timer_hh__ 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/distSqr.m: -------------------------------------------------------------------------------- 1 | function z = distSqr(x,y) 2 | % function z = distSqr(x,y) 3 | % 4 | % Return matrix of all-pairs squared distances between the vectors 5 | % in the columns of x and y. 6 | % 7 | % INPUTS 8 | % x dxn matrix of vectors 9 | % y dxm matrix of vectors 10 | % 11 | % OUTPUTS 12 | % z nxm matrix of squared distances 13 | % 14 | % This routine is faster when mn. 15 | % 16 | % David Martin 17 | % March 2003 18 | 19 | % Based on dist2.m code, 20 | % Copyright (c) Christopher M Bishop, Ian T Nabney (1996, 1997) 21 | 22 | if size(x,1)~=size(y,1), 23 | error('size(x,1)~=size(y,1)'); 24 | end 25 | 26 | [d,n] = size(x); 27 | [d,m] = size(y); 28 | 29 | % z = repmat(sum(x.^2)',1,m) ... 30 | % + repmat(sum(y.^2),n,1) ... 31 | % - 2*x'*y; 32 | 33 | z = x'*y; 34 | x2 = sum(x.^2)'; 35 | y2 = sum(y.^2); 36 | for i = 1:m, 37 | z(:,i) = x2 + y2(i) - 2*z(:,i); 38 | end 39 | 40 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/fftconv2.m: -------------------------------------------------------------------------------- 1 | function fim = fftconv2(im,f) 2 | % function fim = fftconv2(im,f) 3 | % 4 | % Convolution using FFT. 5 | % 6 | % David R. Martin 7 | % March 2003 8 | 9 | % wrap the filter around the origin and pad with zeros 10 | padf = zeros(size(im)); 11 | r = floor(size(f,1)/2); 12 | padf(1:r+1,1:r+1) = f(r+1:end,r+1:end); 13 | padf(1:r,end-r+1:end) = f(r+2:end,1:r); 14 | padf(end-r+1:end,1:r) = f(1:r,r+2:end); 15 | padf(end-r+1:end,end-r+1:end) = f(1:r,1:r); 16 | 17 | % magic 18 | fftim = fft2(im); 19 | fftf = fft2(padf); 20 | fim = real(ifft2(fftim.*fftf)); 21 | 22 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/gethosttype: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 3 | # Create a canonical system name "machine_system" for use in 4 | # bin and lib paths. 5 | # 6 | # If a name cannot be found, then "unknown" is the output, 7 | # and the exit status is 1. 8 | # 9 | 10 | uname_machine=`(uname -m) 2> /dev/null` || machine=unknown 11 | uname_system=`(uname -s) 2> /dev/null` || system=unknown 12 | 13 | case $uname_machine in 14 | i*86) 15 | machine=ix86;; 16 | ia64) 17 | machine=ia64;; 18 | IP27) 19 | machine=mips;; 20 | sun4*) 21 | machine=sparc;; 22 | x86_64) 23 | machine=x86_64;; 24 | *) 25 | machine=unknown;; 26 | esac 27 | 28 | if [ $machine = "unknown" ]; then 29 | echo unknown 30 | exit 1 31 | fi 32 | 33 | case $uname_system in 34 | Linux) 35 | system=linux;; 36 | IRIX*) 37 | system=irix;; 38 | SunOS) 39 | system=solaris;; 40 | NetBSD) 41 | system=netbsd;; 42 | *) 43 | system=unknown;; 44 | esac 45 | 46 | if [ $system = "unknown" ]; then 47 | echo unknown 48 | exit 1 49 | fi 50 | 51 | echo ${machine}_${system} 52 | exit 0 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/isum.c: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | 6 | // If you have trouble mexifying this file, then consider using 7 | // the MATLAB code (isum.m) instead. So long as your MATLAB 8 | // version is at least 6.5, you won't suffer too much of a 9 | // performance penalty. 10 | 11 | void 12 | mexFunction ( 13 | int nlhs, mxArray* plhs[], 14 | int nrhs, const mxArray* prhs[]) 15 | { 16 | // Check number of arguments. 17 | if (nlhs < 1) { 18 | mexErrMsgTxt("Too few output arguments."); 19 | } 20 | if (nlhs > 1) { 21 | mexErrMsgTxt("Too many output arguments."); 22 | } 23 | if (nrhs < 3) { 24 | mexErrMsgTxt("Too few input arguments."); 25 | } 26 | if (nrhs > 3) { 27 | mexErrMsgTxt("Too many input arguments."); 28 | } 29 | 30 | const double* x = mxGetPr(prhs[0]); 31 | const double* idx = mxGetPr(prhs[1]); 32 | int nbins = (int)mxGetScalar(prhs[2]); 33 | if (nbins < 0) { nbins = 0; } 34 | 35 | // Check arguments. 36 | const int n = mxGetNumberOfElements(prhs[0]); 37 | if (n != mxGetNumberOfElements(prhs[1])) { 38 | mexErrMsgTxt("x and idx must be the same size"); 39 | } 40 | 41 | // Do the reduction. 42 | plhs[0] = mxCreateDoubleMatrix(nbins,1,mxREAL); 43 | double* acc = mxGetPr(plhs[0]); 44 | memset(acc,0,nbins*sizeof(*acc)); 45 | for (int i = 0; i < n; i++) { 46 | int v = (int)idx[i]; 47 | if (v < 1) { continue; } 48 | if (v > nbins) { continue; } 49 | acc[v-1] += x[i]; 50 | } 51 | } 52 | 53 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/isum.m: -------------------------------------------------------------------------------- 1 | function acc = isum(x,idx,nbins) 2 | % function acc = isum(x,idx,nbins) 3 | % 4 | % Indexed sum reduction, where acc(i) contains the sum of 5 | % x(find(idx==i)). 6 | % 7 | % The mex version is 300x faster in R12, and 4x faster in R13. As far 8 | % as I can tell, there is no way to do this efficiently in matlab R12. 9 | % 10 | % David R. Martin 11 | % March 2003 12 | 13 | acc = zeros(nbins,1); 14 | for i = 1:numel(x), 15 | if idx(i)<1, continue; end 16 | if idx(i)>nbins, continue; end 17 | acc(idx(i)) = acc(idx(i)) + x(i); 18 | end 19 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/kofn.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "Random.hh" 3 | #include "kofn.hh" 4 | 5 | // O(n) implementation. 6 | static void 7 | _kOfN_largeK (int k, int n, int* values) 8 | { 9 | assert (k > 0); 10 | assert (k <= n); 11 | int j = 0; 12 | for (int i = 0; i < n; i++) { 13 | double prob = (double) (k - j) / (n - i); 14 | assert (prob <= 1); 15 | double x = Random::rand.fp (); 16 | if (x < prob) { 17 | values[j++] = i; 18 | } 19 | } 20 | assert (j == k); 21 | } 22 | 23 | // O(k*lg(k)) implementation; constant factor is about 2x the constant 24 | // factor for the O(n) implementation. 25 | static void 26 | _kOfN_smallK (int k, int n, int* values) 27 | { 28 | assert (k > 0); 29 | assert (k <= n); 30 | if (k == 1) { 31 | values[0] = Random::rand.i32 (0, n - 1); 32 | return; 33 | } 34 | int leftN = n / 2; 35 | int rightN = n - leftN; 36 | int leftK = 0; 37 | int rightK = 0; 38 | for (int i = 0; i < k; i++) { 39 | int x = Random::rand.i32 (0, n - i - 1); 40 | if (x < leftN - leftK) { 41 | leftK++; 42 | } else { 43 | rightK++; 44 | } 45 | } 46 | if (leftK > 0) { _kOfN_smallK (leftK, leftN, values); } 47 | if (rightK > 0) { _kOfN_smallK (rightK, rightN, values + leftK); } 48 | for (int i = leftK; i < k; i++) { 49 | values[i] += leftN; 50 | } 51 | } 52 | 53 | // Return k randomly selected integers from the interval [0,n), in 54 | // increasing sorted order. 55 | void 56 | kOfN (int k, int n, int* values) 57 | { 58 | assert (k >= 0); 59 | assert (n >= 0); 60 | if (k == 0) { return; } 61 | static double log2 = log (2); 62 | double klogk = k * log (k) / log2; 63 | if (klogk < n / 2) { 64 | _kOfN_smallK (k, n, values); 65 | } else { 66 | _kOfN_largeK (k, n, values); 67 | } 68 | } 69 | 70 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/kofn.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __kofn_hh__ 3 | #define __kofn_hh__ 4 | 5 | void kOfN (int k, int n, int* values); 6 | 7 | #endif // __kofn_hh__ 8 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/logist2.m: -------------------------------------------------------------------------------- 1 | function [beta,p,lli] = logist2(y,x,w) 2 | % [beta,p,lli] = logist2(y,x) 3 | % 4 | % 2-class logistic regression. 5 | % 6 | % INPUT 7 | % y Nx1 colum vector of 0|1 class assignments 8 | % x NxK matrix of input vectors as rows 9 | % [w] Nx1 vector of sample weights 10 | % 11 | % OUTPUT 12 | % beta Kx1 column vector of model coefficients 13 | % p Nx1 column vector of fitted class 1 posteriors 14 | % lli log likelihood 15 | % 16 | % Class 1 posterior is 1 / (1 + exp(-x*beta)) 17 | % 18 | % David Martin 19 | % April 16, 2002 20 | 21 | % Copyright (C) 2002 David R. Martin 22 | % 23 | % This program is free software; you can redistribute it and/or 24 | % modify it under the terms of the GNU General Public License as 25 | % published by the Free Software Foundation; either version 2 of the 26 | % License, or (at your option) any later version. 27 | % 28 | % This program is distributed in the hope that it will be useful, but 29 | % WITHOUT ANY WARRANTY; without even the implied warranty of 30 | % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 31 | % General Public License for more details. 32 | % 33 | % You should have received a copy of the GNU General Public License 34 | % along with this program; if not, write to the Free Software 35 | % Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 36 | % 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 37 | 38 | error(nargchk(2,3,nargin)); 39 | 40 | % check inputs 41 | if size(y,2) ~= 1, 42 | error('Input y not a column vector.'); 43 | end 44 | if size(y,1) ~= size(x,1), 45 | error('Input x,y sizes mismatched.'); 46 | end 47 | 48 | % get sizes 49 | [N,k] = size(x); 50 | 51 | % if sample weights weren't specified, set them to 1 52 | if nargin < 3, 53 | w = 1; 54 | end 55 | 56 | % normalize sample weights so max is 1 57 | w = w / max(w); 58 | 59 | % initial guess for beta: all zeros 60 | beta = zeros(k,1); 61 | 62 | % Newton-Raphson via IRLS, 63 | % taken from Hastie/Tibshirani/Friedman Section 4.4. 64 | iter = 0; 65 | lli = 0; 66 | while 1==1, 67 | iter = iter + 1; 68 | 69 | % fitted probabilities 70 | p = 1 ./ (1 + exp(-x*beta)); 71 | 72 | % log likelihood 73 | lli_prev = lli; 74 | lli = sum( w .* (y.*log(p+eps) + (1-y).*log(1-p+eps)) ); 75 | 76 | % least-squares weights 77 | wt = w .* p .* (1-p); 78 | 79 | % derivatives of likelihood w.r.t. beta 80 | deriv = x'*(w.*(y-p)); 81 | 82 | % Hessian of likelihood w.r.t. beta 83 | % hessian = x'Wx, where W=diag(w) 84 | % Do it this way to be memory efficient and fast. 85 | hess = zeros(k,k); 86 | for i = 1:k, 87 | wxi = wt .* x(:,i); 88 | for j = i:k, 89 | hij = wxi' * x(:,j); 90 | hess(i,j) = -hij; 91 | hess(j,i) = -hij; 92 | end 93 | end 94 | 95 | % make sure Hessian is well conditioned 96 | if (rcond(hess) < eps), 97 | error(['Stopped at iteration ' num2str(iter) ... 98 | ' because Hessian is poorly conditioned.']); 99 | break; 100 | end; 101 | 102 | % Newton-Raphson update step 103 | step = hess\deriv; 104 | beta = beta - step; 105 | 106 | % termination criterion based on derivatives 107 | tol = 1e-6; 108 | if abs(deriv'*step/k) < tol, break; end; 109 | 110 | % termination criterion based on log likelihood 111 | % tol = 1e-4; 112 | % if abs((lli-lli_prev)/(lli+lli_prev)) < 0.5*tol, break; end; 113 | end; 114 | 115 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/padReflect.m: -------------------------------------------------------------------------------- 1 | function [impad] = padReflect(im,r) 2 | % function [impad] = padReflect(im,r) 3 | % 4 | % Pad an image with a border of size r, and reflect the image into 5 | % the border. 6 | % 7 | % David R. Martin 8 | % March 2003 9 | 10 | impad = zeros(size(im)+2*r); 11 | impad(r+1:end-r,r+1:end-r) = im; % middle 12 | impad(1:r,r+1:end-r) = flipud(im(1:r,:)); % top 13 | impad(end-r+1:end,r+1:end-r) = flipud(im(end-r+1:end,:)); % bottom 14 | impad(r+1:end-r,1:r) = fliplr(im(:,1:r)); % left 15 | impad(r+1:end-r,end-r+1:end) = fliplr(im(:,end-r+1:end)); % right 16 | impad(1:r,1:r) = flipud(fliplr(im(1:r,1:r))); % top-left 17 | impad(1:r,end-r+1:end) = flipud(fliplr(im(1:r,end-r+1:end))); % top-right 18 | impad(end-r+1:end,1:r) = flipud(fliplr(im(end-r+1:end,1:r))); % bottom-left 19 | impad(end-r+1:end,end-r+1:end) = flipud(fliplr(im(end-r+1:end,end-r+1:end))); % bottom-right 20 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/progbar.m: -------------------------------------------------------------------------------- 1 | function progbar(i,n,w) 2 | % function progbar(i,n,w) 3 | % 4 | % Display a textual progress bar. 5 | % 6 | % INPUTS 7 | % i Iteration number. 8 | % n Number of iterations. 9 | % [w=50] Width of bar. 10 | % 11 | % EXAMPLE 12 | % 13 | % progbar(0,n); 14 | % for i = 1:n, 15 | % compute(); 16 | % progbar(i,n); 17 | % end 18 | % 19 | % David R. Martin 20 | % April 2002 21 | 22 | if nargin<3, w=50; end 23 | w = min(w,n); 24 | 25 | if i==0, 26 | fwrite(2,'['); 27 | for c = 1:w, fwrite(2,'.'); end 28 | fwrite(2,']'); 29 | for c = 1:w+1, fwrite(2,sprintf('\b')); end 30 | return 31 | end 32 | 33 | if mod(i,n/w) <= mod(i-1,n/w), 34 | fwrite(2,'='); 35 | end 36 | 37 | if i==n, 38 | fprintf(2,'\n'); 39 | end 40 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/Util/test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/Util/test -------------------------------------------------------------------------------- /evaluation/EdgeEval/__init__.py: -------------------------------------------------------------------------------- 1 | from .EdgeMapEval import correspond 2 | 3 | __all__ = ['correspond'] 4 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.5/CSA++/csa.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/CSA++/csa.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.5/EdgeMapEval.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/EdgeMapEval.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Exception.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Exception.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Matrix.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Matrix.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Random.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Random.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/String.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/String.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Timer.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Timer.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/kofn.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/kofn.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.5/correspondPixels.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/correspondPixels.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.5/match.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/match.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.8/CSA++/csa.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/CSA++/csa.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.8/EdgeMapEval.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/EdgeMapEval.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Exception.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Exception.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Matrix.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Matrix.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Random.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Random.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/String.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/String.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Timer.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Timer.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/kofn.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/kofn.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.8/correspondPixels.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/correspondPixels.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/build/temp.linux-x86_64-3.8/match.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/match.o -------------------------------------------------------------------------------- /evaluation/EdgeEval/correspondPixels.h: -------------------------------------------------------------------------------- 1 | void correspondPixels(double* bmap1, /*input*/ 2 | double* bmap2, /*input*/ 3 | double* match1, /*output*/ 4 | double* match2, /*output*/ 5 | double& cost, /*output*/ 6 | double maxDist, /*parameters*/ 7 | double outlierCost, /*parameters*/ 8 | int height, /*image size*/ 9 | int width /*image size*/); -------------------------------------------------------------------------------- /evaluation/EdgeEval/correspondPixels.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | 4 | assert sizeof(int) == sizeof(np.int32_t) 5 | 6 | cdef extern from "correspondPixels.h": 7 | void correspondPixels(double* bmap1, 8 | double* bmap2, 9 | double* match1, 10 | double* match2, 11 | double& cost, 12 | double maxDist, 13 | double outlierCost, 14 | int height, 15 | int width) 16 | 17 | def correspond(np.ndarray[double,ndim=2] bmap1, 18 | np.ndarray[double,ndim=2] bmap2, 19 | double maxDist = 0.001, 20 | double outlierCost = 100): 21 | 22 | assert bmap1.shape[0] == bmap2.shape[0] and bmap1.shape[1] == bmap2.shape[1] 23 | 24 | cdef int height = bmap1.shape[0] 25 | cdef int width = bmap1.shape[1] 26 | 27 | cdef np.ndarray[np.double_t, ndim=2] \ 28 | match1 = np.zeros((height,width), dtype=np.double) 29 | 30 | cdef np.ndarray[np.double_t, ndim=2] \ 31 | match2 = np.zeros((height,width), dtype=np.double) 32 | cdef double cost = 0 33 | correspondPixels( bmap1.data, 34 | bmap2.data, 35 | match1.data, 36 | match2.data, 37 | cost, 38 | maxDist, 39 | outlierCost, 40 | height, 41 | width) 42 | 43 | return match1, match2, cost 44 | 45 | 46 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/include/Exception.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Exception_hh__ 3 | #define __Exception_hh__ 4 | 5 | // A simple exception class that contains an error message. 6 | 7 | // Copyright (C) 2002 David R. Martin 8 | // 9 | // This program is free software; you can redistribute it and/or 10 | // modify it under the terms of the GNU General Public License as 11 | // published by the Free Software Foundation; either version 2 of the 12 | // License, or (at your option) any later version. 13 | // 14 | // This program is distributed in the hope that it will be useful, but 15 | // WITHOUT ANY WARRANTY; without even the implied warranty of 16 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 17 | // General Public License for more details. 18 | // 19 | // You should have received a copy of the GNU General Public License 20 | // along with this program; if not, write to the Free Software 21 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 22 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 23 | 24 | #include 25 | 26 | class Exception 27 | { 28 | public: 29 | 30 | // Always construct exception with a message, so we can print 31 | // a useful error/log message. 32 | Exception (const char* msg); 33 | 34 | // We need to implement the copy constructor so that rethrowing 35 | // works. 36 | Exception (const Exception& that); 37 | 38 | virtual ~Exception (); 39 | 40 | // Retrieve the message that this exception carries. 41 | virtual const char* msg () const; 42 | 43 | protected: 44 | 45 | char* _msg; 46 | 47 | }; 48 | 49 | // write to output stream 50 | inline std::ostream& operator<< (std::ostream& out, const Exception& e) { 51 | out << e.msg(); 52 | return out; 53 | } 54 | 55 | #endif // __Exception_hh__ 56 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/include/Point.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Point_hh__ 3 | #define __Point_hh__ 4 | 5 | // Simple point template classes. 6 | // Probably only make sense for intrinsic types. 7 | 8 | // 2D Points 9 | 10 | template 11 | class Point2D 12 | { 13 | public: 14 | Point2D () { x = 0; y = 0; } 15 | Point2D (T x, T y) { this->x = x; this->y = y; } 16 | T x,y; 17 | }; 18 | 19 | template 20 | inline int operator== (const Point2D& a, const Point2D& b) 21 | { return (a.x == b.x) && (a.y == b.y); } 22 | 23 | template 24 | inline int operator!= (const Point2D& a, const Point2D& b) 25 | { return (a.x != b.x) || (a.y != b.y); } 26 | 27 | typedef Point2D Pixel; 28 | 29 | // 3D Points 30 | 31 | template 32 | class Point3D 33 | { 34 | public: 35 | Point3D () { x = 0; y = 0; z = 0; } 36 | Point3D (T x, T y) { this->x = x; this->y = y; this->z = z;} 37 | T x,y,z; 38 | }; 39 | 40 | template 41 | inline int operator== (const Point3D& a, const Point3D& b) 42 | { return (a.x == b.x) && (a.y == b.y) && (a.z == b.z); } 43 | 44 | template 45 | inline int operator!= (const Point3D& a, const Point3D& b) 46 | { return (a.x != b.x) || (a.y != b.y) || (a.z != b.z); } 47 | 48 | typedef Point3D Voxel; 49 | 50 | #endif // __Point_hh__ 51 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/include/Timer.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Timer_hh__ 3 | #define __Timer_hh__ 4 | 5 | // Copyright (C) 2002 David R. Martin 6 | // 7 | // This program is free software; you can redistribute it and/or 8 | // modify it under the terms of the GNU General Public License as 9 | // published by the Free Software Foundation; either version 2 of the 10 | // License, or (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, but 13 | // WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 | // General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program; if not, write to the Free Software 19 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 20 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | class Timer 29 | { 30 | public: 31 | 32 | inline Timer (); 33 | inline ~Timer (); 34 | 35 | inline void start (); 36 | inline void stop (); 37 | inline void reset (); 38 | 39 | // All times are in seconds. 40 | inline double cpu (); 41 | inline double user (); 42 | inline double system (); 43 | inline double elapsed (); 44 | 45 | // Convert time in seconds into a nice human-friendly format: h:mm:ss.ss 46 | // Precision is the number of digits after the decimal. 47 | // Return a pointer to a static buffer. 48 | static const char* formatTime (double sec, int precision = 2); 49 | 50 | private: 51 | 52 | void _compute (); 53 | 54 | enum State { stopped, running }; 55 | 56 | State _state; 57 | 58 | struct timeval _elapsed_start; 59 | struct timeval _elapsed_stop; 60 | double _elapsed; 61 | 62 | struct tms _cpu_start; 63 | struct tms _cpu_stop; 64 | double _user; 65 | double _system; 66 | }; 67 | 68 | Timer::Timer () 69 | { 70 | reset (); 71 | } 72 | 73 | Timer::~Timer () 74 | { 75 | } 76 | 77 | void 78 | Timer::reset () 79 | { 80 | _state = stopped; 81 | _elapsed = _user = _system = 0; 82 | } 83 | 84 | void 85 | Timer::start () 86 | { 87 | assert (_state == stopped); 88 | _state = running; 89 | gettimeofday (&_elapsed_start, NULL); 90 | times (&_cpu_start); 91 | } 92 | 93 | void 94 | Timer::stop () 95 | { 96 | assert (_state == running); 97 | gettimeofday (&_elapsed_stop, NULL); 98 | times (&_cpu_stop); 99 | _compute (); 100 | _state = stopped; 101 | } 102 | 103 | double 104 | Timer::cpu () 105 | { 106 | assert (_state == stopped); 107 | return _user + _system; 108 | } 109 | 110 | double 111 | Timer::user () 112 | { 113 | assert (_state == stopped); 114 | return _user; 115 | } 116 | 117 | double 118 | Timer::system () 119 | { 120 | assert (_state == stopped); 121 | return _system; 122 | } 123 | 124 | double 125 | Timer::elapsed () 126 | { 127 | assert (_state == stopped); 128 | return _elapsed; 129 | } 130 | 131 | #endif // __Timer_hh__ 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/include/kofn.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __kofn_hh__ 3 | #define __kofn_hh__ 4 | 5 | void kOfN (int k, int n, int* values); 6 | 7 | #endif // __kofn_hh__ 8 | -------------------------------------------------------------------------------- /evaluation/EdgeEval/include/match.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __match_hh__ 3 | #define __match_hh__ 4 | 5 | class Matrix; 6 | 7 | // returns the cost of the assignment 8 | double matchEdgeMaps ( 9 | const Matrix& bmap1, const Matrix& bmap2, 10 | double maxDist, double outlierCost, 11 | Matrix& match1, Matrix& match2); 12 | 13 | #endif // __match_hh__ -------------------------------------------------------------------------------- /evaluation/EdgeEval/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from Cython.Distutils import build_ext 4 | from distutils.extension import Extension 5 | import subprocess 6 | import numpy as np 7 | import glob 8 | 9 | try: 10 | NP_INCLUDE = np.get_include() 11 | except AttributeError: 12 | NP_INCLUDE = np.get_numpy_include() 13 | 14 | CSAPP_INCLUDE = os.path.abspath('CSA++') 15 | 16 | INCLUDE = [os.path.abspath('include'), 17 | CSAPP_INCLUDE, 18 | NP_INCLUDE, 19 | os.path.abspath('Util')] 20 | # SOURCES = ['CSA++/csa.cc','match.cc','correspondPixels.cc'] 21 | SOURCES = ['EdgeMapEval.cc','CSA++/csa.cc','match.cc','correspondPixels.pyx'] + glob.glob('Util/*.cc') 22 | # import pdb 23 | # pdb.set_trace() 24 | 25 | class custom_build_ext(build_ext): 26 | def build_extensions(self): 27 | build_ext.build_extensions(self) 28 | 29 | ext_modules = [ 30 | Extension("EdgeMapEval", 31 | sources=SOURCES, 32 | include_dirs=INCLUDE, 33 | language='c++', 34 | extra_compile_args=["-std=c++11"], 35 | ) 36 | ] 37 | 38 | setup( 39 | ext_modules = ext_modules, 40 | cmdclass = { 41 | 'build_ext': custom_build_ext 42 | } 43 | ) -------------------------------------------------------------------------------- /evaluation/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | cd EdgeEval; python setup.py build_ext --inplace; cd .. 3 | 4 | -------------------------------------------------------------------------------- /evaluation/RasterizeLine/__init__.py: -------------------------------------------------------------------------------- 1 | from .draw import drawfn 2 | __all__ = ['drawfn'] -------------------------------------------------------------------------------- /evaluation/RasterizeLine/build/temp.linux-x86_64-3.5/draw.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/RasterizeLine/build/temp.linux-x86_64-3.5/draw.o -------------------------------------------------------------------------------- /evaluation/RasterizeLine/build/temp.linux-x86_64-3.5/kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/RasterizeLine/build/temp.linux-x86_64-3.5/kernel.o -------------------------------------------------------------------------------- /evaluation/RasterizeLine/build/temp.linux-x86_64-3.8/draw.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/RasterizeLine/build/temp.linux-x86_64-3.8/draw.o -------------------------------------------------------------------------------- /evaluation/RasterizeLine/build/temp.linux-x86_64-3.8/kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/RasterizeLine/build/temp.linux-x86_64-3.8/kernel.o -------------------------------------------------------------------------------- /evaluation/RasterizeLine/draw.hpp: -------------------------------------------------------------------------------- 1 | void _draw(const float* lines, 2 | double* map, 3 | int height, int width, 4 | int nlines); -------------------------------------------------------------------------------- /evaluation/RasterizeLine/draw.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | 4 | assert sizeof(int) == sizeof(np.int32_t) 5 | 6 | cdef extern from "draw.hpp": 7 | void _draw(const float* lines, double* map, int height, int width, int nlines); 8 | 9 | 10 | def drawfn(np.ndarray[float, ndim=2] lines, 11 | int height, int width): 12 | 13 | cdef np.ndarray[np.double_t, ndim=2] linemap = np.zeros((height,width), 14 | dtype=np.double) 15 | cdef int nlines = lines.shape[0] 16 | _draw( lines.data, linemap.data, 17 | height, width, nlines) 18 | 19 | 20 | return linemap -------------------------------------------------------------------------------- /evaluation/RasterizeLine/kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "draw.hpp" 2 | #include "math.h" 3 | void _draw(const float* lines, 4 | double* map, 5 | int height, int width, 6 | int nlines) 7 | { 8 | for(int k = 0; k < nlines; ++k){ 9 | float x1 = lines[4*k]; 10 | float y1 = lines[4*k+1]; 11 | float x2 = lines[4*k+2]; 12 | float y2 = lines[4*k+3]; 13 | 14 | float vn = ceil(sqrt((x1-x2)*(x1-x2)+(y1-y2)*(y1-y2))); 15 | 16 | float dx = (x2-x1)/(vn-1.0); 17 | float dy = (y2-y1)/(vn-1.0); 18 | 19 | for(int j = 0; j < (int) vn; ++j){ 20 | int xx = round(x1+(float)j*dx); 21 | int yy = round(y1+(float)j*dy); 22 | if (xx>=0 && xx<=width-1 && yy>=0 && yy<=height-1){ 23 | int index = yy*width + xx; 24 | map[index] = 1.0; 25 | } 26 | } 27 | } 28 | } -------------------------------------------------------------------------------- /evaluation/RasterizeLine/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from Cython.Distutils import build_ext 4 | from distutils.extension import Extension 5 | import subprocess 6 | import numpy as np 7 | import glob 8 | 9 | try: 10 | NP_INCLUDE = np.get_include() 11 | except AttributeError: 12 | NP_INCLUDE = np.get_numpy_include() 13 | 14 | class custom_build_ext(build_ext): 15 | def build_extensions(self): 16 | build_ext.build_extensions(self) 17 | 18 | ext_modules = [ 19 | Extension("draw", 20 | sources=['kernel.cpp','draw.pyx'], 21 | include_dirs=[NP_INCLUDE], 22 | language='c++', 23 | extra_compile_args=["-std=c++11"], 24 | ) 25 | ] 26 | 27 | setup( 28 | ext_modules = ext_modules, 29 | cmdclass = { 30 | 'build_ext': custom_build_ext 31 | } 32 | ) -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluation import LSDEvaluator 2 | 3 | __all__ = ["LSDEvaluator"] 4 | -------------------------------------------------------------------------------- /evaluation/compute_prec_recall.py: -------------------------------------------------------------------------------- 1 | from .EdgeEval import correspond 2 | -------------------------------------------------------------------------------- /evaluation/draw-hap.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import argparse 3 | import numpy as np 4 | import json 5 | import matplotlib as mpl 6 | import scipy.io as sio 7 | from scipy import interpolate 8 | 9 | mpl.rcParams.update({"font.size": 18}) 10 | # mpl.font_manager._rebuild() 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--path',nargs='+', type=str) 15 | parser.add_argument('--threshold', type=int, choices = [5,10,15], default=10) 16 | parser.add_argument('--dest', default = None, type=str, help='the destination of the saved figure') 17 | 18 | args = parser.parse_args() 19 | 20 | evaluation_results = [] 21 | legends = [] 22 | for path in args.path: 23 | result = sio.loadmat(path) 24 | evaluation_results.append(result) 25 | legends.append(result['label']) 26 | f_scores = np.linspace(0.2,0.9,num=8).tolist() 27 | for f_score in f_scores: 28 | x = np.linspace(0.01,1) 29 | y = f_score*x/(2*x-f_score) 30 | l, = plt.plot(x[y >= 0], y[y >= 0], color=[0,0.5,0], alpha=0.3) 31 | plt.annotate("f={0:0.1}".format(f_score), xy=(0.9, y[45] + 0.02), alpha=0.4,fontsize=10) 32 | 33 | plt.rc('legend',fontsize=14) 34 | plt.grid(True) 35 | plt.axis([0.0, 1.0, 0.0, 1.0]) 36 | plt.xticks(np.arange(0, 1.0, step=0.1)) 37 | plt.xlabel("Recall") 38 | plt.ylabel("Precision") 39 | plt.yticks(np.arange(0, 1.0, step=0.1)) 40 | 41 | for label, result in zip(legends,evaluation_results): 42 | label = label.item() 43 | precision = result['precision'][0].flatten() 44 | recall = result['recall'][0].flatten() 45 | idx = np.isfinite(precision)*np.isfinite(recall) 46 | precision = precision[idx] 47 | recall = recall[idx] 48 | 49 | x = np.arange(0,1,0.01)*recall[-1] 50 | f = interpolate.interp1d(recall,precision,kind='cubic',bounds_error=False,fill_value=precision[0]) 51 | y = f(x) 52 | T = 0.005 53 | 54 | print(result['f'].item()) 55 | 56 | 57 | #sap_head = "sAP$^{%d}$"%(args.threshold) 58 | if 'HAWP' in label: 59 | label_ = '[F={:.1f}] {}(Ours)'.format(result['f'].item()*100,label) 60 | else: 61 | label_ = '[F={:.1f}] {}'.format(result['f'].item()*100,label) 62 | #plt.plot(recall,precision,'-',label=label_sap,linewidth=3) 63 | # plt.plot(recall,precision,'-') 64 | plt.plot(x,y,'-',label=label_,linewidth=3) 65 | plt.legend(loc='best') 66 | title = "PR Curve for AP$H$" 67 | plt.title(title) 68 | if args.dest: 69 | plt.savefig(args.dest, bbox_inches='tight') 70 | else: 71 | plt.show() 72 | 73 | if __name__ == "__main__": 74 | main() -------------------------------------------------------------------------------- /evaluation/draw-json.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import cv2 3 | import numpy as np 4 | import argparse 5 | import os 6 | import os.path as osp 7 | import json 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--pred',type=str,required=True,help='the json file for the wireframe or line segment predictions') 11 | parser.add_argument('--benchmark', type=str, choices = ['wireframe','york'], required=True) 12 | parser.add_argument('--threshold', type=float, default=0.5) 13 | parser.add_argument('--cmp', default='g', choices = ['g','l']) 14 | parser.add_argument('--dest', type=str, required=True) 15 | parser.add_argument('--width', default=0, type=int,) 16 | parser.add_argument('--height', default=0, type=int,) 17 | parser.add_argument('--topk', default=-1, type=int) 18 | parser.add_argument('--fname', default=None, type=str) 19 | 20 | args = parser.parse_args() 21 | with open(args.pred,'r') as fin: 22 | results = json.load(fin) 23 | 24 | if args.benchmark == 'wireframe': 25 | args.images = 'data/wireframe/images' 26 | elif args.benchmark == 'york': 27 | args.images = 'data/york/images' 28 | 29 | os.makedirs(args.dest,exist_ok=True) 30 | 31 | if args.fname is not None: 32 | results = [r for r in results if r['filename'] == args.fname] 33 | for result in results: 34 | fname = result['filename'] 35 | 36 | image = cv2.imread(osp.join(args.images,fname))[:,:,::-1] 37 | ori_shape = image.shape 38 | 39 | lines = np.array(result['lines_pred']) 40 | score = np.array(result['lines_score']) 41 | if result['width']!=ori_shape[1] or result['height']!=ori_shape[0]: 42 | sx = float(ori_shape[1]/result['width']) 43 | sy = float(ori_shape[0]/result['height']) 44 | sxy = np.array([sx,sy,sx,sy]).reshape(-1,4) 45 | lines = lines*sxy 46 | 47 | if args.cmp == 'g': 48 | sort_arg = np.argsort(score)[::-1] 49 | lines = lines[sort_arg] 50 | score = score[sort_arg] 51 | idx = score>args.threshold 52 | else: 53 | sort_arg = np.argsort(score) 54 | lines = lines[sort_arg] 55 | score = score[sort_arg] 56 | idx = score 0: 59 | idx = np.arange(idx.shape[0])= 0], y[y >= 0], color=[0,0.5,0], alpha=0.3) 31 | plt.annotate("f={0:0.1}".format(f_score), xy=(0.9, y[45] + 0.02), alpha=0.4,fontsize=10) 32 | 33 | plt.rc('legend',fontsize=14) 34 | plt.grid(True) 35 | plt.axis([0.0, 1.0, 0.0, 1.0]) 36 | plt.xticks(np.arange(0, 1.0, step=0.1)) 37 | plt.xlabel("Recall") 38 | plt.ylabel("Precision") 39 | plt.yticks(np.arange(0, 1.0, step=0.1)) 40 | 41 | for label, result in zip(legends,evaluation_results): 42 | precision = result['precision'] 43 | recall = result['recall'] 44 | sap_head = "sAP$^{%d}$"%(args.threshold) 45 | if 'HAWP' in label: 46 | label_sap = '[{}={:.1f}] {}(Ours)'.format(sap_head,result['sAP'],label) 47 | else: 48 | label_sap = '[{}={:.1f}] {}'.format(sap_head,result['sAP'],label) 49 | plt.plot(recall,precision,'-',label=label_sap,linewidth=3) 50 | plt.legend(loc='best') 51 | title = "PR Curve for sAP$^{%d}$"%(int(args.threshold)) 52 | plt.title(title) 53 | if args.dest: 54 | plt.savefig(args.dest, bbox_inches='tight') 55 | else: 56 | plt.show() 57 | 58 | if __name__ == "__main__": 59 | main() -------------------------------------------------------------------------------- /evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import glob 4 | import os 5 | import os.path as osp 6 | from .prmeter import PrecisionRecallMeter 7 | import multiprocessing 8 | from tqdm import tqdm 9 | 10 | class LSDEvaluator(object): 11 | def __init__(self, root, thresholds, cmp='g', height=0,width=0): 12 | self.root = root 13 | self.filenames = glob.glob(osp.join(root,'*.mat')) 14 | self.thresholds = thresholds 15 | self.meter = PrecisionRecallMeter(self.thresholds,cmp=cmp) 16 | self.height =height 17 | self.width = width 18 | 19 | def eval_for_image(self, index): 20 | mat = sio.loadmat(self.filenames[index]) 21 | gt = mat['gt'] 22 | pred = mat['pred'] 23 | height = mat['height'].item() 24 | width = mat['width'].item() 25 | # import pdb 26 | # pdb.set_trace() 27 | if self.height>0 and self.width>0: 28 | sx = float(self.width/width) 29 | sy = float(self.height/height) 30 | scale = np.array([sx, 31 | sy, 32 | sx, 33 | sy],dtype=np.float32) 34 | scale = scale.reshape((1,4)) 35 | gt*=scale 36 | pred[:,:4] = pred[:,:4]*scale 37 | return self.meter(pred,gt,self.height,self.width) 38 | else: 39 | return self.meter(pred, gt, height, width) 40 | 41 | def __call__(self, num_workers=16, per_image = True): 42 | # self.eval_for_image(0) 43 | with multiprocessing.Pool(num_workers) as p: 44 | self.results = results = list(tqdm(p.imap(self.eval_for_image, 45 | range(len(self.filenames))), total=len(self.filenames))) 46 | if per_image: 47 | self.precisions = np.concatenate([r['p'][:,None] for r in self.results],axis=1) 48 | self.recalls = np.concatenate([r['r'][:, None] for r in self.results], axis=1) 49 | 50 | self.average_precisions = np.mean(self.precisions,axis=1) 51 | self.average_recalls = np.mean(self.recalls, axis=1) 52 | self.fmeasure = 2*self.average_precisions*self.average_recalls/(self.average_recalls+self.average_precisions) 53 | 54 | return {'precisions':self.precisions, 'recalls': self.recalls, 55 | 'avg_precision':self.average_precisions, 56 | 'avg_recall': self.average_recalls, 57 | 'avg_fmeasure': self.fmeasure, 58 | 'filenames': self.filenames} 59 | else: 60 | sumtp = sum(res['tp'] for res in results) 61 | sumfp = sum(res['fp'] for res in results) 62 | sumgt = sum(res['gt'] for res in results) 63 | # import pdb 64 | # pdb.set_trace() 65 | # rcs = sorted(sumtp/sumgt) 66 | # prs = sorted(sumtp/np.maximum(sumtp+sumfp,1e-9))[::-1] 67 | rcs = sumtp/sumgt 68 | prs = sumtp/np.maximum(sumtp+sumfp,1e-9) 69 | # temp = np.concatenate(([0],prs)) 70 | # idx = np.where((temp[1:]-temp[:-1])>0)[0] 71 | # rcs = rcs[idx] 72 | # prs = prs[idx] 73 | 74 | return {'avg_precision': prs, 'avg_recall':rcs} 75 | 76 | -------------------------------------------------------------------------------- /evaluation/example_evaluation.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import matplotlib.pyplot as plt 3 | from evaluation.RasterizeLine import drawfn 4 | import numpy as np 5 | from evaluation.prmeter import PrecisionRecallMeter 6 | import glob 7 | import os.path as osp 8 | from tqdm import tqdm 9 | import multiprocessing 10 | if __name__ == '__main__': 11 | 12 | files = glob.glob('../outputs/afm_box_b4/wireframe_afm/*.mat') 13 | meter = PrecisionRecallMeter([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) 14 | 15 | # results = [None]*len(files) 16 | def eval_on_image(i): 17 | mat = sio.loadmat(files[i]) 18 | gt = mat['gt'] 19 | pred = mat['pred'] 20 | height = mat['height'] 21 | width = mat['width'] 22 | 23 | return meter(pred,gt,height,width) 24 | 25 | # for i in range(len(files)): 26 | # eval_on_image(i) 27 | # import pdb 28 | # pdb.set_trace() 29 | with multiprocessing.Pool(32) as p: 30 | results = list(tqdm(p.imap(eval_on_image, range(len(files))), total=len(files))) 31 | precisions = np.concatenate([r['p'][:,None] for r in results],axis=1) 32 | recalls = np.concatenate([r['r'][:,None] for r in results],axis=1) 33 | 34 | import pdb 35 | pdb.set_trace() 36 | 37 | # mat = sio.loadmat('../outputs/afm_box_b4/wireframe_afm/00031546.mat') 38 | 39 | -------------------------------------------------------------------------------- /evaluation/example_rasterline.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import matplotlib.pyplot as plt 3 | from evaluation.RasterizeLine import drawfn 4 | import numpy as np 5 | 6 | if __name__ == '__main__': 7 | mat = sio.loadmat('../outputs/afm_box_b4/wireframe_afm/00037266.mat') 8 | gt = mat['gt'] 9 | pred = mat['pred'] 10 | height = mat['height'] 11 | width = mat['width'] 12 | G = drawfn(np.ascontiguousarray(gt),height,width) 13 | P = drawfn(np.ascontiguousarray(pred[:,:4]), height,width) 14 | import pdb 15 | pdb.set_trace() 16 | -------------------------------------------------------------------------------- /evaluation/prmeter.py: -------------------------------------------------------------------------------- 1 | from .EdgeEval import correspond 2 | from .RasterizeLine import drawfn 3 | import numpy as np 4 | 5 | # cmp_dict = { 6 | # 'g': lambda a,b: a>b, 7 | # 'l': lambda a,b: a==b, 8 | # 'e': lambda a,b: ab 29 | 30 | @staticmethod 31 | def cmp_e(a,b): 32 | return a==b 33 | 34 | @staticmethod 35 | def cmp_l(a,b): 36 | return at)[0] 58 | pred_t = np.ascontiguousarray(pred[idx,:4]) 59 | pred_map = drawfn(pred_t,height,width) 60 | 61 | matchE, matchG, _ = correspond(pred_map, gt_map,self.maxDist) 62 | cntR[i] = np.sum(matchG>0) 63 | sumR[i] = np.sum(gt_map>0) 64 | cntP[i] = np.sum(matchE>0) 65 | sumP[i] = np.sum(pred_map>0) 66 | 67 | matchE = np.array(matchE>0,dtype=np.float32) 68 | 69 | tp[i] = matchE.sum() 70 | fp[i] = np.sum(pred_map) - matchE.sum() 71 | gt[i] = gt_map.sum() 72 | #fp: sumP-cntR 73 | #gt: sumR 74 | #tp: cntP 75 | 76 | recalls[i] = cntR[i] / (sumR[i]+1e-15) 77 | precisions[i] = cntP[i] / (sumP[i]+1e-15) 78 | 79 | 80 | 81 | fscore = 2*recalls*precisions/(recalls+precisions+1e-6) 82 | 83 | return {'p':precisions, 'r':recalls, 'f':fscore, 84 | 'tp': tp, 85 | 'fp': fp, 86 | 'gt': gt, 87 | 'sumR': sumR, 88 | 'cntR': cntR, 89 | 'sumP': sumP, 90 | 'cntP': cntP} 91 | -------------------------------------------------------------------------------- /evaluation/runs/draw-hap.sh: -------------------------------------------------------------------------------- 1 | python -m evaluation.draw-hap --path \ 2 | precomputed-results/benchmark/DWP-wireframe.mat \ 3 | precomputed-results/benchmark/LCNN-wireframe.mat \ 4 | precomputed-results/benchmark/LETR-R101-wireframe-aph.mat \ 5 | precomputed-results/benchmark/LETR-R50-wireframe-aph.mat \ 6 | precomputed-results/benchmark/FClip-HG2-LB-wireframe-aph.mat \ 7 | precomputed-results/benchmark/afmpp-wireframe-aph.mat \ 8 | precomputed-results/benchmark/HAWPv1-wireframe.mat \ 9 | outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test-aph.mat \ 10 | --dest figures/APH-wireframe.pdf 11 | 12 | 13 | python -m evaluation.draw-hap --path \ 14 | precomputed-results/benchmark/DWP-york.mat \ 15 | precomputed-results/benchmark/LCNN-york.mat \ 16 | precomputed-results/benchmark/LETR-R101-york-aph.mat \ 17 | precomputed-results/benchmark/LETR-R50-york-aph.mat \ 18 | precomputed-results/benchmark/FClip-HG2-LB-york-aph.mat \ 19 | precomputed-results/benchmark/afmpp-york-aph.mat \ 20 | precomputed-results/benchmark/HAWPv1-york.mat \ 21 | outputs/ihawp-train-rot-v2-full/220625-162909/york_test-aph.mat \ 22 | --dest figures/APH-york.pdf 23 | # precomputed-results/benchmark/DWP-wireframe.mat \ -------------------------------------------------------------------------------- /evaluation/runs/draw-sap.sh: -------------------------------------------------------------------------------- 1 | # python -m evaluation.draw-sap --path \ 2 | # outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json.sap \ 3 | # precomputed-results/benchmark/HAWPv1-wireframe.json.sap \ 4 | # precomputed-results/benchmark/FClip-HG2-LB-wireframe.json.sap \ 5 | # precomputed-results/benchmark/LETR-R101-wireframe.json.sap \ 6 | # precomputed-results/benchmark/LETR-R50-wireframe.json.sap \ 7 | # precomputed-results/benchmark/LCNN-wireframe.json.sap \ 8 | # --threshold=10 \ 9 | # --dest figures/wireframe-sAP-10.pdf 10 | 11 | python -m evaluation.draw-sap --path \ 12 | outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json.sap \ 13 | precomputed-results/benchmark/HAWPv1-wireframe.json.sap \ 14 | precomputed-results/benchmark/FClip-HG2-LB-wireframe.json.sap \ 15 | precomputed-results/benchmark/LETR-R101-wireframe.json.sap \ 16 | precomputed-results/benchmark/LETR-R50-wireframe.json.sap \ 17 | precomputed-results/benchmark/LCNN-wireframe.json.sap \ 18 | --threshold=5 \ 19 | --dest figures/wireframe-sAP-05.pdf 20 | # outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json.sap \ 21 | #precomputed-results/benchmark/HAWPv1-wireframe.json.sap \ 22 | 23 | python -m evaluation.draw-sap --path \ 24 | precomputed-results/benchmark/LETR-R101-york.json.sap \ 25 | precomputed-results/benchmark/LETR-R50-york.json.sap \ 26 | precomputed-results/benchmark/FClip-HG2-LB-york.json.sap \ 27 | precomputed-results/benchmark/HAWPv1-york.json.sap \ 28 | outputs/ihawp-train-rot-v2-full/220625-162909/york_test.json.sap \ 29 | --threshold=5 \ 30 | --dest figures/york-sAP-05.pdf 31 | # precomputed-results/benchmark/LCNN-york.json.sap \ -------------------------------------------------------------------------------- /evaluation/runs/draw-vis-im-york.sh: -------------------------------------------------------------------------------- 1 | # FNAME='00031811.png' 2 | # FNAME='00110785.png' 3 | FNAME='00255368.png' 4 | python -m evaluation.draw-json \ 5 | --pred precomputed-results/benchmark/LETR-R101-wireframe.json \ 6 | --benchmark wireframe \ 7 | --topk 50 \ 8 | --fname $FNAME \ 9 | --dest precomputed-results/vis-fsl-sel/LETR-R101 10 | 11 | python -m evaluation.draw-json \ 12 | --pred outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json \ 13 | --benchmark wireframe \ 14 | --topk 50 \ 15 | --fname $FNAME \ 16 | --dest precomputed-results/vis-fsl-sel/HAWPv2 17 | 18 | 19 | python -m evaluation.draw-json \ 20 | --pred precomputed-results/benchmark/afmpp-wireframe.json \ 21 | --benchmark wireframe \ 22 | --threshold=0.2 \ 23 | --fname $FNAME \ 24 | --cmp=l \ 25 | --dest precomputed-results/vis-fsl-sel/afmpp 26 | 27 | python -m evaluation.draw-json \ 28 | --pred precomputed-results/benchmark/FClip-HG2-LB-wireframe.json \ 29 | --benchmark wireframe \ 30 | --fname $FNAME \ 31 | --topk 50 \ 32 | --dest precomputed-results/vis-fsl-sel/FClip 33 | 34 | 35 | -------------------------------------------------------------------------------- /evaluation/runs/draw-vis-im.sh: -------------------------------------------------------------------------------- 1 | FNAME='00031811.png' 2 | FNAME='00110785.png' 3 | # FNAME='00255368.png' 4 | # FNAME='00031546.png' 5 | # FNAME='00031811.png' 6 | # FNAME='00034439.png' 7 | FNAME='00053549.png' 8 | 9 | python -m evaluation.draw-json \ 10 | --pred precomputed-results/benchmark/HAWPv1-wireframe.json \ 11 | --benchmark wireframe \ 12 | --threshold 0.97 \ 13 | --fname $FNAME \ 14 | --dest precomputed-results/vis-fsl-sel/HAWPv1 15 | 16 | python -m evaluation.draw-json \ 17 | --pred precomputed-results/benchmark/LETR-R101-wireframe.json \ 18 | --benchmark wireframe \ 19 | --topk 50 \ 20 | --fname $FNAME \ 21 | --dest precomputed-results/vis-fsl-sel/LETR-R101 22 | 23 | python -m evaluation.draw-json \ 24 | --pred outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json \ 25 | --benchmark wireframe \ 26 | --threshold 0.9 \ 27 | --fname $FNAME \ 28 | --dest precomputed-results/vis-fsl-sel/HAWPv2 29 | 30 | 31 | python -m evaluation.draw-json \ 32 | --pred precomputed-results/benchmark/afmpp-wireframe.json \ 33 | --benchmark wireframe \ 34 | --threshold=0.2 \ 35 | --fname $FNAME \ 36 | --cmp=l \ 37 | --dest precomputed-results/vis-fsl-sel/afmpp 38 | 39 | python -m evaluation.draw-json \ 40 | --pred precomputed-results/benchmark/FClip-HG2-LB-wireframe.json \ 41 | --benchmark wireframe \ 42 | --fname $FNAME \ 43 | --topk 50 \ 44 | --dest precomputed-results/vis-fsl-sel/FClip 45 | 46 | 47 | -------------------------------------------------------------------------------- /evaluation/runs/draw-vis-york.sh: -------------------------------------------------------------------------------- 1 | python -m evaluation.draw-json \ 2 | --pred precomputed-results/benchmark/LETR-R101-york.json \ 3 | --benchmark york \ 4 | --topk 50 \ 5 | --dest precomputed-results/vis-fsl-york/LETR-R101 6 | 7 | python -m evaluation.draw-json \ 8 | --pred outputs/ihawp-train-rot-v2-full/220625-162909/york_test.json \ 9 | --benchmark york \ 10 | --threshold 0.9 \ 11 | --dest precomputed-results/vis-fsl-york/HAWPv2 12 | 13 | 14 | python -m evaluation.draw-json \ 15 | --pred precomputed-results/benchmark/HAWPv1-york.json \ 16 | --benchmark york \ 17 | --threshold 0.97 \ 18 | --dest precomputed-results/vis-fsl-york/HAWPv1 19 | 20 | python -m evaluation.draw-json \ 21 | --pred precomputed-results/benchmark/afmpp-york.json \ 22 | --benchmark york \ 23 | --threshold=0.2 \ 24 | --cmp=l \ 25 | --dest precomputed-results/vis-fsl-york/afmpp 26 | 27 | python -m evaluation.draw-json \ 28 | --pred precomputed-results/benchmark/FClip-HG2-LB-york.json \ 29 | --benchmark york \ 30 | --topk 100 \ 31 | --dest precomputed-results/vis-fsl-york/FClip 32 | 33 | 34 | -------------------------------------------------------------------------------- /evaluation/runs/draw-vis.sh: -------------------------------------------------------------------------------- 1 | python -m evaluation.draw-json \ 2 | --pred precomputed-results/benchmark/LETR-R101-wireframe.json \ 3 | --benchmark wireframe \ 4 | --topk 50 \ 5 | --dest precomputed-results/vis-fsl/LETR-R101 6 | 7 | python -m evaluation.draw-json \ 8 | --pred outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json \ 9 | --benchmark wireframe \ 10 | --threshold 0.97 \ 11 | # --topk 50 \ 12 | --dest precomputed-results/vis-fsl/HAWPv2 13 | 14 | python -m evaluation.draw-json \ 15 | --pred precomputed-results/benchmark/HAWPv1-wireframe.json \ 16 | --benchmark wireframe \ 17 | --threshold 0.97 \ 18 | --dest precomputed-results/vis-fsl/HAWPv2 19 | 20 | 21 | python -m evaluation.draw-json \ 22 | --pred precomputed-results/benchmark/afmpp-wireframe.json \ 23 | --benchmark wireframe \ 24 | --threshold=0.2 \ 25 | --cmp=l \ 26 | --dest precomputed-results/vis-fsl/afmpp 27 | 28 | python -m evaluation.draw-json \ 29 | --pred precomputed-results/benchmark/FClip-HG2-LB-wireframe.json \ 30 | --benchmark wireframe \ 31 | --topk 50 \ 32 | --dest precomputed-results/vis-fsl/FClip 33 | 34 | 35 | -------------------------------------------------------------------------------- /evaluation/runs/eval-aph.sh: -------------------------------------------------------------------------------- 1 | python -m evaluation.eval-json --pred precomputed-results/benchmark/FClip-HG2-LB-wireframe.json --benchmark wireframe --label "F-Clip-HG2-LB" --nthreads=8 --thresholds 0.1 0.2 0.25 0.27 0.3 0.315 0.33 0.345 0.36 0.38 0.4 0.42 0.45 0.47 0.49 0.5 0.52 0.54 0.56 0.58 2 | python -m evaluation.eval-json --pred precomputed-results/benchmark/FClip-HG2-LB-york.json --benchmark york --label "F-Clip-HG2-LB" --nthreads=8 --thresholds 0.1 0.2 0.25 0.27 0.3 0.315 0.33 0.345 0.36 0.38 0.4 0.42 0.45 0.47 0.49 0.5 0.52 0.54 0.56 0.58 3 | -------------------------------------------------------------------------------- /evaluation/runs/sap.sh: -------------------------------------------------------------------------------- 1 | #LETR-R50 2 | python -m evaluation.eval-sap --pred precomputed-results/benchmark/LETR-R50-wireframe.json --benchmark wireframe --label LETR-R50 3 | python -m evaluation.eval-sap --pred precomputed-results/benchmark/LETR-R50-york.json --benchmark york --label LETR-R50 4 | #LETR-R101 5 | python -m evaluation.eval-sap --pred precomputed-results/benchmark/LETR-R101-wireframe.json --benchmark wireframe --label LETR-R101 6 | python -m evaluation.eval-sap --pred precomputed-results/benchmark/LETR-R101-york.json --benchmark york --label LETR-R101 7 | 8 | 9 | #HAWPv2 10 | python -m evaluation.eval-sap --pred outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json --benchmark wireframe --label HAWPv2 11 | python -m evaluation.eval-sap --pred outputs/ihawp-train-rot-v2-full/220625-162909/york_test.json --benchmark york --label HAWPv2 12 | -------------------------------------------------------------------------------- /evaluation/sAPEval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/sAPEval/__init__.py -------------------------------------------------------------------------------- /evaluation/sAPEval/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import os.path as osp 4 | import glob 5 | 6 | def ap(tp, fp): 7 | recall = tp 8 | precision = tp / np.maximum(tp + fp, 1e-9) 9 | 10 | recall = np.concatenate(([0.0], recall, [1.0])) 11 | precision = np.concatenate(([0.0], precision, [0.0])) 12 | 13 | for i in range(precision.size - 1, 0, -1): 14 | precision[i - 1] = max(precision[i - 1], precision[i]) 15 | i = np.where(recall[1:] != recall[:-1])[0] 16 | return np.sum((recall[i + 1] - recall[i]) * precision[i + 1]) 17 | 18 | def msTPFP(lines_pred, lines_gt, threshold = 5): 19 | x1_pred = lines_pred[:,:2] 20 | x2_pred = lines_pred[:,2:4] 21 | x1_gt = lines_gt[:,:2] 22 | x2_gt = lines_gt[:,2:] 23 | diff1_1 = ((x1_pred[:,None]-x1_gt)**2).sum(-1) 24 | diff1_2 = ((x1_pred[:,None]-x2_gt)**2).sum(-1) 25 | 26 | diff2_1 = ((x2_pred[:, None] - x1_gt) ** 2).sum(-1) 27 | diff2_2 = ((x2_pred[:, None] - x2_gt) ** 2).sum(-1) 28 | 29 | diff = np.minimum(diff1_1+diff2_2, diff1_2+diff2_1) 30 | 31 | choice = np.argmin(diff,1) 32 | dist = np.min(diff,1) 33 | hit = np.zeros(len(lines_gt),np.bool) 34 | tp = np.zeros(len(lines_pred),np.float) 35 | fp = np.zeros(len(lines_pred), np.float) 36 | for i in range(len(lines_pred)): 37 | if dist[i] < threshold and not hit[choice[i]]: 38 | hit[choice[i]] = True 39 | tp[i] = 1 40 | else: 41 | fp[i] = 1 42 | 43 | return tp,fp 44 | 45 | 46 | if __name__ == '__main__': 47 | 48 | path = osp.join('outputs','afmbox_R50-FPN-AFM-512','wireframe','*.mat') 49 | 50 | files = glob.glob(path) 51 | 52 | tps, fps, scores = [],[],[] 53 | n_gt = 0 54 | 55 | aps = [] 56 | for f in files: 57 | mat = sio.loadmat(f) 58 | 59 | height = mat['height'].item() 60 | width = mat['width'].item() 61 | lines_pred = mat['pred'] 62 | lines_gt = mat['gt'] 63 | 64 | lines_pred[:, 0] *= 128 / width 65 | lines_pred[:, 2] *= 128 / width 66 | lines_pred[:, 1] *= 128 / height 67 | lines_pred[:, 3] *= 128 / height 68 | 69 | lines_gt[:, 0] *= 128 / width 70 | lines_gt[:, 2] *= 128 / width 71 | lines_gt[:, 1] *= 128 / height 72 | lines_gt[:, 3] *= 128 / height 73 | 74 | pred_score = lines_pred[:,4] 75 | 76 | 77 | n_gt += len(lines_gt) 78 | 79 | tp,fp = msTPFP(lines_pred,lines_gt,10) 80 | # import pdb 81 | # pdb.set_trace() 82 | # tp = np.cumsum(tp)/len(lines_gt) 83 | # fp = np.cumsum(fp)/len(lines_gt) 84 | 85 | aps += [ap(tp,fp)] 86 | tps.append(tp) 87 | fps.append(fp) 88 | scores.append(pred_score) 89 | 90 | tps = np.concatenate(tps) 91 | fps = np.concatenate(fps) 92 | scores = np.concatenate(scores) 93 | index = np.argsort(scores) 94 | tp = np.cumsum(tps[index])/n_gt 95 | fp = np.cumsum(fps[index])/n_gt 96 | import pdb 97 | pdb.set_trace() 98 | 99 | 100 | -------------------------------------------------------------------------------- /hawp/__init__.py: -------------------------------------------------------------------------------- 1 | # from .models import HAWP 2 | from . import base 3 | from . import fsl 4 | from . import ssl -------------------------------------------------------------------------------- /hawp/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .csrc import _C 2 | from . import utils 3 | from .utils.comm import to_device 4 | from .utils.logger import setup_logger 5 | from .utils.metric_logger import MetricLogger 6 | from .utils.miscellaneous import save_config 7 | from .wireframe import WireframeGraph 8 | 9 | __all__ = [ 10 | "_C", 11 | "utils", 12 | "to_device", 13 | "setup_logger", 14 | "MetricLogger", 15 | "save_config", 16 | "WireframeGraph", 17 | ] -------------------------------------------------------------------------------- /hawp/base/csrc/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.cpp_extension import load 2 | import glob 3 | import os.path as osp 4 | 5 | __this__ = osp.dirname(__file__) 6 | 7 | try: 8 | _C = load(name='_C',sources=[ 9 | osp.join(__this__,'binding.cpp'), 10 | osp.join(__this__,'linesegment.cu'), 11 | ] 12 | ) 13 | except: 14 | _C = None 15 | 16 | __all__ = ["_C"] 17 | 18 | #_C = load(name='base._C', sources=['lltm_cuda.cpp', 'lltm_cuda_kernel.cu']) 19 | -------------------------------------------------------------------------------- /hawp/base/csrc/binding.cpp: -------------------------------------------------------------------------------- 1 | #include "linesegment.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("encodels", &encodels, "Encoding line segments to maps"); 5 | } -------------------------------------------------------------------------------- /hawp/base/csrc/linesegment.h: -------------------------------------------------------------------------------- 1 | // #pragma once 2 | #include 3 | 4 | std::tuple lsencode_cuda( 5 | const at::Tensor& lines, 6 | const int input_height, 7 | const int input_width, 8 | const int height, 9 | const int width, 10 | const int num_lines); 11 | 12 | std::tuple encodels( 13 | const at::Tensor& lines, 14 | const int input_height, 15 | const int input_width, 16 | const int height, 17 | const int width, 18 | const int num_lines) 19 | { 20 | return lsencode_cuda(lines, 21 | input_height, 22 | input_width, 23 | height, 24 | width, 25 | num_lines); 26 | } -------------------------------------------------------------------------------- /hawp/base/show/__init__.py: -------------------------------------------------------------------------------- 1 | from .canvas import Canvas, image_canvas, canvas 2 | from .painters import HAWPainter 3 | from .cli import cli, configure -------------------------------------------------------------------------------- /hawp/base/show/cli.py: -------------------------------------------------------------------------------- 1 | # from hawp.config import defaults 2 | import logging 3 | 4 | from .canvas import Canvas 5 | from .painters import HAWPainter 6 | import matplotlib 7 | LOG = logging.getLogger(__name__) 8 | 9 | def cli(parser): 10 | group = parser.add_argument_group('show') 11 | 12 | assert not Canvas.show 13 | group.add_argument('--show', default=False,action='store_true', 14 | help='show every plot, i.e., call matplotlib show()') 15 | 16 | group.add_argument('--edge-threshold', default=None, type=float, 17 | help='show the wireframe edges whose confidences are greater than [edge_threshold]') 18 | group.add_argument('--out-ext', default='png', type=str, 19 | help='save the plot in specific format') 20 | def configure(args): 21 | Canvas.show = args.show 22 | Canvas.out_file_extension = args.out_ext 23 | if args.edge_threshold is not None: 24 | HAWPainter.confidence_threshold = args.edge_threshold -------------------------------------------------------------------------------- /hawp/base/show/painters.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | try: 8 | import matplotlib 9 | import matplotlib.animation 10 | import matplotlib.collections 11 | import matplotlib.patches 12 | except ImportError: 13 | matplotlib = None 14 | 15 | 16 | LOG = logging.getLogger(__name__) 17 | 18 | 19 | class HAWPainter: 20 | line_width = None 21 | marker_size = None 22 | confidence_threshold = 0.05 23 | 24 | def __init__(self): 25 | 26 | if self.line_width is None: 27 | self.line_width = 2 28 | 29 | if self.marker_size is None: 30 | self.marker_size = max(1, int(self.line_width * 0.5)) 31 | 32 | def draw_wireframe(self, ax, wireframe, *, 33 | edge_color = None, vertex_color = None): 34 | if wireframe is None: 35 | return 36 | 37 | if edge_color is None: 38 | edge_color = 'b' 39 | if vertex_color is None: 40 | vertex_color = 'c' 41 | 42 | line_segments = wireframe['lines_pred'][wireframe['lines_score']>self.confidence_threshold] 43 | 44 | if isinstance(line_segments, torch.Tensor): 45 | line_segments = line_segments.cpu().numpy() 46 | 47 | # line_segments = wireframe.line_segments(threshold=self.confidence_threshold) 48 | # line_segments = line_segments.cpu().numpy() 49 | ax.plot([line_segments[:,0],line_segments[:,2]],[line_segments[:,1],line_segments[:,3]],'-',color=edge_color) 50 | ax.plot(line_segments[:,0],line_segments[:,1],'.',color=vertex_color) 51 | ax.plot(line_segments[:,2],line_segments[:,3],'.', 52 | color=vertex_color) 53 | -------------------------------------------------------------------------------- /hawp/base/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /hawp/base/utils/imports.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | import importlib 5 | import importlib.util 6 | import sys 7 | 8 | def import_file(module_name, file_path, make_importable=False): 9 | spec = importlib.util.spec_from_file_location(module_name, file_path) 10 | module = importlib.util.module_from_spec(spec) 11 | spec.loader.exec_module(module) 12 | if make_importable: 13 | sys.modules[module_name] = module 14 | return module 15 | -------------------------------------------------------------------------------- /hawp/base/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | import sys 5 | from pythonjsonlogger import jsonlogger 6 | 7 | 8 | def setup_logger(name, save_dir, out_file='log.txt', json_format=False): 9 | logger = logging.getLogger(name) 10 | logger.setLevel(logging.DEBUG) 11 | ch = logging.StreamHandler(stream=sys.stdout) 12 | ch.setLevel(logging.DEBUG) 13 | if json_format: 14 | formatter = jsonlogger.JsonFormatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 15 | else: 16 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 17 | ch.setFormatter(formatter) 18 | logger.addHandler(ch) 19 | 20 | if save_dir: 21 | fh = logging.FileHandler(os.path.join(save_dir, out_file)) 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger 27 | -------------------------------------------------------------------------------- /hawp/base/utils/metric_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def msTPFP(line_pred, line_gt, threshold): 4 | line_pred = line_pred.reshape(-1, 2, 2)[:, :, ::-1] 5 | line_gt = line_gt.reshape(-1, 2, 2)[:, :, ::-1] 6 | diff = ((line_pred[:, None, :, None] - line_gt[:, None]) ** 2).sum(-1) 7 | diff = np.minimum( 8 | diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0] 9 | ) 10 | 11 | choice = np.argmin(diff, 1) 12 | dist = np.min(diff, 1) 13 | hit = np.zeros(len(line_gt), np.bool) 14 | tp = np.zeros(len(line_pred), np.float) 15 | fp = np.zeros(len(line_pred), np.float) 16 | for i in range(len(line_pred)): 17 | if dist[i] < threshold and not hit[choice[i]]: 18 | hit[choice[i]] = True 19 | tp[i] = 1 20 | else: 21 | fp[i] = 1 22 | return tp, fp 23 | 24 | 25 | def TPFP(lines_dt, lines_gt, threshold): 26 | lines_dt = lines_dt.reshape(-1,2,2)[:,:,::-1] 27 | lines_gt = lines_gt.reshape(-1,2,2)[:,:,::-1] 28 | diff = ((lines_dt[:, None, :, None] - lines_gt[:, None]) ** 2).sum(-1) 29 | diff = np.minimum( 30 | diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0] 31 | ) 32 | 33 | # diff1 = ((lines_dt[:, None, :2] - lines_gt[:, :2]) ** 2).sum(-1) 34 | # diff2 = ((lines_dt[:, None, 2:] - lines_gt[:, 2:]) ** 2).sum(-1) 35 | # diff3 = ((lines_dt[:, None, :2] - lines_gt[:, 2:]) ** 2).sum(-1) 36 | # diff4 = ((lines_dt[:, None, 2:] - lines_gt[:, :2]) ** 2).sum(-1) 37 | # import pdb 38 | # pdb.set_trace() 39 | # diff = np.minimum(diff1+diff2, diff3+diff4) 40 | choice = np.argmin(diff,1) 41 | dist = np.min(diff,1) 42 | hit = np.zeros(len(lines_gt), np.bool) 43 | tp = np.zeros(len(lines_dt), np.float) 44 | fp = np.zeros(len(lines_dt),np.float) 45 | 46 | for i in range(lines_dt.shape[0]): 47 | if dist[i] < threshold and not hit[choice[i]]: 48 | hit[choice[i]] = True 49 | tp[i] = 1 50 | else: 51 | fp[i] = 1 52 | return tp, fp 53 | 54 | def AP(tp, fp): 55 | recall = tp 56 | precision = tp/np.maximum(tp+fp, 1e-9) 57 | 58 | recall = np.concatenate(([0.0], recall, [1.0])) 59 | precision = np.concatenate(([0.0], precision, [0.0])) 60 | 61 | 62 | 63 | for i in range(precision.size - 1, 0, -1): 64 | precision[i - 1] = max(precision[i - 1], precision[i]) 65 | i = np.where(recall[1:] != recall[:-1])[0] 66 | 67 | ap = np.sum((recall[i + 1] - recall[i]) * precision[i + 1]) 68 | 69 | return ap 70 | -------------------------------------------------------------------------------- /hawp/base/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import defaultdict 3 | from collections import deque 4 | 5 | import torch 6 | 7 | 8 | class SmoothedValue(object): 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=20): 14 | self.deque = deque(maxlen=window_size) 15 | self.series = [] 16 | self.total = 0.0 17 | self.count = 0 18 | 19 | def update(self, value): 20 | self.deque.append(value) 21 | self.series.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def avg(self): 32 | d = torch.tensor(list(self.deque)) 33 | return d.mean().item() 34 | 35 | @property 36 | def global_avg(self): 37 | return self.total / self.count 38 | 39 | 40 | class MetricLogger(object): 41 | def __init__(self, delimiter="\t"): 42 | self.meters = defaultdict(SmoothedValue) 43 | self.delimiter = delimiter 44 | 45 | def update(self, **kwargs): 46 | for k, v in kwargs.items(): 47 | if isinstance(v, torch.Tensor): 48 | v = v.item() 49 | assert isinstance(v, (float, int)) 50 | self.meters[k].update(v) 51 | 52 | def __getattr__(self, attr): 53 | if attr in self.meters: 54 | return self.meters[attr] 55 | if attr in self.__dict__: 56 | return self.__dict__[attr] 57 | raise AttributeError("'{}' object has no attribute '{}'".format( 58 | type(self).__name__, attr)) 59 | 60 | def __str__(self): 61 | loss_str = [] 62 | keys = sorted(self.meters) 63 | # for name, meter in self.meters.items(): 64 | for name in keys: 65 | meter = self.meters[name] 66 | loss_str.append( 67 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 68 | ) 69 | return self.delimiter.join(loss_str) 70 | 71 | def tensorborad(self, iteration, writter, phase='train'): 72 | for name, meter in self.meters.items(): 73 | if 'loss' in name: 74 | # writter.add_scalar('average/{}'.format(name), meter.avg, iteration) 75 | writter.add_scalar('{}/global/{}'.format(phase,name), meter.global_avg, iteration) 76 | # writter.add_scalar('median/{}'.format(name), meter.median, iteration) 77 | 78 | -------------------------------------------------------------------------------- /hawp/base/utils/miscellaneous.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import json 3 | import logging 4 | import os 5 | 6 | 7 | def mkdir(path): 8 | try: 9 | os.makedirs(path) 10 | except OSError as e: 11 | if e.errno != errno.EEXIST: 12 | raise 13 | 14 | 15 | def save_config(cfg, path): 16 | with open(path, 'w') as f: 17 | f.write(cfg.dump()) 18 | -------------------------------------------------------------------------------- /hawp/base/utils/model_zoo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import os 3 | import sys 4 | 5 | try: 6 | from torch.hub import _download_url_to_file 7 | from torch.hub import urlparse 8 | from torch.hub import HASH_REGEX 9 | except ImportError: 10 | from torch.utils.model_zoo import _download_url_to_file 11 | from torch.utils.model_zoo import urlparse 12 | from torch.utils.model_zoo import HASH_REGEX 13 | 14 | from .comm import is_main_process,synchronize 15 | 16 | 17 | # very similar to https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py 18 | # but with a few improvements and modifications 19 | def cache_url(url, model_dir=None, progress=True): 20 | r"""Loads the Torch serialized object at the given URL. 21 | If the object is already present in `model_dir`, it's deserialized and 22 | returned. The filename part of the URL should follow the naming convention 23 | ``filename-.ext`` where ```` is the first eight or more 24 | digits of the SHA256 hash of the contents of the file. The hash is used to 25 | ensure unique names and to verify the contents of the file. 26 | The default value of `model_dir` is ``$TORCH_HOME/models`` where 27 | ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be 28 | overridden with the ``$TORCH_MODEL_ZOO`` environment variable. 29 | Args: 30 | url (string): URL of the object to download 31 | model_dir (string, optional): directory in which to save the object 32 | progress (bool, optional): whether or not to display a progress bar to stderr 33 | Example: 34 | >>> cached_file = maskrcnn_benchmark.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') 35 | """ 36 | if model_dir is None: 37 | torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch")) 38 | model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models")) 39 | if not os.path.exists(model_dir): 40 | os.makedirs(model_dir) 41 | parts = urlparse(url) 42 | filename = os.path.basename(parts.path) 43 | if filename == "model_final.pkl": 44 | # workaround as pre-trained Caffe2 models from Detectron have all the same filename 45 | # so make the full path the filename by replacing / with _ 46 | filename = parts.path.replace("/", "_") 47 | cached_file = os.path.join(model_dir, filename) 48 | if not os.path.exists(cached_file) and is_main_process(): 49 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 50 | hash_prefix = HASH_REGEX.search(filename) 51 | if hash_prefix is not None: 52 | hash_prefix = hash_prefix.group(1) 53 | # workaround: Caffe2 models don't have a hash, but follow the R-50 convention, 54 | # which matches the hash PyTorch uses. So we skip the hash matching 55 | # if the hash_prefix is less than 6 characters 56 | if len(hash_prefix) < 6: 57 | hash_prefix = None 58 | _download_url_to_file(url, cached_file, hash_prefix, progress=progress) 59 | synchronize() 60 | return cached_file 61 | -------------------------------------------------------------------------------- /hawp/base/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | 4 | def _register_generic(module_dict, module_name, module): 5 | assert module_name not in module_dict 6 | module_dict[module_name] = module 7 | 8 | 9 | class Registry(dict): 10 | ''' 11 | A helper class for managing registering modules, it extends a dictionary 12 | and provides a register functions. 13 | 14 | Eg. creeting a registry: 15 | some_registry = Registry({"default": default_module}) 16 | 17 | There're two ways of registering new modules: 18 | 1): normal way is just calling register function: 19 | def foo(): 20 | ... 21 | some_registry.register("foo_module", foo) 22 | 2): used as decorator when declaring the module: 23 | @some_registry.register("foo_module") 24 | @some_registry.register("foo_modeul_nickname") 25 | def foo(): 26 | ... 27 | 28 | Access of module is just like using a dictionary, eg: 29 | f = some_registry["foo_modeul"] 30 | ''' 31 | def __init__(self, *args, **kwargs): 32 | super(Registry, self).__init__(*args, **kwargs) 33 | 34 | def register(self, module_name, module=None): 35 | # used as function call 36 | if module is not None: 37 | _register_generic(self, module_name, module) 38 | return 39 | 40 | # used as decorator 41 | def register_fn(fn): 42 | _register_generic(self, module_name, fn) 43 | return fn 44 | 45 | return register_fn 46 | -------------------------------------------------------------------------------- /hawp/base/wireframe.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | import json 6 | 7 | class WireframeGraph: 8 | def __init__(self, 9 | vertices: torch.Tensor, 10 | v_confidences: torch.Tensor, 11 | edges: torch.Tensor, 12 | edge_weights: torch.Tensor, 13 | frame_width: int, 14 | frame_height: int): 15 | self.vertices = vertices 16 | self.v_confidences = v_confidences 17 | self.edges = edges 18 | self.weights = edge_weights 19 | self.frame_width = frame_width 20 | self.frame_height = frame_height 21 | 22 | @classmethod 23 | def xyxy2indices(cls,junctions, lines): 24 | # junctions: (N,2) 25 | # lines: (M,4) 26 | # return: (M,2) 27 | dist1 = torch.norm(junctions[None,:,:]-lines[:,None,:2],dim=-1) 28 | dist2 = torch.norm(junctions[None,:,:]-lines[:,None,2:],dim=-1) 29 | idx1 = torch.argmin(dist1,dim=-1) 30 | idx2 = torch.argmin(dist2,dim=-1) 31 | return torch.stack((idx1,idx2),dim=-1) 32 | @classmethod 33 | def load_json(cls, fname): 34 | with open(fname,'r') as f: 35 | data = json.load(f) 36 | 37 | 38 | vertices = torch.tensor(data['vertices']) 39 | v_confidences = torch.tensor(data['vertices-score']) 40 | edges = torch.tensor(data['edges']) 41 | edge_weights = torch.tensor(data['edges-weights']) 42 | height = data['height'] 43 | width = data['width'] 44 | 45 | return WireframeGraph(vertices,v_confidences,edges,edge_weights,width,height) 46 | 47 | @property 48 | def is_empty(self): 49 | for key, val in self.__dict__.items(): 50 | if val is None: 51 | return True 52 | return False 53 | 54 | @property 55 | def num_vertices(self): 56 | if self.is_empty: 57 | return 0 58 | return self.vertices.shape[0] 59 | 60 | @property 61 | def num_edges(self): 62 | if self.is_empty: 63 | return 0 64 | return self.edges.shape[0] 65 | 66 | 67 | def line_segments(self, threshold = 0.05, device=None, to_np=False): 68 | is_valid = self.weights>threshold 69 | p1 = self.vertices[self.edges[is_valid,0]] 70 | p2 = self.vertices[self.edges[is_valid,1]] 71 | ps = self.weights[is_valid] 72 | 73 | lines = torch.cat((p1,p2,ps[:,None]),dim=-1) 74 | if device is not None: 75 | lines = lines.to(device) 76 | if to_np: 77 | lines = lines.cpu().numpy() 78 | 79 | return lines 80 | # if device != self.device: 81 | 82 | def rescale(self, image_width, image_height): 83 | scale_x = float(image_width)/float(self.frame_width) 84 | scale_y = float(image_height)/float(self.frame_height) 85 | 86 | self.vertices[:,0] *= scale_x 87 | self.vertices[:,1] *= scale_y 88 | self.frame_width = image_width 89 | self.frame_height = image_height 90 | 91 | def jsonize(self): 92 | return { 93 | 'vertices': self.vertices.cpu().tolist(), 94 | 'vertices-score': self.v_confidences.cpu().tolist(), 95 | 'edges': self.edges.cpu().tolist(), 96 | 'edges-weights': self.weights.cpu().tolist(), 97 | 'height': self.frame_height, 98 | 'width': self.frame_width, 99 | } 100 | def __repr__(self) -> str: 101 | return "WireframeGraph\n"+\ 102 | "Vertices: {}\n".format(self.num_vertices)+\ 103 | "Edges: {}\n".format(self.num_edges,) + \ 104 | "Frame size (HxW): {}x{}".format(self.frame_height,self.frame_width) 105 | 106 | #graph = WireframeGraph() 107 | if __name__ == "__main__": 108 | graph = WireframeGraph.load_json('NeuS/public_data/bmvs_clock/hawp/000.json') 109 | print(graph) 110 | -------------------------------------------------------------------------------- /hawp/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .hafm import HAFMencoder -------------------------------------------------------------------------------- /hawp/fsl/__init__.py: -------------------------------------------------------------------------------- 1 | from . import config 2 | from . import backbones 3 | from . import dataset 4 | from . import model -------------------------------------------------------------------------------- /hawp/fsl/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_backbone -------------------------------------------------------------------------------- /hawp/fsl/backbones/build.py: -------------------------------------------------------------------------------- 1 | from .registry import MODELS 2 | from .stacked_hg import HourglassNet, Bottleneck2D 3 | from .multi_task_head import MultitaskHead 4 | from .resnets import ResNets 5 | from .point_line import PointLineNet 6 | # from .point_line_new import PointLineNet 7 | from .stacked_point_line import StackPointLine 8 | 9 | @MODELS.register("Hourglass") 10 | def build_hg(cfg, **kwargs): 11 | inplanes = cfg.MODEL.HGNETS.INPLANES 12 | num_feats = cfg.MODEL.OUT_FEATURE_CHANNELS//2 13 | depth = cfg.MODEL.HGNETS.DEPTH 14 | num_stacks = cfg.MODEL.HGNETS.NUM_STACKS 15 | num_blocks = cfg.MODEL.HGNETS.NUM_BLOCKS 16 | head_size = cfg.MODEL.HEAD_SIZE 17 | 18 | out_feature_channels = cfg.MODEL.OUT_FEATURE_CHANNELS 19 | 20 | if kwargs.get('gray_scale',False): 21 | input_channels = 1 22 | else: 23 | input_channels = 3 24 | num_class = sum(sum(head_size, [])) 25 | model = HourglassNet( 26 | input_channels=input_channels, 27 | block=Bottleneck2D, 28 | inplanes = inplanes, 29 | num_feats= num_feats, 30 | depth=depth, 31 | head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size), 32 | num_stacks = num_stacks, 33 | num_blocks = num_blocks, 34 | num_classes = num_class) 35 | 36 | model.out_feature_channels = out_feature_channels 37 | 38 | return model 39 | 40 | 41 | 42 | @MODELS.register("PointLine") 43 | def build_hg(cfg, **kwargs): 44 | inplanes = cfg.MODEL.HGNETS.INPLANES 45 | num_feats = cfg.MODEL.OUT_FEATURE_CHANNELS//2 46 | depth = cfg.MODEL.HGNETS.DEPTH 47 | num_stacks = cfg.MODEL.HGNETS.NUM_STACKS 48 | num_blocks = cfg.MODEL.HGNETS.NUM_BLOCKS 49 | head_size = cfg.MODEL.HEAD_SIZE 50 | 51 | out_feature_channels = cfg.MODEL.OUT_FEATURE_CHANNELS 52 | 53 | if kwargs.get('gray_scale',False): 54 | input_channels = 1 55 | else: 56 | input_channels = 3 57 | num_class = sum(sum(head_size, [])) 58 | model = PointLineNet( 59 | head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size)) 60 | 61 | model.out_feature_channels = out_feature_channels 62 | 63 | return model 64 | 65 | 66 | 67 | @MODELS.register("StackPointLine") 68 | def build_hg(cfg, **kwargs): 69 | inplanes = cfg.MODEL.HGNETS.INPLANES 70 | num_feats = cfg.MODEL.OUT_FEATURE_CHANNELS//2 71 | depth = cfg.MODEL.HGNETS.DEPTH 72 | num_stacks = cfg.MODEL.HGNETS.NUM_STACKS 73 | num_blocks = cfg.MODEL.HGNETS.NUM_BLOCKS 74 | head_size = cfg.MODEL.HEAD_SIZE 75 | 76 | out_feature_channels = cfg.MODEL.OUT_FEATURE_CHANNELS 77 | 78 | if kwargs.get('gray_scale',False): 79 | input_channels = 1 80 | else: 81 | input_channels = 3 82 | num_class = sum(sum(head_size, [])) 83 | model = StackPointLine( 84 | input_channels=input_channels, 85 | block=Bottleneck2D, 86 | inplanes = inplanes, 87 | num_feats= num_feats, 88 | depth=depth, 89 | head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size), 90 | num_stacks = num_stacks, 91 | num_blocks = num_blocks, 92 | num_classes = num_class) 93 | 94 | model.out_feature_channels = out_feature_channels 95 | 96 | return model 97 | 98 | 99 | # @MODELS.register("ResNets") 100 | # def build_resnet(cfg): 101 | # head_size = cfg.MODEL.HEAD_SIZE 102 | 103 | # num_class = sum(sum(head_size,[])) 104 | # model = ResNets(cfg.MODEL.RESNETS.BASENET,head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size),num_class=num_class,pretrain=cfg.MODEL.RESNETS.PRETRAIN) 105 | 106 | # model.out_feature_channels = 128 107 | # return model 108 | 109 | def build_backbone(cfg, **kwargs): 110 | assert cfg.MODEL.NAME in MODELS, \ 111 | "cfg.MODELS.NAME: {} is not registered in registry".format(cfg.MODELS.NAME) 112 | 113 | return MODELS[cfg.MODEL.NAME](cfg, **kwargs) 114 | -------------------------------------------------------------------------------- /hawp/fsl/backbones/multi_task_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | class MultitaskHead(nn.Module): 4 | def __init__(self, input_channels, num_class, head_size): 5 | super(MultitaskHead, self).__init__() 6 | 7 | m = int(input_channels / 4) 8 | heads = [] 9 | for output_channels in sum(head_size, []): 10 | heads.append( 11 | nn.Sequential( 12 | nn.Conv2d(input_channels, m, kernel_size=3, padding=1), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(m, output_channels, kernel_size=1), 15 | ) 16 | ) 17 | self.heads = nn.ModuleList(heads) 18 | assert num_class == sum(sum(head_size, [])) 19 | 20 | def forward(self, x): 21 | return torch.cat([head(x) for head in self.heads], dim=1) 22 | 23 | 24 | class AngleDistanceHead(nn.Module): 25 | def __init__(self, input_channels, num_class, head_size): 26 | super(AngleDistanceHead, self).__init__() 27 | 28 | m = int(input_channels/4) 29 | 30 | heads = [] 31 | for output_channels in sum(head_size, []): 32 | if output_channels != 2: 33 | heads.append( 34 | nn.Sequential( 35 | nn.Conv2d(input_channels, m, kernel_size=3, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(m, output_channels, kernel_size=1), 38 | ) 39 | ) 40 | else: 41 | heads.append( 42 | nn.Sequential( 43 | nn.Conv2d(input_channels, m, kernel_size=3, padding=1), 44 | nn.ReLU(inplace=True), 45 | CosineSineLayer(m) 46 | ) 47 | ) 48 | self.heads = nn.ModuleList(heads) 49 | assert num_class == sum(sum(head_size, [])) 50 | def forward(self, x): 51 | return torch.cat([head(x) for head in self.heads], dim=1) -------------------------------------------------------------------------------- /hawp/fsl/backbones/registry.py: -------------------------------------------------------------------------------- 1 | from hawp.base.utils.registry import Registry 2 | 3 | MODELS = Registry() -------------------------------------------------------------------------------- /hawp/fsl/backbones/resnets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | class ResNets(nn.Module): 6 | RESNET_TEMPLATES = { 7 | 'resnet18':torchvision.models.resnet18, 8 | 'resnet34':torchvision.models.resnet34, 9 | 'resnet50':torchvision.models.resnet50, 10 | 'resnet101':torchvision.models.resnet101, 11 | } 12 | def __init__(self,basenet, head, num_class, pretrain=True,): 13 | super(ResNets, self).__init__() 14 | assert basenet in ResNets.RESNET_TEMPLATES 15 | 16 | basenet_fn = ResNets.RESNET_TEMPLATES.get(basenet) 17 | 18 | model = basenet_fn(pretrain) 19 | 20 | self.conv1 = nn.Conv2d(3,64,7,2,3) 21 | self.bn1 = nn.BatchNorm2d(64) 22 | self.relu = nn.ReLU(True) 23 | self.maxpool = nn.MaxPool2d(3,2,1) 24 | 25 | self.layer1 = model.layer1 26 | self.layer2 = model.layer2 27 | self.layer3 = model.layer3 28 | self.layer4 = model.layer4 29 | 30 | self.pixel_shuffle = nn.PixelShuffle(4) 31 | self.hafm_predictor = head(128,num_class) 32 | # self.hafm_predictor = nn.Sequential(nn.Conv2d(2048,512,3,1,1),nn.ReLU(True),nn.Conv2d(512,5,1,1,0)) 33 | def forward(self, images): 34 | x = self.conv1(images) 35 | x = self.relu(self.bn1(x)) 36 | x = self.maxpool(x) 37 | 38 | x = self.layer1(x) 39 | x = self.layer2(x) 40 | x = self.layer3(x) 41 | x = self.layer4(x) 42 | 43 | x = self.pixel_shuffle(x) 44 | 45 | return [self.hafm_predictor(x)], x 46 | 47 | if __name__ == "__main__": 48 | model = ResNets('resnet50') 49 | 50 | inp = torch.zeros((1,3,512,512)) 51 | 52 | model(inp) 53 | 54 | 55 | -------------------------------------------------------------------------------- /hawp/fsl/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import cfg 2 | 3 | __all__ = ['cfg'] -------------------------------------------------------------------------------- /hawp/fsl/config/dataset.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | # ---------------------------------------------------------------------------- # 3 | # Dataset options 4 | # ---------------------------------------------------------------------------- # 5 | DATASETS = CN() 6 | DATASETS.TRAIN = ("wireframe_train",) 7 | DATASETS.VAL = ("wireframe_test",) 8 | DATASETS.TEST = ("wireframe_test",) 9 | DATASETS.IMAGE = CN() 10 | DATASETS.IMAGE.HEIGHT = 512 11 | DATASETS.IMAGE.WIDTH = 512 12 | 13 | DATASETS.IMAGE.PIXEL_MEAN = [109.730, 103.832, 98.681] 14 | DATASETS.IMAGE.PIXEL_STD = [22.275, 22.124, 23.229] 15 | DATASETS.IMAGE.TO_255 = True 16 | DATASETS.TARGET = CN() 17 | DATASETS.TARGET.HEIGHT= 128 18 | DATASETS.TARGET.WIDTH = 128 19 | DATASETS.AUGMENTATION = 4 20 | DATASETS.DISTANCE_TH = 0.02 21 | DATASETS.NUM_STATIC_POSITIVE_LINES = 300 22 | DATASETS.NUM_STATIC_NEGATIVE_LINES = 40 23 | 24 | # 25 | -------------------------------------------------------------------------------- /hawp/fsl/config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | from .models import MODELS 3 | from .dataset import DATASETS 4 | from .solver import SOLVER 5 | from .detr import DETR 6 | cfg = CN() 7 | 8 | cfg.ENCODER = CN() 9 | cfg.ENCODER.DIS_TH = 5 10 | cfg.ENCODER.ANG_TH = 0.1 11 | cfg.ENCODER.NUM_STATIC_POS_LINES = 300 12 | cfg.ENCODER.NUM_STATIC_NEG_LINES = 40 13 | cfg.ENCODER.BACKGROUND_WEIGHT = 0.0 14 | cfg.MODELING_PATH = 'hawp' 15 | cfg.MODEL = MODELS 16 | cfg.DATASETS = DATASETS 17 | cfg.SOLVER = SOLVER 18 | 19 | cfg.DATALOADER = CN() 20 | cfg.DATALOADER.NUM_WORKERS = 8 21 | cfg.OUTPUT_DIR = "outputs/dev" 22 | -------------------------------------------------------------------------------- /hawp/fsl/config/detr.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | 4 | DETR = CN() 5 | 6 | DETR.backbone = 'resnet50' 7 | DETR.dilation = False #dilated conv, DC5 for DETR 8 | DETR.position_embedding = 'sine' 9 | DETR.lr_backbone = 1e-5 10 | DETR.enc_layers = 6 11 | DETR.dec_layers = 6 12 | DETR.dim_feedforward = 2048 13 | DETR.hidden_dim = 256 14 | DETR.dropout = 0.1 15 | DETR.nheads = 8 16 | DETR.num_queries = 1000 17 | DETR.pre_norm = False 18 | DETR.eos_coef = 0.1 #"Relative classification weight of the no-object class" 19 | 20 | DETR.no_aux_loss = False 21 | DETR.set_cost_class = 1.0 22 | DETR.set_cost_lines = 5.0 23 | 24 | 25 | -------------------------------------------------------------------------------- /hawp/fsl/config/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import MODELS 2 | 3 | __all__ = ['MODELS'] -------------------------------------------------------------------------------- /hawp/fsl/config/models/head.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | PARSING_HEAD = CN() 4 | 5 | PARSING_HEAD.MAX_DISTANCE = 5.0 6 | 7 | PARSING_HEAD.N_STC_POSL = 300 8 | PARSING_HEAD.N_STC_NEGL = 40 9 | 10 | PARSING_HEAD.MATCHING_STRATEGY = 'junction' #junction or line_adjusted 11 | PARSING_HEAD.N_DYN_JUNC = 300 12 | PARSING_HEAD.N_DYN_POSL = 300 13 | PARSING_HEAD.N_DYN_NEGL = 300 14 | PARSING_HEAD.N_DYN_OTHR = 0 15 | PARSING_HEAD.N_DYN_OTHR2 = 300 16 | 17 | PARSING_HEAD.N_PTS0 = 32 18 | PARSING_HEAD.N_PTS1 = 8 19 | 20 | PARSING_HEAD.DIM_LOI = 128 21 | PARSING_HEAD.DIM_FC = 1024 22 | PARSING_HEAD.USE_RESIDUAL = 1 23 | PARSING_HEAD.N_OUT_JUNC = 250 24 | PARSING_HEAD.N_OUT_LINE = 2500 25 | PARSING_HEAD.JMATCH_THRESHOLD = 1.5 26 | PARSING_HEAD.J2L_THRESHOLD = 1000.0 27 | PARSING_HEAD.JUNCTION_HM_THRESHOLD = 0.008 28 | #INFERENCE FLAGS 29 | #0, only use junctions to yield line segments 30 | #1, only use learned angles to yield line segments 31 | #2, match line segment proposals with junctions 32 | # PARSING_HEAD.INFERENCE = 0 -------------------------------------------------------------------------------- /hawp/fsl/config/models/models.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | from .shg import HGNETS 3 | from .resnets import RESNETS 4 | from .head import PARSING_HEAD 5 | 6 | MODELS = CN() 7 | 8 | MODELS.NAME = "Hourglass" 9 | MODELS.HGNETS = HGNETS 10 | MODELS.RESNETS = RESNETS 11 | MODELS.DEVICE = "cuda" 12 | MODELS.WEIGHTS = "" 13 | MODELS.HEAD_SIZE = [[3], [1], [1], [2], [2]] 14 | MODELS.OUT_FEATURE_CHANNELS = 256 15 | 16 | MODELS.LOSS_WEIGHTS = CN(new_allowed=True) 17 | MODELS.PARSING_HEAD = PARSING_HEAD 18 | MODELS.SCALE = 1.0 19 | 20 | MODELS.USE_LINE_HEATMAP = False 21 | MODELS.USE_HR_JUNCTION = False 22 | 23 | MODELS.LOI_POOLING = CN() 24 | MODELS.LOI_POOLING.USE_INIT_LINES = True 25 | MODELS.LOI_POOLING.NUM_POINTS = 32 26 | MODELS.LOI_POOLING.DIM_EDGE_FEATURE = 16 27 | MODELS.LOI_POOLING.DIM_JUNCTION_FEATURE = 128 28 | MODELS.LOI_POOLING.DIM_FC = 1024 29 | MODELS.LOI_POOLING.TYPE = 'softmax' 30 | MODELS.LOI_POOLING.LAYER_NORM = False 31 | MODELS.LOI_POOLING.ACTIVATION = 'relu' 32 | 33 | MODELS.FOCAL_LOSS = CN() 34 | MODELS.FOCAL_LOSS.ALPHA = -1.0 35 | MODELS.FOCAL_LOSS.GAMMA = 0.0 -------------------------------------------------------------------------------- /hawp/fsl/config/models/proposal_head.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | PROPOSAL_HEAD = CN() 4 | 5 | PROPOSAL_HEAD.NUM_PROPOSALS = 512 6 | PROPOSAL_HEAD.ANGULAR_THRESHOLD = 0.01 #TODO: REMOVE 7 | PROPOSAL_HEAD.NUM_SAMPLE_POINTS = 32 8 | PROPOSAL_HEAD.POSITIVE_DIS_TH = 10.0 9 | PROPOSAL_HEAD.NEGATIVE_DIS_TH = 15.0 10 | PROPOSAL_HEAD.LOWEST_SCORE_TH = 0.05 11 | PROPOSAL_HEAD.POST_NMS_TH = 10.0 12 | PROPOSAL_HEAD.USE_1DPOOLING = False 13 | PROPOSAL_HEAD.MIN_DISTANCE = 0.0 14 | PROPOSAL_HEAD.MAX_DISTANCE = 1.5 15 | PROPOSAL_HEAD.NUM_DIS_PROPOSAL = 9 16 | PROPOSAL_HEAD.USE_EDGE = False 17 | PROPOSAL_HEAD.SHARE_WITH_JUNCTION_FEATURE = False 18 | 19 | 20 | PROPOSAL_HEAD.NUM_DYNAMIC_JUNCTIONS = 300 21 | PROPOSAL_HEAD.NUM_DYNAMIC_POSITIVE_LINES = 300 22 | PROPOSAL_HEAD.NUM_DYNAMIC_NEGATIVE_LINES = 80 23 | PROPOSAL_HEAD.NUM_DYNAMIC_OTHER_LINES = 600 24 | 25 | # PROPOSAL_HEAD.LOSS_WEIGHTS = CN() 26 | # PROPOSAL_HEAD.LOSS_WEIGHTS.AFM = 1.0 27 | # PROPOSAL_HEAD.LOSS_WEIGHTS.PROPOSAL = 1.0 28 | -------------------------------------------------------------------------------- /hawp/fsl/config/models/resnets.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | RESNETS = CN() 4 | RESNETS.BASENET = 'resnet50' 5 | RESNETS.PRETRAIN = True 6 | 7 | # RESNETS.NUM_GROUPS = 1 8 | # RESNETS.WIDTH_PER_GROUP = 64 9 | # #RESNETS.STRIDE_IN_1X1 = True 10 | 11 | # # RESNETS.TRANS_FUNC = "BottleneckWithFixedBatchNorm" 12 | # RESNETS.NORM_FUNC = 'BatchNorm' 13 | # RESNETS.STEM_FUNC = "Stem" 14 | # RESNETS.STEM_OUT_CHANNELS = 64 15 | # RESNETS.STEM_KERNEL_SIZE = 7 16 | # RESNETS.STEM_RETURN_FEATURE = False 17 | # RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4 18 | # RESNETS.RES2_OUT_CHANNELS = 256 19 | 20 | 21 | # RESNETS.STAGE_SPECS = (3,4,6,3) # ResNet 50 22 | 23 | -------------------------------------------------------------------------------- /hawp/fsl/config/models/shg.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | HGNETS = CN() 4 | 5 | HGNETS.DEPTH = 4 6 | HGNETS.NUM_STACKS = 2 7 | HGNETS.NUM_BLOCKS = 1 8 | 9 | HGNETS.INPLANES = 64 10 | HGNETS.NUM_FEATS = 128 11 | -------------------------------------------------------------------------------- /hawp/fsl/config/paths_catalog.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | class DatasetCatalog(object): 5 | 6 | DATA_DIR = osp.abspath(osp.join(osp.dirname(__file__), 7 | '..','..','..','data')) 8 | 9 | DATASETS = { 10 | 'wireframe_train': { 11 | 'img_dir': 'wireframe/images', 12 | 'ann_file': 'wireframe/train.json', 13 | }, 14 | 'wireframe_train-pseudo': { 15 | 'img_dir': 'wireframe-pseudo/images', 16 | 'ann_file': 'wireframe-pseudo/train.json', 17 | }, 18 | 'wireframe_train-syn-export': { 19 | 'img_dir': 'wireframe-syn-export/images', 20 | 'ann_file': 'wireframe-syn-export/train.json', 21 | }, 22 | 'wireframe_train-syn-export-1': { 23 | 'img_dir': 'wireframe-syn-export-ep30-iter100-th075/images', 24 | 'ann_file': 'wireframe-syn-export-ep30-iter100-th075/train.json', 25 | }, 26 | 'wireframe_test1': { 27 | 'img_dir': 'wireframe/images', 28 | 'ann_file': 'wireframe/overfit.json', 29 | }, 30 | 'synthetic_train': { 31 | 'img_dir': 'synthetic-shapes/images', 32 | 'ann_file': 'synthetic-shapes/train.json', 33 | }, 34 | 'synthetic_test': { 35 | 'img_dir': 'synthetic-shapes/images', 36 | 'ann_file': 'synthetic-shapes/test.json', 37 | }, 38 | 'cities_train': { 39 | 'img_dir': 'cities/images', 40 | 'ann_file': 'cities/train.json', 41 | }, 42 | 'cities_test': { 43 | 'img_dir': 'cities/images', 44 | 'ann_file': 'cities/test.json', 45 | }, 46 | 'wireframe_test': { 47 | 'img_dir': 'wireframe/images', 48 | 'ann_file': 'wireframe/test.json', 49 | }, 50 | 'york_test': { 51 | 'img_dir': 'york/images', 52 | 'ann_file': 'york/test.json', 53 | }, 54 | 'coco_train-val2017': { 55 | 'img_dir': 'coco/val2017', 56 | 'ann_file': 'coco/coco-wf-val.json', 57 | }, 58 | 'coco_test-val2017': { 59 | 'img_dir': 'coco/val2017', 60 | 'ann_file': 'coco/coco-wf-val.json', 61 | } 62 | } 63 | 64 | @staticmethod 65 | def get(name): 66 | assert name in DatasetCatalog.DATASETS 67 | data_dir = DatasetCatalog.DATA_DIR 68 | attrs = DatasetCatalog.DATASETS[name] 69 | 70 | args = dict( 71 | root = osp.join(data_dir,attrs['img_dir']), 72 | ann_file = osp.join(data_dir,attrs['ann_file']) 73 | ) 74 | 75 | if 'train' in name: 76 | return dict(factory="TrainDataset",args=args) 77 | if 'test' in name and 'ann_file' in attrs: 78 | return dict(factory="TestDatasetWithAnnotations", 79 | args=args) 80 | raise NotImplementedError() -------------------------------------------------------------------------------- /hawp/fsl/config/solver.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | 4 | SOLVER = CN() 5 | SOLVER.IMS_PER_BATCH = 6 6 | SOLVER.MAX_EPOCH = 30 7 | SOLVER.OPTIMIZER = "ADAM" 8 | SOLVER.BASE_LR = 0.01 9 | SOLVER.BACKBONE_LR_FACTOR=1.0 10 | SOLVER.BIAS_LR_FACTOR = 1 11 | 12 | SOLVER.MOMENTUM = 0.9 13 | SOLVER.WEIGHT_DECAY = 0.0002 14 | SOLVER.WEIGHT_DECAY_BIAS = 0 15 | SOLVER.GAMMA = 0.1 16 | 17 | SOLVER.STEPS = (25,) 18 | SOLVER.CHECKPOINT_PERIOD = 1 19 | SOLVER.AMSGRAD = False 20 | -------------------------------------------------------------------------------- /hawp/fsl/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_dataset import TrainDataset 2 | from . import transforms 3 | from .build import build_train_dataset, build_test_dataset 4 | from .test_dataset import TestDatasetWithAnnotations -------------------------------------------------------------------------------- /hawp/fsl/dataset/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .transforms import * 3 | from . import train_dataset 4 | from ..config.paths_catalog import DatasetCatalog 5 | from . import test_dataset 6 | 7 | def build_transform(cfg): 8 | transforms = Compose( 9 | [ResizeImage(cfg.DATASETS.IMAGE.HEIGHT, 10 | cfg.DATASETS.IMAGE.WIDTH), 11 | ToTensor(), 12 | Normalize(cfg.DATASETS.IMAGE.PIXEL_MEAN, 13 | cfg.DATASETS.IMAGE.PIXEL_STD, 14 | cfg.DATASETS.IMAGE.TO_255) 15 | ] 16 | ) 17 | 18 | if cfg.MODEL.NAME == "PointLine": 19 | transforms = Compose( 20 | [ResizeImage(cfg.DATASETS.IMAGE.HEIGHT, 21 | cfg.DATASETS.IMAGE.WIDTH), 22 | ToTensor() 23 | ] 24 | ) 25 | 26 | return transforms 27 | def build_train_dataset(cfg): 28 | assert len(cfg.DATASETS.TRAIN) == 1 29 | name = cfg.DATASETS.TRAIN[0] 30 | dargs = DatasetCatalog.get(name) 31 | 32 | factory = getattr(train_dataset,dargs['factory']) 33 | args = dargs['args'] 34 | args['augmentation'] = cfg.DATASETS.AUGMENTATION 35 | args['transform'] = Compose( 36 | [Resize(cfg.DATASETS.IMAGE.HEIGHT, 37 | cfg.DATASETS.IMAGE.WIDTH, 38 | cfg.DATASETS.TARGET.HEIGHT, 39 | cfg.DATASETS.TARGET.WIDTH), 40 | ToTensor(), 41 | Normalize(cfg.DATASETS.IMAGE.PIXEL_MEAN, 42 | cfg.DATASETS.IMAGE.PIXEL_STD, 43 | cfg.DATASETS.IMAGE.TO_255)]) 44 | 45 | if cfg.MODEL.NAME == "PointLine": 46 | args['transform'] = Compose( 47 | [Resize(cfg.DATASETS.IMAGE.HEIGHT, 48 | cfg.DATASETS.IMAGE.WIDTH, 49 | cfg.DATASETS.TARGET.HEIGHT, 50 | cfg.DATASETS.TARGET.WIDTH), 51 | ToTensor()]) 52 | 53 | 54 | dataset = factory(**args) 55 | 56 | dataset = torch.utils.data.DataLoader(dataset, 57 | batch_size=cfg.SOLVER.IMS_PER_BATCH, 58 | collate_fn=train_dataset.collate_fn, 59 | shuffle = True, 60 | num_workers = cfg.DATALOADER.NUM_WORKERS) 61 | return dataset 62 | 63 | def build_test_dataset(cfg): 64 | transforms = Compose( 65 | [ResizeImage(cfg.DATASETS.IMAGE.HEIGHT, 66 | cfg.DATASETS.IMAGE.WIDTH), 67 | ToTensor(), 68 | Normalize(cfg.DATASETS.IMAGE.PIXEL_MEAN, 69 | cfg.DATASETS.IMAGE.PIXEL_STD, 70 | cfg.DATASETS.IMAGE.TO_255) 71 | ] 72 | ) 73 | 74 | if cfg.MODEL.NAME == "PointLine": 75 | transforms = Compose( 76 | [ResizeImage(cfg.DATASETS.IMAGE.HEIGHT, 77 | cfg.DATASETS.IMAGE.WIDTH), 78 | ToTensor() 79 | ] 80 | ) 81 | 82 | datasets = [] 83 | for name in cfg.DATASETS.TEST: 84 | dargs = DatasetCatalog.get(name) 85 | factory = getattr(test_dataset,dargs['factory']) 86 | args = dargs['args'] 87 | args['transform'] = transforms 88 | dataset = factory(**args) 89 | dataset = torch.utils.data.DataLoader( 90 | dataset, batch_size = 1, 91 | collate_fn = dataset.collate_fn, 92 | num_workers = cfg.DATALOADER.NUM_WORKERS, 93 | ) 94 | datasets.append((name,dataset)) 95 | return datasets 96 | -------------------------------------------------------------------------------- /hawp/fsl/dataset/imagelist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | try: 9 | import cv2 # pylint: disable=import-error 10 | except ImportError: 11 | cv2 = None 12 | 13 | import PIL 14 | try: 15 | import PIL.ImageGrab 16 | except ImportError: 17 | pass 18 | 19 | try: 20 | import mss 21 | except ImportError: 22 | mss = None 23 | 24 | import os 25 | import os.path as osp 26 | LOG = logging.getLogger(__name__) 27 | 28 | class ToTensor(object): 29 | def __call__(self,image): 30 | tensor = torch.from_numpy(image).float()/255.0 31 | tensor = tensor.permute((2,0,1)).contiguous() 32 | return tensor 33 | 34 | class Normalize(object): 35 | def __init__(self, mean, std): 36 | self.mean = mean 37 | self.std = std 38 | def __call__(self, image): 39 | image[0] = (image[0]-self.mean[0])/self.std[0] 40 | image[1] = (image[1]-self.mean[1])/self.std[1] 41 | image[2] = (image[2]-self.mean[2])/self.std[2] 42 | return image 43 | 44 | # pylint: disable=abstract-method 45 | class ImageList(torch.utils.data.Dataset): 46 | horizontal_flip = None 47 | rotate = None 48 | crop = None 49 | scale = 1.0 50 | start_frame = None 51 | start_msec = None 52 | max_frames = None 53 | 54 | def __init__(self, source, *, 55 | input_size, 56 | transforms, 57 | ): 58 | super().__init__() 59 | 60 | self.source = source 61 | self.input_size = input_size 62 | self.transforms = transforms 63 | 64 | self.filenames = sorted(os.listdir(source)) 65 | 66 | def __len__(self): 67 | return len(self.filenames) 68 | 69 | # pylint: disable=unsubscriptable-object 70 | def preprocessing(self, image): 71 | if self.scale != 1.0: 72 | image = cv2.resize(image, None, fx=self.scale, fy=self.scale) 73 | LOG.debug('resized image size: %s', image.shape) 74 | if self.horizontal_flip: 75 | image = image[:, ::-1] 76 | if self.crop: 77 | if self.crop[0]: 78 | image = image[:, self.crop[0]:] 79 | if self.crop[1]: 80 | image = image[self.crop[1]:, :] 81 | if self.crop[2]: 82 | image = image[:, :-self.crop[2]] 83 | if self.crop[3]: 84 | image = image[:-self.crop[3], :] 85 | if self.rotate == 'left': 86 | image = np.swapaxes(image, 0, 1) 87 | image = np.flip(image, axis=0) 88 | elif self.rotate == 'right': 89 | image = np.swapaxes(image, 0, 1) 90 | image = np.flip(image, axis=1) 91 | elif self.rotate == '180': 92 | image = np.flip(image, axis=0) 93 | image = np.flip(image, axis=1) 94 | 95 | meta = { 96 | 'width': image.shape[1], 97 | 'height': image.shape[0], 98 | } 99 | 100 | processed_image = self.transforms(image) 101 | 102 | return image, processed_image, meta 103 | 104 | def __getitem__(self, id): 105 | fname = osp.join(self.source,self.filenames[id]) 106 | 107 | image = cv2.imread(fname) 108 | meta = { 109 | 'width': image.shape[1], 110 | 'height': image.shape[0], 111 | } 112 | meta['frame_i'] = id 113 | meta['filename'] = '{:05d}'.format(id) 114 | processed_image = self.transforms(image) 115 | return image, processed_image, meta 116 | -------------------------------------------------------------------------------- /hawp/fsl/dataset/iteration_based_batch_sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import BatchSampler 2 | 3 | 4 | class IterationBasedBatchSampler(BatchSampler): 5 | """ 6 | Wraps a BatchSampler, resampling from it until 7 | a specified number of iterations have been sampled 8 | """ 9 | 10 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 11 | self.batch_sampler = batch_sampler 12 | self.num_iterations = num_iterations 13 | self.start_iter = start_iter 14 | 15 | def __iter__(self): 16 | iteration = self.start_iter 17 | while iteration <= self.num_iterations: 18 | # if the underlying sampler has a set_epoch method, like 19 | # DistributedSampler, used for making each process see 20 | # a different split of the dataset, then set it 21 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 22 | self.batch_sampler.sampler.set_epoch(iteration) 23 | for batch in self.batch_sampler: 24 | iteration += 1 25 | if iteration > self.num_iterations: 26 | break 27 | yield batch 28 | 29 | def __len__(self): 30 | return self.num_iterations -------------------------------------------------------------------------------- /hawp/fsl/dataset/test_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.utils.data.dataloader import default_collate 4 | import json 5 | import copy 6 | from PIL import Image 7 | from skimage import io 8 | import os 9 | import os.path as osp 10 | import numpy as np 11 | import cv2 12 | class TestDatasetWithAnnotations(Dataset): 13 | ''' 14 | Format of the annotation file 15 | annotations[i] has the following dict items: 16 | - filename # of the input image, str 17 | - height # of the input image, int 18 | - width # of the input image, int 19 | - lines # of the input image, list of list, N*4 20 | - junc # of the input image, list of list, M*2 21 | ''' 22 | 23 | def __init__(self, root, ann_file, transform = None): 24 | self.root = root 25 | with open(ann_file, 'r') as _: 26 | self.annotations = json.load(_) 27 | self.transform = transform 28 | 29 | def __len__(self): 30 | return len(self.annotations) 31 | 32 | def __getitem__(self, idx): 33 | ann = copy.deepcopy(self.annotations[idx]) 34 | # image = Image.open(osp.join(self.root,ann['filename'])).convert('RGB') 35 | # image = io.imread(osp.join(self.root,ann['filename']))#.astype(float)[:,:,:3] 36 | image = cv2.imread(osp.join(self.root,ann['filename']), cv2.IMREAD_GRAYSCALE).astype(float) 37 | # image = cv2.imread(osp.join(self.root,ann['filename']))[:,:,::-1] 38 | if len(image.shape) == 2: 39 | image = np.stack((image,image,image),axis=-1) 40 | 41 | image = image.astype(float)[:,:,:3] 42 | # image = io.imread(osp.join(self.root,ann['filename'])).astype(float)[:,:,:3] 43 | for key, _type in (['junc',np.float32], 44 | ['lines', np.float32]): 45 | ann[key] = np.array(ann[key],dtype=_type) 46 | 47 | if self.transform is not None: 48 | return self.transform(image,ann) 49 | return image, ann 50 | def image(self, idx): 51 | ann = copy.deepcopy(self.annotations[idx]) 52 | image = Image.open(osp.join(self.root,ann['filename'])).convert('RGB') 53 | return image 54 | @staticmethod 55 | def collate_fn(batch): 56 | return (default_collate([b[0] for b in batch]), 57 | [b[1] for b in batch]) -------------------------------------------------------------------------------- /hawp/fsl/dataset/train_dataset.bck.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | import os.path as osp 5 | import json 6 | import cv2 7 | from skimage import io 8 | from PIL import Image 9 | import numpy as np 10 | import random 11 | from torch.utils.data.dataloader import default_collate 12 | from torch.utils.data.dataloader import DataLoader 13 | import matplotlib.pyplot as plt 14 | from torchvision.transforms import functional as F 15 | import copy 16 | class TrainDataset(Dataset): 17 | def __init__(self, root, ann_file, transform = None): 18 | self.root = root 19 | with open(ann_file,'r') as _: 20 | self.annotations = json.load(_) 21 | self.transform = transform 22 | 23 | def __getitem__(self, idx_): 24 | # print(idx_) 25 | idx = idx_%len(self.annotations) 26 | reminder = idx_//len(self.annotations) 27 | # idx = 0 28 | # reminder = 0 29 | ann = copy.deepcopy(self.annotations[idx]) 30 | ann['reminder'] = reminder 31 | # image = io.imread(osp.join(self.root,ann['filename']))#.astype(float)[:,:,:3] 32 | # if len(image.shape) == 2: 33 | # image = np.stack((image,image,image),axis=-1) 34 | 35 | # image = image.astype(float)[:,:,:3] 36 | 37 | image = io.imread(osp.join(self.root,ann['filename']), as_gray=True).astype(float)#[:,:,:3] 38 | if len(image.shape) == 2: 39 | image = np.concatenate([image[...,None],image[...,None],image[...,None]],axis=-1) 40 | else: 41 | image = image[:,:,:3] 42 | 43 | # image = Image.open(osp.join(self.root,ann['filename'])).convert('RGB') 44 | for key,_type in (['junctions',np.float32], 45 | ['edges_positive',np.int32], 46 | ['edges_negative',np.int32]): 47 | ann[key] = np.array(ann[key],dtype=_type) 48 | 49 | width = ann['width'] 50 | height = ann['height'] 51 | if reminder == 1: 52 | image = image[:,::-1,:] 53 | # image = F.hflip(image) 54 | ann['junctions'][:,0] = width-ann['junctions'][:,0] 55 | elif reminder == 2: 56 | # image = F.vflip(image) 57 | image = image[::-1,:,:] 58 | ann['junctions'][:,1] = height-ann['junctions'][:,1] 59 | elif reminder == 3: 60 | # image = F.vflip(F.hflip(image)) 61 | image = image[::-1,::-1,:] 62 | ann['junctions'][:,0] = width-ann['junctions'][:,0] 63 | ann['junctions'][:,1] = height-ann['junctions'][:,1] 64 | else: 65 | pass 66 | 67 | if self.transform is not None: 68 | return self.transform(image,ann) 69 | return image, ann 70 | 71 | def __len__(self): 72 | return len(self.annotations)*4 73 | # return 1000 74 | 75 | def collate_fn(batch): 76 | return (default_collate([b[0] for b in batch]), 77 | [b[1] for b in batch]) -------------------------------------------------------------------------------- /hawp/fsl/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /hawp/fsl/model/build.py: -------------------------------------------------------------------------------- 1 | from . import models 2 | 3 | 4 | def build_model(cfg): 5 | model = models.WireframeDetector(cfg) 6 | 7 | return model -------------------------------------------------------------------------------- /hawp/fsl/model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | def cross_entropy_loss_for_junction(logits, positive): 4 | nlogp = -F.log_softmax(logits, dim=1) 5 | 6 | loss = (positive * nlogp[:, None, 1] + (1 - positive) * nlogp[:, None, 0]) 7 | 8 | return loss.mean() 9 | 10 | def focal_loss_for_junction(logits, positive, gamma=2.0): 11 | prob = F.softmax(logits, 1) 12 | ce_loss = F.cross_entropy(logits, positive, reduction='none') 13 | p_t = prob[:,1:]*positive + prob[:,:1]*(1-positive) 14 | loss = ce_loss * ((1-p_t)**gamma) 15 | 16 | return loss.mean() 17 | 18 | def sigmoid_l1_loss(logits, targets, offset = 0.0, mask=None): 19 | logp = torch.sigmoid(logits) + offset 20 | loss = torch.abs(logp-targets) 21 | 22 | if mask is not None: 23 | w = mask.mean(3, True).mean(2,True) 24 | w[w==0] = 1 25 | loss = loss*(mask/w) 26 | 27 | return loss.mean() 28 | 29 | 30 | def sigmoid_focal_loss( 31 | inputs: torch.Tensor, 32 | targets: torch.Tensor, 33 | alpha: float = -1, 34 | gamma: float = 2, 35 | reduction: str = "none", 36 | ) -> torch.Tensor: 37 | """ 38 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 39 | Args: 40 | inputs: A float tensor of arbitrary shape. 41 | The predictions for each example. 42 | targets: A float tensor with the same shape as inputs. Stores the binary 43 | classification label for each element in inputs 44 | (0 for the negative class and 1 for the positive class). 45 | alpha: (optional) Weighting factor in range (0,1) to balance 46 | positive vs negative examples. Default = -1 (no weighting). 47 | gamma: Exponent of the modulating factor (1 - p_t) to 48 | balance easy vs hard examples. 49 | reduction: 'none' | 'mean' | 'sum' 50 | 'none': No reduction will be applied to the output. 51 | 'mean': The output will be averaged. 52 | 'sum': The output will be summed. 53 | Returns: 54 | Loss tensor with the reduction option applied. 55 | """ 56 | inputs = inputs.float() 57 | targets = targets.float() 58 | p = torch.sigmoid(inputs) 59 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 60 | p_t = p * targets + (1 - p) * (1 - targets) 61 | loss = ce_loss * ((1 - p_t) ** gamma) 62 | 63 | if alpha >= 0: 64 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 65 | loss = alpha_t * loss 66 | 67 | if reduction == "mean": 68 | loss = loss.mean() 69 | elif reduction == "sum": 70 | loss = loss.sum() 71 | 72 | return loss -------------------------------------------------------------------------------- /hawp/fsl/model/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | def non_maximum_suppression(a): 7 | ap = F.max_pool2d(a, 3, stride=1, padding=1) 8 | mask = (a == ap).float().clamp(min=0.0) 9 | 10 | return a * mask 11 | 12 | def get_junctions(jloc, joff, topk = 300, th=0): 13 | height, width = jloc.size(1), jloc.size(2) 14 | jloc = jloc.reshape(-1) 15 | joff = joff.reshape(2, -1) 16 | 17 | 18 | scores, index = torch.topk(jloc, k=topk) 19 | # y = (index // width).float() + torch.gather(joff[1], 0, index) + 0.5 20 | y = torch.div(index,width,rounding_mode='trunc').float()+ torch.gather(joff[1], 0, index) + 0.5 21 | x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5 22 | 23 | junctions = torch.stack((x, y)).t() 24 | 25 | return junctions[scores>th], scores[scores>th] 26 | 27 | def plot_lines(lines, scale=1.0, color = 'red', **kwargs): 28 | if isinstance(lines, np.ndarray): 29 | plt.plot([lines[:,0]*scale,lines[:,2]*scale],[lines[:,1]*scale,lines[:,3]*scale],color=color,linestyle='-') 30 | else: 31 | lines_np = lines.detach().cpu().numpy() 32 | plt.plot([lines_np[:,0]*scale,lines_np[:,2]*scale],[lines_np[:,1]*scale,lines_np[:,3]*scale],color=color,linestyle='-') 33 | 34 | -------------------------------------------------------------------------------- /hawp/fsl/point_model/point_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/hawp/fsl/point_model/point_model.pth -------------------------------------------------------------------------------- /hawp/fsl/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def make_optimizer(cfg, model, loss_reducer = None): 4 | 5 | params = [] 6 | 7 | for key, value in model.named_parameters(): 8 | if not value.requires_grad: 9 | continue 10 | 11 | lr=cfg.SOLVER.BASE_LR 12 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 13 | # if 'md_predictor' in key or 'st_predictor' in key or 'ed_predictor' in key: 14 | # lr = cfg.SOLVER.BASE_LR*100.0 15 | 16 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 17 | 18 | if loss_reducer is not None: 19 | for key, value in loss_reducer.named_parameters(): 20 | params += [{"params": [value], "lr": lr, "weight_decay": 0}] 21 | 22 | 23 | if cfg.SOLVER.OPTIMIZER == 'SGD': 24 | optimizer = torch.optim.SGD(params, 25 | cfg.SOLVER.BASE_LR, 26 | momentum=cfg.SOLVER.MOMENTUM, 27 | weight_decay=cfg.SOLVER.WEIGHT_DECAY) 28 | elif cfg.SOLVER.OPTIMIZER == 'ADAM': 29 | optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR, 30 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 31 | amsgrad=cfg.SOLVER.AMSGRAD) 32 | elif cfg.SOLVER.OPTIMIZER == 'ADAMW': 33 | optimizer = torch.optim.AdamW(params, 34 | cfg.SOLVER.BASE_LR,betas=(0.9,0.999), 35 | eps=1e-8, 36 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 37 | amsgrad=cfg.SOLVER.AMSGRAD, 38 | ) 39 | else: 40 | raise NotImplementedError() 41 | return optimizer 42 | 43 | def make_lr_scheduler(cfg,optimizer): 44 | return torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=cfg.SOLVER.STEPS,gamma=cfg.SOLVER.GAMMA) -------------------------------------------------------------------------------- /hawp/ssl/__init__.py: -------------------------------------------------------------------------------- 1 | from . import config, datasets, models 2 | -------------------------------------------------------------------------------- /hawp/ssl/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .project_config import Config 2 | from .utils import * -------------------------------------------------------------------------------- /hawp/ssl/config/exports/wireframe-100iters.yaml: -------------------------------------------------------------------------------- 1 | ### General dataset parameters 2 | dataset_name: "wireframe" 3 | add_augmentation_to_all_splits: False 4 | gray_scale: True 5 | # Ground truth source ('official' or path to the exported h5 dataset.) 6 | # gt_source_train: "" # Fill with your own export file 7 | # gt_source_test: "" # Fill with your own export file 8 | # Return type: (1) single (to train the detector only) 9 | # or (2) paired_desc (to train the detector + descriptor) 10 | return_type: "single" 11 | random_seed: 0 12 | 13 | ### Descriptor training parameters 14 | # Number of points extracted per line 15 | max_num_samples: 10 16 | # Max number of training line points extracted in the whole image 17 | max_pts: 1000 18 | # Min distance between two points on a line (in pixels) 19 | min_dist_pts: 10 20 | # Small jittering of the sampled points during training 21 | jittering: 0 22 | 23 | ### Data preprocessing configuration 24 | preprocessing: 25 | resize: [512, 512] 26 | blur_size: 11 27 | augmentation: 28 | random_scaling: 29 | enable: True 30 | range: [0.7, 1.5] 31 | photometric: 32 | enable: true 33 | primitives: ['random_brightness', 'random_contrast', 34 | 'additive_speckle_noise', 'additive_gaussian_noise', 35 | 'additive_shade', 'motion_blur' ] 36 | params: 37 | random_brightness: {brightness: 0.2} 38 | random_contrast: {contrast: [0.3, 1.5]} 39 | additive_gaussian_noise: {stddev_range: [0, 10]} 40 | additive_speckle_noise: {prob_range: [0, 0.0035]} 41 | additive_shade: 42 | transparency_range: [-0.5, 0.5] 43 | kernel_size_range: [100, 150] 44 | motion_blur: {max_kernel_size: 3} 45 | random_order: True 46 | homographic: 47 | enable: true 48 | params: 49 | translation: true 50 | rotation: true 51 | scaling: true 52 | perspective: true 53 | scaling_amplitude: 0.2 54 | perspective_amplitude_x: 0.2 55 | perspective_amplitude_y: 0.2 56 | patch_ratio: 0.85 57 | max_angle: 1.57 58 | allow_artifacts: true 59 | valid_border_margin: 3 60 | 61 | ## Homography adaptation configuration 62 | homography_adaptation: 63 | num_iter: 10 64 | valid_border_margin: 3 65 | min_counts: 30 66 | homographies: 67 | translation: true 68 | rotation: true 69 | scaling: true 70 | perspective: true 71 | scaling_amplitude: 0.2 72 | perspective_amplitude_x: 0.2 73 | perspective_amplitude_y: 0.2 74 | allow_artifacts: true 75 | patch_ratio: 0.85 -------------------------------------------------------------------------------- /hawp/ssl/config/exports/wireframe-10iters.yaml: -------------------------------------------------------------------------------- 1 | ### General dataset parameters 2 | dataset_name: "wireframe" 3 | add_augmentation_to_all_splits: False 4 | gray_scale: True 5 | # Ground truth source ('official' or path to the exported h5 dataset.) 6 | # gt_source_train: "" # Fill with your own export file 7 | # gt_source_test: "" # Fill with your own export file 8 | # Return type: (1) single (to train the detector only) 9 | # or (2) paired_desc (to train the detector + descriptor) 10 | return_type: "single" 11 | random_seed: 0 12 | 13 | ### Descriptor training parameters 14 | # Number of points extracted per line 15 | max_num_samples: 10 16 | # Max number of training line points extracted in the whole image 17 | max_pts: 1000 18 | # Min distance between two points on a line (in pixels) 19 | min_dist_pts: 10 20 | # Small jittering of the sampled points during training 21 | jittering: 0 22 | 23 | ### Data preprocessing configuration 24 | preprocessing: 25 | resize: [512, 512] 26 | blur_size: 11 27 | augmentation: 28 | random_scaling: 29 | enable: True 30 | range: [0.7, 1.5] 31 | photometric: 32 | enable: true 33 | primitives: ['random_brightness', 'random_contrast', 34 | 'additive_speckle_noise', 'additive_gaussian_noise', 35 | 'additive_shade', 'motion_blur' ] 36 | params: 37 | random_brightness: {brightness: 0.2} 38 | random_contrast: {contrast: [0.3, 1.5]} 39 | additive_gaussian_noise: {stddev_range: [0, 10]} 40 | additive_speckle_noise: {prob_range: [0, 0.0035]} 41 | additive_shade: 42 | transparency_range: [-0.5, 0.5] 43 | kernel_size_range: [100, 150] 44 | motion_blur: {max_kernel_size: 3} 45 | random_order: True 46 | homographic: 47 | enable: true 48 | params: 49 | translation: true 50 | rotation: true 51 | scaling: true 52 | perspective: true 53 | scaling_amplitude: 0.2 54 | perspective_amplitude_x: 0.2 55 | perspective_amplitude_y: 0.2 56 | patch_ratio: 0.85 57 | max_angle: 1.57 58 | allow_artifacts: true 59 | valid_border_margin: 3 60 | 61 | ## Homography adaptation configuration 62 | homography_adaptation: 63 | num_iter: 10 64 | valid_border_margin: 3 65 | min_counts: 3 66 | homographies: 67 | translation: true 68 | rotation: true 69 | scaling: true 70 | perspective: true 71 | scaling_amplitude: 0.2 72 | perspective_amplitude_x: 0.2 73 | perspective_amplitude_y: 0.2 74 | allow_artifacts: true 75 | patch_ratio: 0.85 -------------------------------------------------------------------------------- /hawp/ssl/config/hawpv3-hrheat.yaml: -------------------------------------------------------------------------------- 1 | ENCODER: 2 | ANG_TH: 0.0 3 | BACKGROUND_WEIGHT: 0.0 4 | DIS_TH: 2 5 | NUM_STATIC_NEG_LINES: 0 6 | NUM_STATIC_POS_LINES: 300 7 | MODEL: 8 | DEVICE: cuda 9 | USE_LINE_HEATMAP: True 10 | USE_HR_JUNCTION: True 11 | HEAD_SIZE: 12 | - - 3 13 | - - 1 14 | - - 1 15 | - - 2 16 | - - 2 17 | HGNETS: 18 | DEPTH: 4 19 | INPLANES: 64 20 | NUM_BLOCKS: 1 21 | NUM_FEATS: 128 22 | NUM_STACKS: 2 23 | LOI_POOLING: 24 | ACTIVATION: relu 25 | DIM_EDGE_FEATURE: 4 26 | DIM_FC: 1024 27 | DIM_JUNCTION_FEATURE: 128 28 | LAYER_NORM: false 29 | NUM_POINTS: 32 30 | TYPE: softmax 31 | LOSS_WEIGHTS: 32 | loss_aux: 1.0 33 | loss_dis: 1.0 34 | loss_jloc: 8.0 35 | loss_joff: 0.25 36 | # loss_joff: 0.0 37 | loss_johr: 0.25 38 | loss_lineness: 1.0 39 | loss_md: 1.0 40 | loss_neg: 1.0 41 | loss_pos: 1.0 42 | loss_res: 1.0 43 | loss_heatmap: 1.0 44 | loss_jmap: 1.0 45 | NAME: Hourglass 46 | OUT_FEATURE_CHANNELS: 256 47 | PARSING_HEAD: 48 | DIM_FC: 1024 49 | DIM_LOI: 128 50 | J2L_THRESHOLD: 10.0 51 | JMATCH_THRESHOLD: 1.5 52 | MATCHING_STRATEGY: junction 53 | MAX_DISTANCE: 5.0 54 | N_DYN_JUNC: 300 55 | N_DYN_NEGL: 40 56 | N_DYN_OTHR: 0 57 | N_DYN_OTHR2: 300 58 | N_DYN_POSL: 300 59 | N_OUT_JUNC: 250 60 | N_OUT_LINE: 2500 61 | N_PTS0: 32 62 | N_PTS1: 8 63 | N_STC_NEGL: 40 64 | N_STC_POSL: 300 65 | USE_RESIDUAL: 1 66 | RESNETS: 67 | BASENET: resnet50 68 | PRETRAIN: true 69 | SCALE: 1.0 70 | WEIGHTS: '' 71 | MODELING_PATH: ihawp-v2 72 | OUTPUT_DIR: output/ihawp 73 | SOLVER: 74 | AMSGRAD: true 75 | BACKBONE_LR_FACTOR: 1.0 76 | BASE_LR: 0.00004 77 | BIAS_LR_FACTOR: 1 78 | CHECKPOINT_PERIOD: 1 79 | GAMMA: 0.1 80 | IMS_PER_BATCH: 6 81 | MAX_EPOCH: 30 82 | MOMENTUM: 0.9 83 | OPTIMIZER: ADAM 84 | STEPS: 85 | - 25 86 | WEIGHT_DECAY: 0.0001 87 | WEIGHT_DECAY_BIAS: 0 88 | -------------------------------------------------------------------------------- /hawp/ssl/config/hawpv3.yaml: -------------------------------------------------------------------------------- 1 | ENCODER: 2 | ANG_TH: 0.0 3 | BACKGROUND_WEIGHT: 0.0 4 | DIS_TH: 2 5 | NUM_STATIC_NEG_LINES: 0 6 | NUM_STATIC_POS_LINES: 300 7 | MODEL: 8 | DEVICE: cuda 9 | USE_LINE_HEATMAP: True 10 | HEAD_SIZE: 11 | - - 3 12 | - - 1 13 | - - 1 14 | - - 2 15 | - - 2 16 | HGNETS: 17 | DEPTH: 4 18 | INPLANES: 64 19 | NUM_BLOCKS: 1 20 | NUM_FEATS: 128 21 | NUM_STACKS: 2 22 | LOI_POOLING: 23 | ACTIVATION: relu 24 | DIM_EDGE_FEATURE: 4 25 | DIM_FC: 1024 26 | DIM_JUNCTION_FEATURE: 128 27 | LAYER_NORM: false 28 | NUM_POINTS: 32 29 | TYPE: softmax 30 | LOSS_WEIGHTS: 31 | loss_aux: 1.0 32 | loss_dis: 1.0 33 | loss_jloc: 8.0 34 | loss_joff: 0.25 35 | # loss_joff: 0.0 36 | loss_lineness: 1.0 37 | loss_md: 1.0 38 | loss_neg: 1.0 39 | loss_pos: 1.0 40 | loss_res: 1.0 41 | loss_heatmap: 1.0 42 | NAME: Hourglass 43 | OUT_FEATURE_CHANNELS: 256 44 | PARSING_HEAD: 45 | DIM_FC: 1024 46 | DIM_LOI: 128 47 | J2L_THRESHOLD: 10.0 48 | JMATCH_THRESHOLD: 1.5 49 | MATCHING_STRATEGY: junction 50 | MAX_DISTANCE: 5.0 51 | N_DYN_JUNC: 300 52 | N_DYN_NEGL: 40 53 | N_DYN_OTHR: 0 54 | N_DYN_OTHR2: 300 55 | N_DYN_POSL: 300 56 | N_OUT_JUNC: 250 57 | N_OUT_LINE: 2500 58 | N_PTS0: 32 59 | N_PTS1: 8 60 | N_STC_NEGL: 40 61 | N_STC_POSL: 300 62 | USE_RESIDUAL: 1 63 | RESNETS: 64 | BASENET: resnet50 65 | PRETRAIN: true 66 | SCALE: 1.0 67 | WEIGHTS: '' 68 | MODELING_PATH: ihawp-v2 69 | OUTPUT_DIR: output/ihawp 70 | SOLVER: 71 | AMSGRAD: true 72 | BACKBONE_LR_FACTOR: 1.0 73 | BASE_LR: 0.00004 74 | BIAS_LR_FACTOR: 1 75 | CHECKPOINT_PERIOD: 1 76 | GAMMA: 0.1 77 | IMS_PER_BATCH: 6 78 | MAX_EPOCH: 30 79 | MOMENTUM: 0.9 80 | OPTIMIZER: ADAM 81 | STEPS: 82 | - 25 83 | WEIGHT_DECAY: 0.0001 84 | WEIGHT_DECAY_BIAS: 0 85 | -------------------------------------------------------------------------------- /hawp/ssl/config/project_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Project configurations. 3 | """ 4 | import os 5 | 6 | 7 | class Config(object): 8 | """ Datasets and experiments folders for the whole project. """ 9 | ##################### 10 | ## Dataset setting ## 11 | ##################### 12 | default_dataroot = os.path.join( 13 | os.path.dirname(__file__), 14 | '..','..','..','data-ssl' 15 | ) 16 | default_dataroot = os.path.abspath(default_dataroot) 17 | default_exproot = os.path.join( 18 | os.path.dirname(__file__), 19 | '..','..','..','exp-ssl' 20 | ) 21 | default_exproot = os.path.abspath(default_exproot) 22 | 23 | DATASET_ROOT = os.getenv("DATASET_ROOT", default_dataroot) # TODO: path to your datasets folder 24 | if not os.path.exists(DATASET_ROOT): 25 | os.makedirs(DATASET_ROOT) 26 | 27 | # Synthetic shape dataset 28 | synthetic_dataroot = os.path.join(DATASET_ROOT, "synthetic_shapes") 29 | synthetic_cache_path = os.path.join(DATASET_ROOT, "synthetic_shapes") 30 | if not os.path.exists(synthetic_dataroot): 31 | os.makedirs(synthetic_dataroot) 32 | 33 | # Exported predictions dataset 34 | export_dataroot = os.path.join(DATASET_ROOT, "export_datasets") 35 | export_cache_path = os.path.join(DATASET_ROOT, "export_datasets") 36 | if not os.path.exists(export_dataroot): 37 | os.makedirs(export_dataroot) 38 | 39 | # York Urban dataset 40 | yorkurban_dataroot = os.path.join(DATASET_ROOT, "YorkUrbanDB") 41 | yorkurban_cache_path = os.path.join(DATASET_ROOT, "YorkUrbanDB") 42 | 43 | # Wireframe dataset 44 | wireframe_dataroot = os.path.join(DATASET_ROOT, "wireframe") 45 | wireframe_cache_path = os.path.join(DATASET_ROOT, "wireframe") 46 | 47 | # Holicity dataset 48 | holicity_dataroot = os.path.join(DATASET_ROOT, "Holicity") 49 | holicity_cache_path = os.path.join(DATASET_ROOT, "Holicity") 50 | 51 | ######################## 52 | ## Experiment Setting ## 53 | ######################## 54 | EXP_PATH = os.getenv("EXP_PATH", default_exproot) # TODO: path to your experiments folder 55 | 56 | if not os.path.exists(EXP_PATH): 57 | os.makedirs(EXP_PATH) 58 | -------------------------------------------------------------------------------- /hawp/ssl/config/synthetic_dataset-4k.yaml: -------------------------------------------------------------------------------- 1 | ### General dataset parameters 2 | dataset_name: "synthetic_shape" 3 | primitives: "all" 4 | add_augmentation_to_all_splits: True 5 | test_augmentation_seed: 200 6 | alias: 4k 7 | # Shape generation configuration 8 | generation: 9 | # split_sizes: {'train': 20000, 'val': 2000, 'test': 400} 10 | split_sizes: {'train': 4000, 'val': 2000, 'test': 400} 11 | random_seed: 10 12 | image_size: [960, 1280] 13 | min_len: 0.0985 14 | min_label_len: 0.099 15 | params: 16 | generate_background: 17 | min_kernel_size: 150 18 | max_kernel_size: 500 19 | min_rad_ratio: 0.02 20 | max_rad_ratio: 0.031 21 | draw_stripes: 22 | transform_params: [0.1, 0.1] 23 | draw_multiple_polygons: 24 | kernel_boundaries: [50, 100] 25 | 26 | ### Data preprocessing configuration. 27 | preprocessing: 28 | resize: [512, 512] 29 | blur_size: 11 30 | augmentation: 31 | photometric: 32 | enable: True 33 | primitives: 'all' 34 | params: {} 35 | random_order: True 36 | homographic: 37 | enable: True 38 | params: 39 | translation: true 40 | rotation: true 41 | scaling: true 42 | perspective: true 43 | scaling_amplitude: 0.2 44 | perspective_amplitude_x: 0.2 45 | perspective_amplitude_y: 0.2 46 | patch_ratio: 0.8 47 | max_angle: 1.57 48 | allow_artifacts: true 49 | translation_overflow: 0.05 50 | valid_border_margin: 0 51 | -------------------------------------------------------------------------------- /hawp/ssl/config/synthetic_dataset.yaml: -------------------------------------------------------------------------------- 1 | ### General dataset parameters 2 | dataset_name: "synthetic_shape" 3 | primitives: "all" 4 | add_augmentation_to_all_splits: True 5 | test_augmentation_seed: 200 6 | # Shape generation configuration 7 | generation: 8 | # split_sizes: {'train': 20000, 'val': 2000, 'test': 400} 9 | split_sizes: {'train': 2000, 'val': 2000, 'test': 400} 10 | random_seed: 10 11 | image_size: [960, 1280] 12 | min_len: 0.0985 13 | min_label_len: 0.099 14 | params: 15 | generate_background: 16 | min_kernel_size: 150 17 | max_kernel_size: 500 18 | min_rad_ratio: 0.02 19 | max_rad_ratio: 0.031 20 | draw_stripes: 21 | transform_params: [0.1, 0.1] 22 | draw_multiple_polygons: 23 | kernel_boundaries: [50, 100] 24 | 25 | ### Data preprocessing configuration. 26 | preprocessing: 27 | resize: [512, 512] 28 | blur_size: 11 29 | augmentation: 30 | photometric: 31 | enable: True 32 | primitives: 'all' 33 | params: {} 34 | random_order: True 35 | homographic: 36 | enable: True 37 | params: 38 | translation: true 39 | rotation: true 40 | scaling: true 41 | perspective: true 42 | scaling_amplitude: 0.2 43 | perspective_amplitude_x: 0.2 44 | perspective_amplitude_y: 0.2 45 | patch_ratio: 0.8 46 | max_angle: 1.57 47 | allow_artifacts: true 48 | translation_overflow: 0.05 49 | valid_border_margin: 0 50 | -------------------------------------------------------------------------------- /hawp/ssl/config/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | def load_config(config_path): 5 | """ Load configurations from a given yaml file. """ 6 | # Check file exists 7 | if not os.path.exists(config_path): 8 | raise ValueError("[Error] The provided config path is not valid.") 9 | 10 | # Load the configuration 11 | with open(config_path, "r") as f: 12 | config = yaml.safe_load(f) 13 | 14 | return config 15 | 16 | def update_config(path, model_cfg=None, dataset_cfg=None): 17 | """ Update configuration file from the resume path. """ 18 | # Check we need to update or completely override. 19 | model_cfg = {} if model_cfg is None else model_cfg 20 | dataset_cfg = {} if dataset_cfg is None else dataset_cfg 21 | 22 | # Load saved configs 23 | with open(os.path.join(path, "model_cfg.yaml"), "r") as f: 24 | model_cfg_saved = yaml.safe_load(f) 25 | model_cfg.update(model_cfg_saved) 26 | with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f: 27 | dataset_cfg_saved = yaml.safe_load(f) 28 | dataset_cfg.update(dataset_cfg_saved) 29 | 30 | # Update the saved yaml file 31 | if not model_cfg == model_cfg_saved: 32 | with open(os.path.join(path, "model_cfg.yaml"), "w") as f: 33 | yaml.dump(model_cfg, f) 34 | if not dataset_cfg == dataset_cfg_saved: 35 | with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f: 36 | yaml.dump(dataset_cfg, f) 37 | 38 | return model_cfg, dataset_cfg 39 | 40 | def record_config(model_cfg, dataset_cfg, output_path): 41 | """ Record dataset config to the log path. """ 42 | # Record model config 43 | with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f: 44 | yaml.safe_dump(model_cfg, f) 45 | 46 | # Record dataset config 47 | with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f: 48 | yaml.safe_dump(dataset_cfg, f) -------------------------------------------------------------------------------- /hawp/ssl/config/wireframe_official_gt.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: "wireframe" 2 | add_augmentation_to_all_splits: False 3 | gray_scale: True 4 | # Ground truth source (official or path to the epxorted h5 dataset.) 5 | gt_source_train: "official" 6 | gt_source_test: "official" 7 | # Date preprocessing configuration. 8 | preprocessing: 9 | resize: [512, 512] 10 | blur_size: 11 11 | augmentation: 12 | photometric: 13 | enable: True 14 | homographic: 15 | enable: True 16 | # The homography adaptation configuration 17 | homography_adaptation: 18 | num_iter: 100 19 | aggregation: 'sum' 20 | mode: 'ver1' 21 | valid_border_margin: 3 22 | min_counts: 30 23 | homographies: 24 | translation: true 25 | rotation: true 26 | scaling: true 27 | perspective: true 28 | scaling_amplitude: 0.2 29 | perspective_amplitude_x: 0.2 30 | perspective_amplitude_y: 0.2 31 | allow_artifacts: true 32 | patch_ratio: 0.85 33 | # Evaluation related config 34 | evaluation: 35 | repeatability: 36 | # Initial random seed used to sample homographic augmentation 37 | seed: 200 38 | # Parameter used to sample illumination change evaluation set. 39 | photometric: 40 | enable: False 41 | # Parameter used to sample viewpoint change evaluation set. 42 | homographic: 43 | enable: True 44 | num_samples: 2 45 | params: 46 | translation: true 47 | rotation: true 48 | scaling: true 49 | perspective: true 50 | scaling_amplitude: 0.2 51 | perspective_amplitude_x: 0.2 52 | perspective_amplitude_y: 0.2 53 | patch_ratio: 0.85 54 | max_angle: 1.57 55 | allow_artifacts: true 56 | valid_border_margin: 3 -------------------------------------------------------------------------------- /hawp/ssl/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/hawp/ssl/datasets/__init__.py -------------------------------------------------------------------------------- /hawp/ssl/datasets/dataset_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | The interface of initializing different datasets. 3 | """ 4 | from .synthetic_dataset import SyntheticShapes,synthetic_collate_fn 5 | from .wireframe_dataset import WireframeDataset,wireframe_collate_fn 6 | from .yorkurban_dataset import YorkUrbanDataset,yorkurban_collate_fn 7 | from .images_dataset import ImageCollections, images_collate_fn 8 | # from .holicity_dataset import HolicityDataset 9 | # from .merge_dataset import MergeDataset 10 | 11 | 12 | def get_dataset(mode="train", dataset_cfg=None): 13 | """ Initialize different dataset based on a configuration. """ 14 | # Check dataset config is given 15 | if dataset_cfg is None: 16 | raise ValueError("[Error] The dataset config is required!") 17 | 18 | # Synthetic dataset 19 | if dataset_cfg["dataset_name"] == "synthetic_shape": 20 | dataset = SyntheticShapes( 21 | mode, dataset_cfg 22 | ) 23 | # Get the collate_fn 24 | # from sold2.dataset.synthetic_dataset import synthetic_collate_fn 25 | collate_fn = synthetic_collate_fn 26 | 27 | # Wireframe dataset 28 | elif dataset_cfg["dataset_name"] == "wireframe": 29 | dataset = WireframeDataset( 30 | mode, dataset_cfg 31 | ) 32 | 33 | # Get the collate_fn 34 | collate_fn = wireframe_collate_fn 35 | elif dataset_cfg["dataset_name"] == "yorkurban": 36 | dataset = YorkUrbanDataset( 37 | mode, dataset_cfg 38 | ) 39 | 40 | # Get the collate_fn 41 | collate_fn = yorkurban_collate_fn 42 | # Holicity dataset 43 | elif dataset_cfg["dataset_name"] == "holicity": 44 | dataset = HolicityDataset( 45 | mode, dataset_cfg 46 | ) 47 | 48 | # Get the collate_fn 49 | from sold2.dataset.holicity_dataset import holicity_collate_fn 50 | collate_fn = holicity_collate_fn 51 | 52 | # Dataset merging several datasets in one 53 | elif dataset_cfg["dataset_name"] == "merge": 54 | dataset = MergeDataset( 55 | mode, dataset_cfg 56 | ) 57 | 58 | # Get the collate_fn 59 | from sold2.dataset.holicity_dataset import holicity_collate_fn 60 | collate_fn = holicity_collate_fn 61 | elif dataset_cfg["dataset_name"] == "general": 62 | dataset = ImageCollections(mode, dataset_cfg) 63 | collate_fn = images_collate_fn 64 | else: 65 | raise ValueError( 66 | "[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"]) 67 | 68 | return dataset, collate_fn 69 | -------------------------------------------------------------------------------- /hawp/ssl/datasets/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/hawp/ssl/datasets/transforms/__init__.py -------------------------------------------------------------------------------- /hawp/ssl/datasets/wireframe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | import os.path as osp 5 | import json 6 | import cv2 7 | from skimage import io 8 | from PIL import Image 9 | import numpy as np 10 | import random 11 | from torch.utils.data.dataloader import default_collate 12 | from torch.utils.data.dataloader import DataLoader 13 | import matplotlib.pyplot as plt 14 | from torchvision.transforms import functional as F 15 | import copy 16 | 17 | 18 | # class WireframeDataset(Dataset): 19 | # def __init__(self, ) -------------------------------------------------------------------------------- /hawp/ssl/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/hawp/ssl/misc/__init__.py -------------------------------------------------------------------------------- /hawp/ssl/misc/geometry_utils.old.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | ### Point-related utils 6 | 7 | # Warp a list of points using a homography 8 | def warp_points(points, homography): 9 | # Convert to homogeneous and in xy format 10 | new_points = np.concatenate([points[..., [1, 0]], 11 | np.ones_like(points[..., :1])], axis=-1) 12 | # Warp 13 | new_points = (homography @ new_points.T).T 14 | # Convert back to inhomogeneous and hw format 15 | new_points = new_points[..., [1, 0]] / new_points[..., 2:] 16 | return new_points 17 | 18 | 19 | # Mask out the points that are outside of img_size 20 | def mask_points(points, img_size): 21 | mask = ((points[..., 0] >= 0) 22 | & (points[..., 0] < img_size[0]) 23 | & (points[..., 1] >= 0) 24 | & (points[..., 1] < img_size[1])) 25 | return mask 26 | 27 | 28 | # Convert a tensor [N, 2] or batched tensor [B, N, 2] of N keypoints into 29 | # a grid in [-1, 1]² that can be used in torch.nn.functional.interpolate 30 | def keypoints_to_grid(keypoints, img_size): 31 | n_points = keypoints.size()[-2] 32 | device = keypoints.device 33 | grid_points = keypoints.float() * 2. / torch.tensor( 34 | img_size, dtype=torch.float, device=device) - 1. 35 | grid_points = grid_points[..., [1, 0]].view(-1, n_points, 1, 2) 36 | return grid_points 37 | 38 | 39 | # Return a 2D matrix indicating the local neighborhood of each point 40 | # for a given threshold and two lists of corresponding keypoints 41 | def get_dist_mask(kp0, kp1, valid_mask, dist_thresh): 42 | b_size, n_points, _ = kp0.size() 43 | dist_mask0 = torch.norm(kp0.unsqueeze(2) - kp0.unsqueeze(1), dim=-1) 44 | dist_mask1 = torch.norm(kp1.unsqueeze(2) - kp1.unsqueeze(1), dim=-1) 45 | dist_mask = torch.min(dist_mask0, dist_mask1) 46 | dist_mask = dist_mask <= dist_thresh 47 | dist_mask = dist_mask.repeat(1, 1, b_size).reshape(b_size * n_points, 48 | b_size * n_points) 49 | dist_mask = dist_mask[valid_mask, :][:, valid_mask] 50 | return dist_mask 51 | 52 | 53 | ### Line-related utils 54 | 55 | # Sample n points along lines of shape (num_lines, 2, 2) 56 | def sample_line_points(lines, n): 57 | line_points_x = np.linspace(lines[:, 0, 0], lines[:, 1, 0], n, axis=-1) 58 | line_points_y = np.linspace(lines[:, 0, 1], lines[:, 1, 1], n, axis=-1) 59 | line_points = np.stack([line_points_x, line_points_y], axis=2) 60 | return line_points 61 | 62 | 63 | # Return a mask of the valid lines that are within a valid mask of an image 64 | def mask_lines(lines, valid_mask): 65 | h, w = valid_mask.shape 66 | int_lines = np.clip(np.round(lines).astype(int), 0, [h - 1, w - 1]) 67 | h_valid = valid_mask[int_lines[:, 0, 0], int_lines[:, 0, 1]] 68 | w_valid = valid_mask[int_lines[:, 1, 0], int_lines[:, 1, 1]] 69 | valid = h_valid & w_valid 70 | return valid 71 | 72 | 73 | # Return a 2D matrix indicating for each pair of points 74 | # if they are on the same line or not 75 | def get_common_line_mask(line_indices, valid_mask): 76 | b_size, n_points = line_indices.shape 77 | common_mask = line_indices[:, :, None] == line_indices[:, None, :] 78 | common_mask = common_mask.repeat(1, 1, b_size).reshape(b_size * n_points, 79 | b_size * n_points) 80 | common_mask = common_mask[valid_mask, :][:, valid_mask] 81 | return common_mask 82 | -------------------------------------------------------------------------------- /hawp/ssl/misc/train_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains some useful functions for train / val. 3 | """ 4 | import os 5 | import numpy as np 6 | import torch 7 | 8 | 9 | ################# 10 | ## image utils ## 11 | ################# 12 | def convert_image(input_tensor, axis): 13 | """ Convert single channel images to 3-channel images. """ 14 | image_lst = [input_tensor for _ in range(3)] 15 | outputs = np.concatenate(image_lst, axis) 16 | return outputs 17 | 18 | 19 | ###################### 20 | ## checkpoint utils ## 21 | ###################### 22 | def get_latest_checkpoint(checkpoint_root, checkpoint_name, 23 | device=torch.device("cuda")): 24 | """ Get the latest checkpoint or by filename. """ 25 | # Load specific checkpoint 26 | if checkpoint_name is not None: 27 | checkpoint = torch.load( 28 | os.path.join(checkpoint_root, checkpoint_name), 29 | map_location=device) 30 | # Load the latest checkpoint 31 | else: 32 | lastest_checkpoint = sorted(os.listdir(os.path.join( 33 | checkpoint_root, "*.tar")))[-1] 34 | checkpoint = torch.load(os.path.join( 35 | checkpoint_root, lastest_checkpoint), map_location=device) 36 | return checkpoint 37 | 38 | 39 | def remove_old_checkpoints(checkpoint_root, max_ckpt=15): 40 | """ Remove the outdated checkpoints. """ 41 | # Get sorted list of checkpoints 42 | checkpoint_list = sorted( 43 | [_ for _ in os.listdir(os.path.join(checkpoint_root)) 44 | if _.endswith(".tar")]) 45 | 46 | # Get the checkpoints to be removed 47 | if len(checkpoint_list) > max_ckpt: 48 | remove_list = checkpoint_list[:-max_ckpt] 49 | for _ in remove_list: 50 | full_name = os.path.join(checkpoint_root, _) 51 | os.remove(full_name) 52 | print("[Debug] Remove outdated checkpoint %s" % (full_name)) 53 | 54 | 55 | ################ 56 | ## HDF5 utils ## 57 | ################ 58 | def parse_h5_data(h5_data): 59 | """ Parse h5 dataset. """ 60 | output_data = {} 61 | for key in h5_data.keys(): 62 | output_data[key] = np.array(h5_data[key]) 63 | 64 | return output_data 65 | -------------------------------------------------------------------------------- /hawp/ssl/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # from . import detector, detect_with_heatmaps 3 | # from . import detect_with_heatmaps 4 | from .detector import HAWP 5 | from . import detector 6 | from . import detector_with_heatmap 7 | from . import detector_with_hrheat 8 | from .registry import MODELS -------------------------------------------------------------------------------- /hawp/ssl/models/heatmap_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # The pixel shuffle decoder 6 | class PixelShuffleDecoder(nn.Module): 7 | def __init__(self, input_feat_dim=128, num_upsample=2, output_channel=2): 8 | super(PixelShuffleDecoder, self).__init__() 9 | # Get channel parameters 10 | self.channel_conf = self.get_channel_conf(num_upsample) 11 | 12 | # Define the pixel shuffle 13 | self.pixshuffle = nn.PixelShuffle(2) 14 | 15 | # Process the feature 16 | self.conv_block_lst = [] 17 | # The input block 18 | self.conv_block_lst.append( 19 | nn.Sequential( 20 | nn.Conv2d(input_feat_dim, self.channel_conf[0], kernel_size=3, stride=1, padding=1), 21 | nn.BatchNorm2d(self.channel_conf[0]), 22 | nn.ReLU(inplace=True) 23 | )) 24 | 25 | # Intermediate block 26 | for channel in self.channel_conf[1:-1]: 27 | self.conv_block_lst.append( 28 | nn.Sequential( 29 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1), 30 | nn.BatchNorm2d(channel), 31 | nn.ReLU(inplace=True) 32 | )) 33 | 34 | # Output block 35 | self.conv_block_lst.append( 36 | nn.Conv2d(self.channel_conf[-1], output_channel, kernel_size=1, stride=1, padding=0) 37 | ) 38 | self.conv_block_lst = nn.ModuleList(self.conv_block_lst) 39 | 40 | # Get num of channels based on number of upsampling. 41 | def get_channel_conf(self, num_upsample): 42 | if num_upsample == 2: 43 | return [256, 64, 16] 44 | elif num_upsample == 3: 45 | return [256, 64, 16, 4] 46 | 47 | def forward(self, input_features): 48 | # Iterate til output block 49 | out = input_features 50 | for block in self.conv_block_lst[:-1]: 51 | out = block(out) 52 | out = self.pixshuffle(out) 53 | 54 | # Output layer 55 | out = self.conv_block_lst[-1](out) 56 | 57 | return out 58 | -------------------------------------------------------------------------------- /hawp/ssl/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | def cross_entropy_loss_for_junction(logits, positive): 4 | nlogp = -F.log_softmax(logits, dim=1) 5 | 6 | loss = (positive * nlogp[:, None, 1] + (1 - positive) * nlogp[:, None, 0]) 7 | 8 | return loss.mean() 9 | 10 | def focal_loss_for_junction(logits, positive, gamma=2.0): 11 | prob = F.softmax(logits, 1) 12 | ce_loss = F.cross_entropy(logits, positive, reduction='none') 13 | p_t = prob[:,1:]*positive + prob[:,:1]*(1-positive) 14 | loss = ce_loss * ((1-p_t)**gamma) 15 | 16 | return loss.mean() 17 | 18 | def sigmoid_l1_loss(logits, targets, offset = 0.0, mask=None): 19 | logp = torch.sigmoid(logits) + offset 20 | loss = torch.abs(logp-targets) 21 | 22 | if mask is not None: 23 | w = mask.mean(3, True).mean(2,True) 24 | w[w==0] = 1 25 | loss = loss*(mask/w) 26 | 27 | return loss.mean() 28 | 29 | def sigmoid_focal_loss( 30 | inputs: torch.Tensor, 31 | targets: torch.Tensor, 32 | alpha: float = -1, 33 | gamma: float = 2, 34 | reduction: str = "none", 35 | ) -> torch.Tensor: 36 | """ 37 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 38 | Args: 39 | inputs: A float tensor of arbitrary shape. 40 | The predictions for each example. 41 | targets: A float tensor with the same shape as inputs. Stores the binary 42 | classification label for each element in inputs 43 | (0 for the negative class and 1 for the positive class). 44 | alpha: (optional) Weighting factor in range (0,1) to balance 45 | positive vs negative examples. Default = -1 (no weighting). 46 | gamma: Exponent of the modulating factor (1 - p_t) to 47 | balance easy vs hard examples. 48 | reduction: 'none' | 'mean' | 'sum' 49 | 'none': No reduction will be applied to the output. 50 | 'mean': The output will be averaged. 51 | 'sum': The output will be summed. 52 | Returns: 53 | Loss tensor with the reduction option applied. 54 | """ 55 | inputs = inputs.float() 56 | targets = targets.float() 57 | p = torch.sigmoid(inputs) 58 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 59 | p_t = p * targets + (1 - p) * (1 - targets) 60 | loss = ce_loss * ((1 - p_t) ** gamma) 61 | 62 | if alpha >= 0: 63 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 64 | loss = alpha_t * loss 65 | 66 | if reduction == "mean": 67 | loss = loss.mean() 68 | elif reduction == "sum": 69 | loss = loss.sum() 70 | 71 | return loss -------------------------------------------------------------------------------- /hawp/ssl/models/registry.py: -------------------------------------------------------------------------------- 1 | from hawp.base.utils.registry import Registry 2 | 3 | MODELS = Registry() 4 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | cython 3 | matplotlib 4 | yacs 5 | scikit-image 6 | tqdm 7 | python-json-logger 8 | h5py 9 | shapely 10 | pycolmap 11 | seaborn 12 | kornia 13 | easydict 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | from setuptools import find_packages 5 | from setuptools import setup 6 | 7 | setup( 8 | name="hawp", 9 | version="1.0", 10 | author="nxue", 11 | description="Holistically-Attracted Wireframe Parsing", 12 | packages=find_packages(), 13 | install_requires=[ 14 | "torch", 15 | "torchvision", 16 | "opencv-python", 17 | "cython", 18 | "matplotlib", 19 | "yacs", 20 | "scikit-image", 21 | "tqdm", 22 | "python-json-logger", 23 | "h5py", 24 | "shapely", 25 | "seaborn", 26 | "easydict", 27 | ], 28 | extras_require={ 29 | "dev": [ 30 | "pycolmap", 31 | ] 32 | } 33 | ) 34 | --------------------------------------------------------------------------------