├── .gitignore
├── LEGAL.md
├── LICENSE
├── configs
├── hawpv2.yaml
└── plnet.yaml
├── docker
└── Dockerfile
├── docs
├── HAWPv2.md
├── HAWPv3.md
├── HAWPv3.train.md
└── figures
│ ├── dtu-24
│ ├── 000000.png
│ ├── 000001.png
│ ├── 000002.png
│ ├── 000003.png
│ ├── 000004.png
│ ├── 000005.png
│ ├── 000009.png
│ ├── 000015.png
│ └── 000045.png
│ ├── plnet
│ ├── plnet.png
│ └── uma_feature.png
│ ├── v3-BSDS
│ ├── 37073.png
│ ├── 42049.png
│ └── 85048.png
│ ├── v3-CrowdAI
│ ├── 000000000190.png
│ ├── 000000000210.png
│ └── 000000000230.png
│ └── v3-wireframe
│ ├── 00037187.png
│ ├── 00051510.png
│ └── 00074259.png
├── downloads.sh
├── evaluation
├── EdgeEval
│ ├── CSA++
│ │ ├── GNUmakefile
│ │ ├── csa.cc
│ │ ├── csa.hh
│ │ ├── csaAssign.m
│ │ ├── csaAssign_c
│ │ ├── csa_defs.h
│ │ ├── csa_types.h
│ │ ├── sparsify.m
│ │ ├── test.cc
│ │ ├── test.m
│ │ └── test.txt
│ ├── EdgeMapEval.cc
│ ├── Util
│ │ ├── Array.hh
│ │ ├── Exception.cc
│ │ ├── Exception.hh
│ │ ├── GNUmakefile
│ │ ├── GNUmakefile-library
│ │ ├── Lab2RGB.m
│ │ ├── Matrix.cc
│ │ ├── Matrix.hh
│ │ ├── Point.hh
│ │ ├── RGB2Lab.m
│ │ ├── Random.cc
│ │ ├── Random.hh
│ │ ├── Sort.hh
│ │ ├── String.cc
│ │ ├── String.hh
│ │ ├── Timer.cc
│ │ ├── Timer.hh
│ │ ├── distSqr.m
│ │ ├── fftconv2.m
│ │ ├── gethosttype
│ │ ├── isum.c
│ │ ├── isum.m
│ │ ├── kmeansML.m
│ │ ├── kofn.cc
│ │ ├── kofn.hh
│ │ ├── logist2.m
│ │ ├── padReflect.m
│ │ ├── progbar.m
│ │ ├── test
│ │ └── test_cc
│ ├── __init__.py
│ ├── build
│ │ ├── temp.linux-x86_64-3.5
│ │ │ ├── CSA++
│ │ │ │ └── csa.o
│ │ │ ├── EdgeMapEval.o
│ │ │ ├── Util
│ │ │ │ ├── Exception.o
│ │ │ │ ├── Matrix.o
│ │ │ │ ├── Random.o
│ │ │ │ ├── String.o
│ │ │ │ ├── Timer.o
│ │ │ │ └── kofn.o
│ │ │ ├── correspondPixels.o
│ │ │ └── match.o
│ │ └── temp.linux-x86_64-3.8
│ │ │ ├── CSA++
│ │ │ └── csa.o
│ │ │ ├── EdgeMapEval.o
│ │ │ ├── Util
│ │ │ ├── Exception.o
│ │ │ ├── Matrix.o
│ │ │ ├── Random.o
│ │ │ ├── String.o
│ │ │ ├── Timer.o
│ │ │ └── kofn.o
│ │ │ ├── correspondPixels.o
│ │ │ └── match.o
│ ├── correspondPixels.cpp
│ ├── correspondPixels.h
│ ├── correspondPixels.pyx
│ ├── include
│ │ ├── Array.hh
│ │ ├── Exception.hh
│ │ ├── Point.hh
│ │ ├── Random.hh
│ │ ├── Sort.hh
│ │ ├── String.hh
│ │ ├── Timer.hh
│ │ ├── csa.hh
│ │ ├── csa_defs.h
│ │ ├── csa_types.h
│ │ ├── kofn.hh
│ │ └── match.hh
│ ├── match.cc
│ └── setup.py
├── Makefile
├── RasterizeLine
│ ├── __init__.py
│ ├── build
│ │ ├── temp.linux-x86_64-3.5
│ │ │ ├── draw.o
│ │ │ └── kernel.o
│ │ └── temp.linux-x86_64-3.8
│ │ │ ├── draw.o
│ │ │ └── kernel.o
│ ├── draw.cpp
│ ├── draw.hpp
│ ├── draw.pyx
│ ├── kernel.cpp
│ └── setup.py
├── __init__.py
├── compute_prec_recall.py
├── draw-hap.py
├── draw-json.py
├── draw-sap.py
├── eval-json.py
├── eval-junctions.py
├── eval-sap.py
├── evaluation.py
├── example_evaluation.py
├── example_rasterline.py
├── prmeter.py
├── runs
│ ├── draw-hap.sh
│ ├── draw-sap.sh
│ ├── draw-vis-im-york.sh
│ ├── draw-vis-im.sh
│ ├── draw-vis-york.sh
│ ├── draw-vis.sh
│ ├── eval-aph.sh
│ └── sap.sh
└── sAPEval
│ ├── __init__.py
│ └── metric.py
├── hawp
├── __init__.py
├── base
│ ├── __init__.py
│ ├── csrc
│ │ ├── __init__.py
│ │ ├── binding.cpp
│ │ ├── linesegment.cu
│ │ └── linesegment.h
│ ├── show
│ │ ├── __init__.py
│ │ ├── canvas.py
│ │ ├── cli.py
│ │ └── painters.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── c2_model_loading.py
│ │ ├── checkpoint.py
│ │ ├── comm.py
│ │ ├── imports.py
│ │ ├── logger.py
│ │ ├── metric_evaluation.py
│ │ ├── metric_logger.py
│ │ ├── miscellaneous.py
│ │ ├── model_serialization.py
│ │ ├── model_zoo.py
│ │ └── registry.py
│ └── wireframe.py
├── encoder
│ ├── __init__.py
│ └── hafm.py
├── fsl
│ ├── __init__.py
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── build.py
│ │ ├── multi_task_head.py
│ │ ├── point_line.py
│ │ ├── registry.py
│ │ ├── resnets.py
│ │ ├── stacked_hg.py
│ │ └── stacked_point_line.py
│ ├── benchmark.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ ├── defaults.py
│ │ ├── detr.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── head.py
│ │ │ ├── models.py
│ │ │ ├── proposal_head.py
│ │ │ ├── resnets.py
│ │ │ └── shg.py
│ │ ├── paths_catalog.py
│ │ └── solver.py
│ ├── dataset
│ │ ├── __init__.py
│ │ ├── build.py
│ │ ├── imagelist.py
│ │ ├── iteration_based_batch_sampler.py
│ │ ├── stream.py
│ │ ├── test_dataset.py
│ │ ├── train_dataset.bck.py
│ │ ├── train_dataset.py
│ │ └── transforms.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── build.py
│ │ ├── hafm.py
│ │ ├── losses.py
│ │ ├── misc.py
│ │ └── models.py
│ ├── point_model
│ │ └── point_model.pth
│ ├── predict.py
│ ├── solver.py
│ └── train.py
└── ssl
│ ├── __init__.py
│ ├── config
│ ├── __init__.py
│ ├── exports
│ │ ├── wireframe-100iters.yaml
│ │ └── wireframe-10iters.yaml
│ ├── hawpv3-hrheat.yaml
│ ├── hawpv3.yaml
│ ├── project_config.py
│ ├── synthetic_dataset-4k.yaml
│ ├── synthetic_dataset.yaml
│ ├── utils.py
│ └── wireframe_official_gt.yaml
│ ├── datasets
│ ├── __init__.py
│ ├── dataset_util.py
│ ├── images_dataset.py
│ ├── synthetic_dataset.py
│ ├── synthetic_util.py
│ ├── transforms
│ │ ├── __init__.py
│ │ ├── homographic_transforms.py
│ │ ├── photometric_transforms.py
│ │ └── utils.py
│ ├── wireframe.py
│ ├── wireframe_dataset.py
│ └── yorkurban_dataset.py
│ ├── evaluate.py
│ ├── homoadp-bm.py
│ ├── homoadp.py
│ ├── misc
│ ├── __init__.py
│ ├── geometry_utils.old.py
│ ├── geometry_utils.py
│ ├── train_utils.py
│ └── visualize_util.py
│ ├── models
│ ├── __init__.py
│ ├── base.py
│ ├── detector.py
│ ├── detector_with_heatmap.py
│ ├── detector_with_hrheat.py
│ ├── hafm.py
│ ├── heatmap_decoder.py
│ ├── losses.py
│ └── registry.py
│ ├── predict.py
│ └── train.py
├── readme.md
├── requirement.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs/*
2 | **/*.DS_Store
3 | parsing/data/wireframe/images/*
4 | .vscode/*
5 | data/*
6 | data
7 |
8 | **/*.pyc
9 | **/*.so
10 | data-ssl/**/*
11 | exp-ssl/**/*
--------------------------------------------------------------------------------
/LEGAL.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/LEGAL.md
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Nan Xue
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/configs/hawpv2.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | NUM_WORKERS: 0
3 | DATASETS:
4 | DISTANCE_TH: 0.02
5 | IMAGE:
6 | HEIGHT: 512
7 | PIXEL_MEAN:
8 | - 109.73
9 | - 103.832
10 | - 98.681
11 | PIXEL_STD:
12 | - 22.275
13 | - 22.124
14 | - 23.229
15 | TO_255: true
16 | WIDTH: 512
17 | NUM_STATIC_NEGATIVE_LINES: 40
18 | NUM_STATIC_POSITIVE_LINES: 300
19 | AUGMENTATION: 5
20 | TARGET:
21 | HEIGHT: 128
22 | WIDTH: 128
23 | TEST:
24 | - wireframe_test
25 | TRAIN:
26 | - wireframe_train
27 | VAL:
28 | - wireframe_test
29 | ENCODER:
30 | ANG_TH: 0.0
31 | BACKGROUND_WEIGHT: 0.0
32 | DIS_TH: 2
33 | NUM_STATIC_NEG_LINES: 0
34 | NUM_STATIC_POS_LINES: 300
35 | MODEL:
36 | DEVICE: cuda
37 | FOCAL_LOSS:
38 | ALPHA: -1.0
39 | GAMMA: 0.0
40 | HEAD_SIZE:
41 | - - 3
42 | - - 1
43 | - - 1
44 | - - 2
45 | - - 2
46 | HGNETS:
47 | DEPTH: 4
48 | INPLANES: 64
49 | NUM_BLOCKS: 1
50 | NUM_FEATS: 128
51 | NUM_STACKS: 2
52 | LOI_POOLING:
53 | ACTIVATION: relu
54 | DIM_EDGE_FEATURE: 4
55 | DIM_FC: 1024
56 | DIM_JUNCTION_FEATURE: 128
57 | LAYER_NORM: false
58 | NUM_POINTS: 32
59 | TYPE: softmax
60 | LOSS_WEIGHTS:
61 | loss_aux: 1.0
62 | loss_dis: 1.0
63 | loss_jloc: 8.0
64 | loss_joff: 0.25
65 | loss_lineness: 1.0
66 | loss_md: 1.0
67 | loss_neg: 1.0
68 | loss_pos: 1.0
69 | loss_res: 1.0
70 | NAME: Hourglass
71 | OUT_FEATURE_CHANNELS: 256
72 | PARSING_HEAD:
73 | DIM_FC: 1024
74 | DIM_LOI: 128
75 | J2L_THRESHOLD: 10.0
76 | JMATCH_THRESHOLD: 1.5
77 | JUNCTION_HM_THRESHOLD: 0.008 #magic number
78 | MATCHING_STRATEGY: junction
79 | MAX_DISTANCE: 5.0
80 | N_DYN_JUNC: 300
81 | N_DYN_NEGL: 40
82 | N_DYN_OTHR: 0
83 | N_DYN_OTHR2: 300
84 | N_DYN_POSL: 300
85 | N_OUT_JUNC: 250
86 | N_OUT_LINE: 2500
87 | N_PTS0: 32
88 | N_PTS1: 8
89 | N_STC_NEGL: 40
90 | N_STC_POSL: 300
91 | USE_RESIDUAL: 1
92 | RESNETS:
93 | BASENET: resnet50
94 | PRETRAIN: true
95 | SCALE: 1.0
96 | WEIGHTS: ''
97 | MODELING_PATH: ihawp-v2
98 | OUTPUT_DIR: output/ihawp
99 | SOLVER:
100 | AMSGRAD: true
101 | BACKBONE_LR_FACTOR: 1.0
102 | BASE_LR: 0.0004
103 | BIAS_LR_FACTOR: 1
104 | CHECKPOINT_PERIOD: 1
105 | GAMMA: 0.1
106 | IMS_PER_BATCH: 6
107 | MAX_EPOCH: 30
108 | MOMENTUM: 0.9
109 | OPTIMIZER: ADAM
110 | STEPS:
111 | - 25
112 | WEIGHT_DECAY: 0.0001
113 | WEIGHT_DECAY_BIAS: 0
114 |
--------------------------------------------------------------------------------
/configs/plnet.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | NUM_WORKERS: 0
3 | DATASETS:
4 | DISTANCE_TH: 0.02
5 | IMAGE:
6 | HEIGHT: 512
7 | PIXEL_MEAN:
8 | - 109.73
9 | - 103.832
10 | - 98.681
11 | PIXEL_STD:
12 | - 22.275
13 | - 22.124
14 | - 23.229
15 | TO_255: true
16 | WIDTH: 512
17 | NUM_STATIC_NEGATIVE_LINES: 40
18 | NUM_STATIC_POSITIVE_LINES: 300
19 | AUGMENTATION: 5
20 | TARGET:
21 | HEIGHT: 128
22 | WIDTH: 128
23 | TEST:
24 | - wireframe_test
25 | TRAIN:
26 | - wireframe_train
27 | VAL:
28 | - wireframe_test
29 | ENCODER:
30 | ANG_TH: 0.0
31 | BACKGROUND_WEIGHT: 0.0
32 | DIS_TH: 2
33 | NUM_STATIC_NEG_LINES: 0
34 | NUM_STATIC_POS_LINES: 300
35 | MODEL:
36 | DEVICE: cuda
37 | FOCAL_LOSS:
38 | ALPHA: -1.0
39 | GAMMA: 0.0
40 | HEAD_SIZE:
41 | - - 3
42 | - - 1
43 | - - 1
44 | - - 2
45 | - - 2
46 | HGNETS:
47 | DEPTH: 4
48 | INPLANES: 64
49 | NUM_BLOCKS: 2
50 | NUM_FEATS: 128
51 | NUM_STACKS: 2
52 | LOI_POOLING:
53 | ACTIVATION: relu
54 | DIM_EDGE_FEATURE: 4
55 | # DIM_FC: 1024
56 | DIM_FC: 128
57 | DIM_JUNCTION_FEATURE: 128
58 | LAYER_NORM: false
59 | NUM_POINTS: 32
60 | TYPE: softmax
61 | LOSS_WEIGHTS:
62 | loss_aux: 1.0 # 2
63 | loss_dis: 1.0 # 2
64 | loss_jloc: 8.0 # 2
65 | loss_joff: 0.25 # 2
66 | loss_res: 1.0 # 2
67 | loss_md: 1.0 # 2
68 | loss_lineness: 1.0 # 1
69 | loss_neg: 1.0 # 1
70 | loss_pos: 1.0 # 1
71 | NAME: PointLine
72 | OUT_FEATURE_CHANNELS: 256
73 | PARSING_HEAD:
74 | DIM_FC: 1024
75 | DIM_LOI: 128
76 | J2L_THRESHOLD: 10.0
77 | JMATCH_THRESHOLD: 1.5
78 | JUNCTION_HM_THRESHOLD: 0.008 #magic number
79 | MATCHING_STRATEGY: junction
80 | MAX_DISTANCE: 5.0
81 | N_DYN_JUNC: 300
82 | N_DYN_NEGL: 40
83 | N_DYN_OTHR: 0
84 | N_DYN_OTHR2: 300
85 | N_DYN_POSL: 300
86 | N_OUT_JUNC: 250
87 | N_OUT_LINE: 2500
88 | N_PTS0: 32
89 | N_PTS1: 8
90 | N_STC_NEGL: 40
91 | N_STC_POSL: 300
92 | USE_RESIDUAL: 1
93 | RESNETS:
94 | BASENET: resnet50
95 | PRETRAIN: true
96 | SCALE: 1.0
97 | WEIGHTS: ''
98 | MODELING_PATH: ihawp-v2
99 | OUTPUT_DIR: output/ihawp
100 | SOLVER:
101 | AMSGRAD: true
102 | BACKBONE_LR_FACTOR: 1.0
103 | BASE_LR: 0.0004
104 | BIAS_LR_FACTOR: 1
105 | CHECKPOINT_PERIOD: 1
106 | GAMMA: 0.2
107 | IMS_PER_BATCH: 6
108 | MAX_EPOCH: 40
109 | MOMENTUM: 0.9
110 | OPTIMIZER: ADAM
111 | STEPS:
112 | - 35
113 | WEIGHT_DECAY: 0.0001
114 | WEIGHT_DECAY_BIAS: 0
115 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.6.0-devel-ubuntu20.04
2 |
3 | # Prevent stop building ubuntu at time zone selection.
4 | ENV DEBIAN_FRONTEND=noninteractive
5 |
6 | RUN apt-get update && apt-get install -y \
7 | git \
8 | build-essential \
9 | vim \
10 | openssh-server \
11 | python3-opencv \
12 | ca-certificates \
13 | python3-dev \
14 | python3-pip \
15 | wget \
16 | ninja-build \
17 | mesa-common-dev \
18 | libgl1-mesa-dev \
19 | libglu1-mesa-dev \
20 | xauth
21 |
22 | RUN sed -i "s/^.*X11Forwarding.*$/X11Forwarding yes/" /etc/ssh/sshd_config \
23 | && sed -i "s/^.*X11UseLocalhost.*$/X11UseLocalhost no/" /etc/ssh/sshd_config \
24 | && grep "^X11UseLocalhost" /etc/ssh/sshd_config || echo "X11UseLocalhost no" >> /etc/ssh/sshd_config
25 |
26 | RUN ln -sv /usr/bin/python3 /usr/bin/python
27 |
28 | RUN pip3 install numpy==1.23.4
29 | RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
30 |
31 | RUN echo 'root:root' | chpasswd
32 |
33 | RUN echo 'PermitRootLogin yes' >> /etc/ssh/sshd_config
34 |
35 | RUN service ssh start
36 |
37 | CMD ["/usr/sbin/sshd","-D"]
38 |
--------------------------------------------------------------------------------
/docs/HAWPv2.md:
--------------------------------------------------------------------------------
1 | # HAWPv2: Learning Wireframes via Fully-Supervised Learning
2 |
3 | *The codes of HAWPv2 are placed in the directory of [hawp/fsl](../hawp/fsl).*
4 | ## Quickstart & Evaluation
5 | - Please download the dataset and checkpoints as in [readme.md](../readme.md).
6 | - Run the following command line(s) to evaluate the offical model on the Wireframe dataset and YorkUrban dataset by
7 |
8 |
9 | Evaluation on the Wireframe dataset.
10 |
11 | ```bash
12 | python -m hawp.fsl.benchmark configs/hawpv2.yaml \
13 | --ckpt checkpoints/hawpv2-edb9b23f.pth \
14 | --dataset wireframe
15 | ```
16 |
17 |
18 |
19 | Evaluation on the YorkUrban dataset.
20 |
21 | ```bash
22 | python -m hawp.fsl.benchmark configs/hawpv2.yaml \
23 | --ckpt checkpoints/hawpv2-edb9b23f.pth \
24 | --dataset york
25 | ```
26 |
27 |
28 | ## Evaluation Results
29 |
30 | |Dataset|sAP-5|sAP-10|sAP-15|command line|comment|
31 | |--|--|--|--|--|--|
32 | |Wireframe| 65.8 | 69.8 |71.4|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset wireframe --jhm=0.001``|jhm = 0.001|
33 | |Wireframe| 65.7 | 69.8 |71.4|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset wireframe --jhm=0.005``|jhm = 0.005|
34 | |Wireframe| 65.7 | 69.7 |71.3|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset wireframe --jhm=0.008``|jhm = 0.008 (default setting)|
35 | |YorkUrban|29.0|31.4|32.8|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset york --jhm=0.005``|jhm = 0.001|jhm=0.001
36 | |YorkUrban|28.9|31.4|32.7|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset york --jhm=0.005``|jhm = 0.005|jhm = 0.005
37 | |YorkUrban|28.8|31.3|32.6|``python -m hawp.fsl.benchmark configs/hawpv2.yaml --ckpt checkpoints/hawpv2-edb9b23f.pth --dataset york --jhm=0.005``|jhm = 0.008|jhm = 0.008 (default setting)
38 |
39 | # Training
40 | - Run the following command line to train the HAWPv2 on the Wireframe dataset.
41 | ```
42 | python -m hawp.fsl.train configs/hawpv2.yaml --logdir outputs
43 | ```
44 |
45 | - The usage of [hawp.fsl.train](hawp/fsl/../../../hawp/fsl/train.py) is as follow:
46 | ```dotnetcli
47 | HAWPv2 Training
48 |
49 | positional arguments:
50 | config path to config file
51 |
52 | optional arguments:
53 | -h, --help show this help message and exit
54 | --logdir LOGDIR
55 | --resume RESUME
56 | --clean
57 | --seed SEED
58 | --tf32 toggle on the TF32 of pytorch
59 | --dtm {True,False} toggle the deterministic option of CUDNN. This option will affect the replication of experiments
60 |
61 | ```
--------------------------------------------------------------------------------
/docs/HAWPv3.md:
--------------------------------------------------------------------------------
1 | # HAWPv3: Learning Wireframes via Self-Supervised Learning
2 |
3 | *The codes of HAWPv3 are placed in the directory of [hawp/ssl](../hawp/ssl).*
4 |
5 | |Model Name|Comments|MD5|
6 | |---|---|---|
7 | |[hawpv3-fdc5487a.pth](https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv3/hawpv3-fdc5487a.pth)| Trained on the images of Wireframe dataset | fdc5487a43e3d42f6b2addf79d8b930d
8 | |[hawpv3-imagenet-03a84.pth](https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv3/hawpv3-imagenet-03a84.pth)| Trained on 100k images of ImageNet dataset| 03a8400e9474320f2b42973d1ba19487|
9 |
10 | ### Inference on your own images
11 |
12 | - Run the following command line to obtain wireframes from HAWPv3 model
13 |
14 | hawpv3-fdc5487a.pth
15 |
16 | python -m hawp.ssl.predict --ckpt checkpoints/hawpv3-fdc5487a.pth \
17 | --threshold 0.05 \
18 | --img {filename.png}
19 |
20 |
21 |
22 | hawpv3-imagenet-03a84.pth
23 |
24 | python -m hawp.ssl.predict --ckpt checkpoints/hawpv3-imagenet-03a84.pth \
25 | --threshold 0.05 \
26 | --img {filename.png}
27 |
28 |
29 | - A running example on the DTU-24 images
30 | ```bash
31 | python -m hawp.ssl.predict --ckpt checkpoints/hawpv3-imagenet-03a84.pth \
32 | --threshold 0.05 \
33 | --img ~/datasets/DTU/scan24/image/*.png \
34 | --saveto docs/figures/dtu-24 --ext png \
35 | ```
36 |
37 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 | ## Training
50 |
51 | ```bash
52 |
53 | python -m hawp.ssl.train --help
54 | usage: train.py [-h] --datacfg DATACFG --modelcfg MODELCFG --name NAME
55 | [--pretrained PRETRAINED] [--overwrite] [--tf32]
56 | [--dtm {True,False}] [--batch-size BATCH_SIZE]
57 | [--num-workers NUM_WORKERS] [--base-lr BASE_LR]
58 | [--steps STEPS [STEPS ...]] [--gamma GAMMA] [--epochs EPOCHS]
59 | [--seed SEED] [--iterations ITERATIONS]
60 |
61 | optional arguments:
62 | -h, --help show this help message and exit
63 | --datacfg DATACFG filepath of the data config
64 | --modelcfg MODELCFG filepath of the model config
65 | --name NAME the name of experiment
66 | --pretrained PRETRAINED
67 | the pretrained model
68 | --overwrite [Caution!] the option to overwrite an existed
69 | experiment
70 | --tf32 toggle on the TF32 of pytorch
71 | --dtm {True,False} toggle the deterministic option of CUDNN. This option
72 | will affect the replication of experiments
73 |
74 | training recipe:
75 | --batch-size BATCH_SIZE
76 | the batch size of training
77 | --num-workers NUM_WORKERS
78 | the number of workers for training
79 | --base-lr BASE_LR the initial learning rate
80 | --steps STEPS [STEPS ...]
81 | the steps of the scheduler
82 | --gamma GAMMA the lr decay factor
83 | --epochs EPOCHS the number of epochs for training
84 | --seed SEED the random seed for training
85 | --iterations ITERATIONS
86 | the number of training iterations
87 |
88 | ```
--------------------------------------------------------------------------------
/docs/HAWPv3.train.md:
--------------------------------------------------------------------------------
1 | # Training Recipes of HAWPv3
2 |
3 | *HAWPv3 consists of multiple training phases including the synthetic training phase and the real data training.*
4 |
5 | ## Step 0: Synthetic learning
6 |
7 | ```
8 | python -m hawp.ssl.train \
9 | --datacfg hawp/ssl/config/synthetic_dataset.yaml \
10 | --modelcfg hawp/ssl/config/hawpv3.yaml \
11 | --base-lr 0.0004 \
12 | --epochs 10 \
13 | --batch-size 6 \
14 | --name hawpv3-round0
15 | ```
16 |
17 | ## Step 1: Homographic Adaptation for Pseudo Wireframe Generation
18 |
19 | If you prefer single-image mode to minimize the usage of GPU memory footprint, please use the following command to obtain pseudo labels
20 | ```
21 | python -m hawp.ssl.homoadp --metarch HAWP-heatmap \
22 | --datacfg hawp/ssl/config/export/wireframe-10iters.yaml \
23 | --workdir exp-ssl/hawpv3-round0 \
24 | --epoch 10 \
25 | --modelcfg exp-ssl/hawpv3-round0/model.yaml \
26 | --min_score 0.75
27 |
28 | ```
29 |
30 | For the batch mode, please use the following command
31 | ```
32 | python -m hawp.ssl.homoadp-bm --metarch HAWP-heatmap \
33 | --datacfg hawp/ssl/config/exports/wireframe-10iters.yaml \
34 | --workdir exp-ssl/hawpv3-round0 \
35 | --epoch 10 \
36 | --modelcfg exp-ssl/hawpv3-round0/model.yaml \
37 | --min-score 0.75 --batch-size=16
38 | ```
39 | *On my machine (NVIDIA A6000), the batch size of 16 will take 40G GPU memory in 43 minutes to generate the wireframe labels for 20k images of the training images in the Wireframe dataset (Huang et al., CVPR 2018)*
40 |
41 | ### Remarks
42 | 1. After the homographic adaptation step finished, three files will be generated and stored at the directory of ```data-ssl/{name}```, where ```{name}``` is the name for the last round of training. For example, if we train HAWPv3 with the name of ```hawpv3-round0```, the exported data will be saved at ```data-ssl/hawpv3-round0```.
43 |
44 | 2. For each generated wireframe and its auxiluary files, their names are started with ```{hash}-model-{epoch:05d}```.
45 |
46 | 3. In sum, the generated datacfg ``YAML`` file is located at ``data-ssl/{name}/{hash}-model-{epoch:05d}.yaml``.
47 | ## Step 2: Learning from Real-World images
48 |
49 | - Once we have the pseudo wireframe labels, we can train HAWPv3 on the real-world images. An example usage is in the below command line:
50 | ```
51 | python -m hawp.ssl.train --datacfg data-ssl/export_datasets/{name}/{hash}-model-00010.yaml --modelcfg hawp/ssl/config/hawpv3.yaml --base-lr 0.0004 --epochs 30 --name hawpv3-round1 --batch-size 6
52 | ```
53 |
54 | - Then, we can run the homographic adaptation with new model checkpoints trained on real-world images, and then train/fine-tune a new model to further improve the repeatibility.
55 |
--------------------------------------------------------------------------------
/docs/figures/dtu-24/000000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000000.png
--------------------------------------------------------------------------------
/docs/figures/dtu-24/000001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000001.png
--------------------------------------------------------------------------------
/docs/figures/dtu-24/000002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000002.png
--------------------------------------------------------------------------------
/docs/figures/dtu-24/000003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000003.png
--------------------------------------------------------------------------------
/docs/figures/dtu-24/000004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000004.png
--------------------------------------------------------------------------------
/docs/figures/dtu-24/000005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000005.png
--------------------------------------------------------------------------------
/docs/figures/dtu-24/000009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000009.png
--------------------------------------------------------------------------------
/docs/figures/dtu-24/000015.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000015.png
--------------------------------------------------------------------------------
/docs/figures/dtu-24/000045.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/dtu-24/000045.png
--------------------------------------------------------------------------------
/docs/figures/plnet/plnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/plnet/plnet.png
--------------------------------------------------------------------------------
/docs/figures/plnet/uma_feature.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/plnet/uma_feature.png
--------------------------------------------------------------------------------
/docs/figures/v3-BSDS/37073.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-BSDS/37073.png
--------------------------------------------------------------------------------
/docs/figures/v3-BSDS/42049.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-BSDS/42049.png
--------------------------------------------------------------------------------
/docs/figures/v3-BSDS/85048.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-BSDS/85048.png
--------------------------------------------------------------------------------
/docs/figures/v3-CrowdAI/000000000190.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-CrowdAI/000000000190.png
--------------------------------------------------------------------------------
/docs/figures/v3-CrowdAI/000000000210.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-CrowdAI/000000000210.png
--------------------------------------------------------------------------------
/docs/figures/v3-CrowdAI/000000000230.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-CrowdAI/000000000230.png
--------------------------------------------------------------------------------
/docs/figures/v3-wireframe/00037187.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-wireframe/00037187.png
--------------------------------------------------------------------------------
/docs/figures/v3-wireframe/00051510.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-wireframe/00051510.png
--------------------------------------------------------------------------------
/docs/figures/v3-wireframe/00074259.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/docs/figures/v3-wireframe/00074259.png
--------------------------------------------------------------------------------
/downloads.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | wget https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv2/hawpv2-edb9b23f.pth -P checkpoints
4 | wget https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv3/hawpv3-ce3ae2cb.pth -P checkpoints
5 | wget https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv3/hawpv3-fdc5487a.pth -P checkpoints
6 | wget https://github.com/cherubicXN/hawp-torchhub/releases/download/HAWPv3/hawpv3-imagenet-03a84.pth -P checkpoints
7 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/CSA++/GNUmakefile:
--------------------------------------------------------------------------------
1 | # use gmake!
2 |
3 | srcs := csa.cc
4 | hdrs := csa.hh csa_types.h csa_defs.h
5 | matlab := csaAssign.m sparsify.m
6 | mex := csaAssign.cc
7 | lib := libcsa.a
8 |
9 | cxxFlags := -O3
10 | mexFlags :=
11 |
12 | include ../Util/GNUmakefile-library
13 |
14 | runtest:
15 | $(cxx) $(cxxFlags) -o test $(srcs) test.cc
16 | ./test
17 |
18 | clean::
19 | -rm -f test
20 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/CSA++/csa.cc:
--------------------------------------------------------------------------------
1 |
2 | #include "csa.hh"
3 |
4 | char* CSA::err_messages[] =
5 | {
6 | "Can't read from the input file.",
7 | "Not a correct assignment problem line.",
8 | "Error reading a node descriptor from the input.",
9 | "Error reading an arc descriptor from the input.",
10 | "Unknown line type in the input",
11 | "Inconsistent number of arcs in the input.",
12 | "Parsing noncontiguous node ID numbers not implemented.",
13 | "Can't obtain enough memory to solve this problem.",
14 | };
15 |
16 | char* CSA::nomem_msg = "Insufficient memory.\n";
17 |
18 | CSA::CSA (int n, int m, const int* graph)
19 | {
20 | assert(n>0);
21 | assert(m>0);
22 | assert(graph!=NULL);
23 | assert((n%2)==0);
24 | _init(n,m);
25 | main(graph);
26 | }
27 |
28 | CSA::~CSA ()
29 | {
30 | _delete();
31 | }
32 |
33 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/CSA++/csaAssign.m:
--------------------------------------------------------------------------------
1 | % [edges] = csaAssign(n,graph)
2 | %
3 | % Compute min-cost assignment with non-negative integral edge weights
4 | % using Andrew Goldberg's CSA package (precise costs version).
5 | %
6 | % INPUT
7 | % n Number of nodes in the bipartite graph (must be even).
8 | % graph 3xm matrix describing graph.
9 | %
10 | % OUTPUT
11 | % edges 3xn matrix of edges in assignment.
12 | %
13 | % You must ensure that an assignment involving all nodes exists, else
14 | % the code may hang. This is a feature of the CSA package. If your
15 | % problem does not necessarily provide such an assignment, then you
16 | % should overlay a high-cost perfect match as a safety net.
17 | %
18 | % Both graph and edges matrices have the same structure. Each column
19 | % gives a graph edge e. The two nodes are given by e(1) and e(2):
20 | %
21 | % e(1) < e(2)
22 | % 1 <= e(1) <= n/2
23 | % n/2 < e(2) <= n
24 | %
25 | % The edge weight is given by e(3).
26 | %
27 | % Since the output edge matrix should contain one reference to each
28 | % node, sum(sum(edges(1:2,:))) == n*(n+1)/2.
29 | %
30 | % David Martin
31 | % January, 2003
32 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/CSA++/csaAssign_c:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 | #include "csa.hh"
5 |
6 | extern "C" {
7 |
8 | void
9 | mexFunction (
10 | int nlhs, mxArray* plhs[],
11 | int nrhs, const mxArray* prhs[])
12 | {
13 | // Check argument counts.
14 | if (nlhs < 1) {
15 | mexErrMsgTxt("Too few output arguments.");
16 | }
17 | if (nlhs > 1) {
18 | mexErrMsgTxt("Too many output arguments.");
19 | }
20 | if (nrhs < 2) {
21 | mexErrMsgTxt("Too few input arguments.");
22 | }
23 | if (nrhs > 2) {
24 | mexErrMsgTxt("Too many input arguments.");
25 | }
26 |
27 | // Get input arguments.
28 | const int n = (int) mxGetScalar (prhs[0]);
29 | const double* g = mxGetPr (prhs[1]);
30 | const int three = mxGetM (prhs[1]);
31 | const int m = mxGetN (prhs[1]);
32 |
33 | // Check input arguments.
34 | if (n < 1) {
35 | mexErrMsgTxt("n must be >0");
36 | }
37 | if ((n%2) != 0) {
38 | mexErrMsgTxt("n must be even");
39 | }
40 | if (m < 1) {
41 | mexErrMsgTxt("m must be >0");
42 | }
43 | if (three != 3) {
44 | mexErrMsgTxt("graph matrix must be 3xM");
45 | }
46 |
47 | // Build the input graph and check the data.
48 | int* graph = new int[m*3];
49 | int maxc = 0;
50 | for (int i = 0; i < m; i++) {
51 | int a = (int) g[3*i+0];
52 | int b = (int) g[3*i+1];
53 | int c = (int) g[3*i+2];
54 | graph[3*i+0] = a;
55 | graph[3*i+1] = b;
56 | graph[3*i+2] = c;
57 | if (a < 1 || a > n/2) {
58 | mexErrMsgTxt("edge tail not in [1,n/2]");
59 | }
60 | if (b <= n/2 || b > n) {
61 | mexErrMsgTxt("edge head not in (n/2,n]");
62 | }
63 | if (c < 0) {
64 | mexErrMsgTxt("edge weights must be non-negative");
65 | }
66 | maxc = std::max(c,maxc);
67 | }
68 |
69 | // The CSA package segfaults if all the edge weights are zero.
70 | // In this case, set all the weights to one, and then later
71 | // remember to set the returned graph weights back to zero.
72 | if (maxc == 0) {
73 | for (int i = 0; i < m; i++) {
74 | graph[3*i+2] = 1;
75 | }
76 | }
77 |
78 | // Run CSA. It will either run successfully or segfault or loop
79 | // forever or return garbage. But it claims to always return a
80 | // valid result if the input is valid. The checks above try to
81 | // ensure that the input is ok, but I don't check that a perfect
82 | // match is present (which CSA requires but does not check for,
83 | // grumble, grumble). In that case, you're on your own, since
84 | // I'm not sure how to quickly check for that condition.
85 | CSA csa (n, m, graph);
86 | int e = csa.edges();
87 |
88 | // Done with input graph.
89 | delete [] graph;
90 | graph = NULL;
91 |
92 | // Construct result.
93 | plhs[0] = mxCreateDoubleMatrix(3, e, mxREAL);
94 | double* points = mxGetPr(plhs[0]);
95 | for (int i = 0; i < e; i++) {
96 | int a, b, cost;
97 | csa.edge(i,a,b,cost);
98 | points[i*3+0] = a;
99 | points[i*3+1] = b;
100 | points[i*3+2] = (maxc==0) ? 0 : cost;
101 | }
102 | }
103 |
104 | }; // extern "C"
105 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/CSA++/sparsify.m:
--------------------------------------------------------------------------------
1 | function graph = sparsify(sm)
2 | % function graph = sparsify(sm)
3 | %
4 | % Given a smilarity matrix sm, return a sparse graph representation
5 | % suitable for csaAssign. For example, if sm is an nxn similarity
6 | % matrix, then Hungarian-style matching can be accomplished by
7 | % executing csaAssign(2*n,sparsify(sm)).
8 | %
9 | % See also csaAssign.
10 | %
11 | % David Martin
12 | % March, 2003
13 |
14 | if ndims(sm)~=2 | size(sm,1)~=size(sm,2),
15 | error('sm must be a square matrix');
16 | end
17 |
18 | n = size(sm,1);
19 | graph = zeros(3,n*n);
20 | graph(1,:) = repmat(1:n,1,n);
21 | graph(2,:) = n + reshape(repmat(1:n,n,1),1,n*n);
22 | graph(3,:) = reshape(sm,1,n*n);
23 |
24 |
25 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/CSA++/test.cc:
--------------------------------------------------------------------------------
1 |
2 | #include "csa.hh"
3 |
4 | int n = 6;
5 | int m = 6;
6 | static int data[] = {
7 | 1, 4, 3,
8 | 2, 5, 3,
9 | 3, 6, 3,
10 | 1, 5, 1,
11 | 2, 6, 1,
12 | 3, 4, 5,
13 | };
14 |
15 | int
16 | main(int argc, char** argv)
17 | {
18 | for (int iter = 0; iter < 10; iter++) {
19 | CSA csa (n, m, data);
20 | for (int i = 0; i < csa.edges(); i++) {
21 | int a, b, cost;
22 | csa.edge(i,a,b,cost);
23 | fprintf (stderr, "%d %d %d\n", a, b, cost);
24 | }
25 | fprintf (stderr, "TOTAL %d\n", csa.cost());
26 | }
27 | return 0;
28 | }
29 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/CSA++/test.m:
--------------------------------------------------------------------------------
1 | % toy test
2 | g1 = [ 1 2 3 1 2 3 ;
3 | 4 5 6 5 6 4 ;
4 | 3 3 3 1 1 5 ];
5 | n = 6;
6 | e1 = csaAssign(n,g1)
7 |
8 | % big random test
9 | n = 1000;
10 | dg = (rand(n,n) > 0.5);
11 | m = sum(dg(:));
12 | i = find(dg==1)' - 1;
13 | g2 = [ 1 + floor(i/n) ;
14 | 1 + mod(i,n) + n ;
15 | 1 + floor(rand(1,m)*1000) ];
16 | tic;
17 | e2 = csaAssign(2*n,g2);
18 | toc;
19 | if sum(e2(1,:)) ~= n*(n+1)/2, error('bug'); end
20 | if sum(e2(2,:)) ~= n*(n+1)/2 + n*n, error('bug'); end
21 | if sum(sum(e2(1:2,:))) ~= 2*n*(2*n+1)/2, error('bug'); end
22 | disp('[n m cost] = ');
23 | [n m sum(e2(3,:))]
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/CSA++/test.txt:
--------------------------------------------------------------------------------
1 | p asn 6 6
2 | n 1
3 | n 2
4 | n 3
5 | a 1 4 3
6 | a 2 5 3
7 | a 3 6 3
8 | a 1 5 1
9 | a 2 6 1
10 | a 3 4 5
11 |
12 |
13 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/EdgeMapEval.cc:
--------------------------------------------------------------------------------
1 | #include "Matrix.hh"
2 | #include "match.hh"
3 | #include
4 | #include
5 | #include
6 | #include "correspondPixels.h"
7 |
8 |
9 | void correspondPixels(double* bmap1,
10 | double* bmap2,
11 | double* match1,
12 | double* match2,
13 | double& cost,
14 | double maxDist,
15 | double outlierCost,
16 | int height,
17 | int width)
18 | {
19 | const int rows = width;
20 | const int cols = height;
21 |
22 | const double idiag = sqrt(float(rows*rows+cols*cols));
23 | const double oc = outlierCost*maxDist*idiag;
24 |
25 | Matrix m1,m2;
26 | cost = matchEdgeMaps(
27 | Matrix(rows,cols,bmap1), Matrix(rows,cols,bmap2),
28 | maxDist*idiag, oc, m1, m2);
29 |
30 | memcpy(match1, m1.data(), m1.numel()*sizeof(double));
31 | memcpy(match2, m2.data(), m2.numel()*sizeof(double));
32 | }
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/Exception.cc:
--------------------------------------------------------------------------------
1 |
2 | // Copyright (C) 2002 David R. Martin
3 | //
4 | // This program is free software; you can redistribute it and/or
5 | // modify it under the terms of the GNU General Public License as
6 | // published by the Free Software Foundation; either version 2 of the
7 | // License, or (at your option) any later version.
8 | //
9 | // This program is distributed in the hope that it will be useful, but
10 | // WITHOUT ANY WARRANTY; without even the implied warranty of
11 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 | // General Public License for more details.
13 | //
14 | // You should have received a copy of the GNU General Public License
15 | // along with this program; if not, write to the Free Software
16 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
17 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
18 |
19 | #include
20 | #include
21 | #include "Exception.hh"
22 |
23 | Exception::Exception (const char* msg)
24 | : _msg (strdup (msg))
25 | {
26 | }
27 |
28 | Exception::Exception (const Exception& that)
29 | : _msg (strdup (that._msg))
30 | {
31 | }
32 |
33 | Exception::~Exception ()
34 | {
35 | free (_msg);
36 | }
37 |
38 | const char*
39 | Exception::msg () const
40 | {
41 | return _msg;
42 | }
43 |
44 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/Exception.hh:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __Exception_hh__
3 | #define __Exception_hh__
4 |
5 | // A simple exception class that contains an error message.
6 |
7 | // Copyright (C) 2002 David R. Martin
8 | //
9 | // This program is free software; you can redistribute it and/or
10 | // modify it under the terms of the GNU General Public License as
11 | // published by the Free Software Foundation; either version 2 of the
12 | // License, or (at your option) any later version.
13 | //
14 | // This program is distributed in the hope that it will be useful, but
15 | // WITHOUT ANY WARRANTY; without even the implied warranty of
16 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 | // General Public License for more details.
18 | //
19 | // You should have received a copy of the GNU General Public License
20 | // along with this program; if not, write to the Free Software
21 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
23 |
24 | #include
25 |
26 | class Exception
27 | {
28 | public:
29 |
30 | // Always construct exception with a message, so we can print
31 | // a useful error/log message.
32 | Exception (const char* msg);
33 |
34 | // We need to implement the copy constructor so that rethrowing
35 | // works.
36 | Exception (const Exception& that);
37 |
38 | virtual ~Exception ();
39 |
40 | // Retrieve the message that this exception carries.
41 | virtual const char* msg () const;
42 |
43 | protected:
44 |
45 | char* _msg;
46 |
47 | };
48 |
49 | // write to output stream
50 | inline std::ostream& operator<< (std::ostream& out, const Exception& e) {
51 | out << e.msg();
52 | return out;
53 | }
54 |
55 | #endif // __Exception_hh__
56 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/GNUmakefile:
--------------------------------------------------------------------------------
1 | # use gmake!
2 |
3 | srcs := Exception.cc String.cc Random.cc Timer.cc Matrix.cc kofn.cc
4 | hdrs := Exception.hh String.hh Random.hh Timer.hh Matrix.hh \
5 | Sort.hh Point.hh Array.hh kofn.hh
6 | matlab := isum.m kmeansML.m distSqr.m fftconv2.m padReflect.m \
7 | Lab2RGB.m RGB2Lab.m logist2.m progbar.m
8 | mex := isum.c
9 | mexLibs := -lutil
10 |
11 | lib := libutil.a
12 | cxxFlags := -O3 -DNOBLAS
13 |
14 | include ./GNUmakefile-library
15 |
16 | # test of Matrix module
17 | test:
18 | # g++ -g -Wall -o test test.cc build/ix86_linux/libutil.a -lgsl -lgslcblas -lm
19 | # g++ -g -Wall -o test test.cc build/ix86_linux/libutil.a -L/usr/mill/lib -lblas -lpgf90 -lpgf90_rpm1 -lpgf902 -lpgf90rtl -lpgftnrtl -lpgc -lm
20 | # g++ -g -Wall -o test test.cc build/ix86_linux/libutil.a -lblas -lg2c -lm
21 | # g++ -g -Wall -o test test.cc -L./build/ix86_linux -L/home/cs/dmartin/lib/$(hostType) -lutil -lf77blas -latlas -lg2c -lm
22 | g++ -g -Wall -o test test.cc build/ix86_linux/libutil.a -lblas -lm
23 |
24 | # eof
25 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/Lab2RGB.m:
--------------------------------------------------------------------------------
1 | function [R, G, B] = Lab2RGB(L, a, b)
2 | % function [R, G, B] = Lab2RGB(L, a, b)
3 | % Lab2RGB takes matrices corresponding to L, a, and b in CIELab space
4 | % and transforms them into RGB. This transform is based on ITU-R
5 | % Recommendation BT.709 using the D65 white point reference.
6 | % and the error in transforming RGB -> Lab -> RGB is approximately
7 | % 10^-5. By Mark Ruzon from C code by Yossi Rubner, 23 September 1997.
8 | % Updated for MATLAB 5 28 January 1998.
9 | % Fixed a bug in conversion back to uint8 9 September 1999.
10 |
11 | if (nargin == 1)
12 | b = L(:,:,3);
13 | a = L(:,:,2);
14 | L = L(:,:,1);
15 | end
16 |
17 | % Thresholds
18 | T1 = 0.008856;
19 | T2 = 0.206893;
20 |
21 | [M, N] = size(L);
22 | s = M * N;
23 | L = reshape(L, 1, s);
24 | a = reshape(a, 1, s);
25 | b = reshape(b, 1, s);
26 |
27 | % Compute Y
28 | fY = ((L + 16) / 116) .^ 3;
29 | YT = fY > T1;
30 | fY = (~YT) .* (L / 903.3) + YT .* fY;
31 | Y = fY;
32 |
33 | % Alter fY slightly for further calculations
34 | fY = YT .* (fY .^ (1/3)) + (~YT) .* (7.787 .* fY + 16/116);
35 |
36 | % Compute X
37 | fX = a / 500 + fY;
38 | XT = fX > T2;
39 | X = (XT .* (fX .^ 3) + (~XT) .* ((fX - 16/116) / 7.787));
40 |
41 | % Compute Z
42 | fZ = fY - b / 200;
43 | ZT = fZ > T2;
44 | Z = (ZT .* (fZ .^ 3) + (~ZT) .* ((fZ - 16/116) / 7.787));
45 |
46 | X = X * 0.950456;
47 | Z = Z * 1.088754;
48 |
49 | MAT = [ 3.240479 -1.537150 -0.498535;
50 | -0.969256 1.875992 0.041556;
51 | 0.055648 -0.204043 1.057311];
52 |
53 | RGB = max(min(MAT * [X; Y; Z], 1), 0);
54 |
55 | R = reshape(RGB(1,:), M, N) * 255;
56 | G = reshape(RGB(2,:), M, N) * 255;
57 | B = reshape(RGB(3,:), M, N) * 255;
58 |
59 | if ((nargout == 1) | (nargout == 0))
60 | R = uint8(round(cat(3,R,G,B)));
61 | end
62 |
63 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/Point.hh:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __Point_hh__
3 | #define __Point_hh__
4 |
5 | // Simple point template classes.
6 | // Probably only make sense for intrinsic types.
7 |
8 | // 2D Points
9 |
10 | template
11 | class Point2D
12 | {
13 | public:
14 | Point2D () { x = 0; y = 0; }
15 | Point2D (T x, T y) { this->x = x; this->y = y; }
16 | T x,y;
17 | };
18 |
19 | template
20 | inline int operator== (const Point2D& a, const Point2D& b)
21 | { return (a.x == b.x) && (a.y == b.y); }
22 |
23 | template
24 | inline int operator!= (const Point2D& a, const Point2D& b)
25 | { return (a.x != b.x) || (a.y != b.y); }
26 |
27 | typedef Point2D Pixel;
28 |
29 | // 3D Points
30 |
31 | template
32 | class Point3D
33 | {
34 | public:
35 | Point3D () { x = 0; y = 0; z = 0; }
36 | Point3D (T x, T y) { this->x = x; this->y = y; this->z = z;}
37 | T x,y,z;
38 | };
39 |
40 | template
41 | inline int operator== (const Point3D& a, const Point3D& b)
42 | { return (a.x == b.x) && (a.y == b.y) && (a.z == b.z); }
43 |
44 | template
45 | inline int operator!= (const Point3D& a, const Point3D& b)
46 | { return (a.x != b.x) || (a.y != b.y) || (a.z != b.z); }
47 |
48 | typedef Point3D Voxel;
49 |
50 | #endif // __Point_hh__
51 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/RGB2Lab.m:
--------------------------------------------------------------------------------
1 | function [L,a,b] = RGB2Lab(R,G,B)
2 | % function [L, a, b] = RGB2Lab(R, G, B)
3 | % RGB2Lab takes matrices corresponding to Red, Green, and Blue, and
4 | % transforms them into CIELab. This transform is based on ITU-R
5 | % Recommendation BT.709 using the D65 white point reference.
6 | % The error in transforming RGB -> Lab -> RGB is approximately
7 | % 10^-5. RGB values can be either between 0 and 1 or between 0 and 255.
8 | % By Mark Ruzon from C code by Yossi Rubner, 23 September 1997.
9 | % Updated for MATLAB 5 28 January 1998.
10 |
11 | if (nargin == 1)
12 | B = double(R(:,:,3));
13 | G = double(R(:,:,2));
14 | R = double(R(:,:,1));
15 | end
16 |
17 | if ((max(max(R)) > 1.0) | (max(max(G)) > 1.0) | (max(max(B)) > 1.0))
18 | R = R/255;
19 | G = G/255;
20 | B = B/255;
21 | end
22 |
23 | [M, N] = size(R);
24 | s = M*N;
25 |
26 | % Set a threshold
27 | T = 0.008856;
28 |
29 | RGB = [reshape(R,1,s); reshape(G,1,s); reshape(B,1,s)];
30 |
31 | % RGB to XYZ
32 | MAT = [0.412453 0.357580 0.180423;
33 | 0.212671 0.715160 0.072169;
34 | 0.019334 0.119193 0.950227];
35 | XYZ = MAT * RGB;
36 |
37 | X = XYZ(1,:) / 0.950456;
38 | Y = XYZ(2,:);
39 | Z = XYZ(3,:) / 1.088754;
40 |
41 | XT = X > T;
42 | YT = Y > T;
43 | ZT = Z > T;
44 |
45 | fX = XT .* X.^(1/3) + (~XT) .* (7.787 .* X + 16/116);
46 |
47 | % Compute L
48 | Y3 = Y.^(1/3);
49 | fY = YT .* Y3 + (~YT) .* (7.787 .* Y + 16/116);
50 | L = YT .* (116 * Y3 - 16.0) + (~YT) .* (903.3 * Y);
51 |
52 | fZ = ZT .* Z.^(1/3) + (~ZT) .* (7.787 .* Z + 16/116);
53 |
54 | % Compute a and b
55 | a = 500 * (fX - fY);
56 | b = 200 * (fY - fZ);
57 |
58 | L = reshape(L, M, N);
59 | a = reshape(a, M, N);
60 | b = reshape(b, M, N);
61 |
62 | if ((nargout == 1) | (nargout == 0))
63 | L = cat(3,L,a,b);
64 | end
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/Random.cc:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include "Random.hh"
8 | #include "String.hh"
9 | #include "Exception.hh"
10 |
11 | // Copyright (C) 2002 David R. Martin
12 | //
13 | // This program is free software; you can redistribute it and/or
14 | // modify it under the terms of the GNU General Public License as
15 | // published by the Free Software Foundation; either version 2 of the
16 | // License, or (at your option) any later version.
17 | //
18 | // This program is distributed in the hope that it will be useful, but
19 | // WITHOUT ANY WARRANTY; without even the implied warranty of
20 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21 | // General Public License for more details.
22 | //
23 | // You should have received a copy of the GNU General Public License
24 | // along with this program; if not, write to the Free Software
25 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
27 |
28 | Random Random::rand;
29 |
30 | Random::Random ()
31 | {
32 | reseed (0);
33 | }
34 |
35 | Random::Random (u_int64_t seed)
36 | {
37 | reseed (seed);
38 | }
39 |
40 | Random::Random (Random& that)
41 | {
42 | u_int64_t a = that.ui32 ();
43 | u_int64_t b = that.ui32 ();
44 | u_int64_t seed = (a << 32) | b;
45 | _init (seed);
46 | }
47 |
48 | void
49 | Random::reset ()
50 | {
51 | _init (_seed);
52 | }
53 |
54 | void
55 | Random::reseed (u_int64_t seed)
56 | {
57 | if (seed == 0) {
58 | struct timeval t;
59 | gettimeofday (&t, NULL);
60 | u_int64_t a = (t.tv_usec >> 3) & 0xffff;
61 | u_int64_t b = t.tv_sec & 0xffff;
62 | u_int64_t c = (t.tv_sec >> 16) & 0xffff;
63 | seed = a | (b << 16) | (c << 32);
64 | }
65 | _init (seed);
66 | }
67 |
68 | void
69 | Random::_init (u_int64_t seed)
70 | {
71 | _seed = seed & 0xffffffffffffull;
72 | _xsubi[0] = (seed >> 0) & 0xffff;
73 | _xsubi[1] = (seed >> 16) & 0xffff;
74 | _xsubi[2] = (seed >> 32) & 0xffff;
75 | }
76 |
77 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/Random.hh:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __Random_hh__
3 | #define __Random_hh__
4 |
5 | // Copyright (C) 2002 David R. Martin
6 | //
7 | // This program is free software; you can redistribute it and/or
8 | // modify it under the terms of the GNU General Public License as
9 | // published by the Free Software Foundation; either version 2 of the
10 | // License, or (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful, but
13 | // WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 | // General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program; if not, write to the Free Software
19 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
20 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
21 |
22 | #include
23 | #include
24 | #include
25 | #include
26 |
27 | // All random numbers are generated from a single seed. This is true
28 | // even when private random streams (seperate from the global
29 | // Random::rand stream) are spawned from existing streams, since the new
30 | // streams are seeded automatically from the parent's random stream.
31 | // Any random stream can be reset so that a sequence of random values
32 | // can be replayed.
33 |
34 | // If seed==0, then the seed is generated from the system clock.
35 |
36 | class Random
37 | {
38 | public:
39 |
40 | static Random rand;
41 |
42 | // These are defined in as the limits of int, but
43 | // here we need the limits of int32_t.
44 | static const int32_t int32_max = 2147483647;
45 | static const int32_t int32_min = -int32_max-1;
46 | static const u_int32_t u_int32_max = 4294967295u;
47 |
48 | // Seed from the system clock.
49 | Random ();
50 |
51 | // Specify seed.
52 | // If zero, seed from the system clock.
53 | Random (u_int64_t seed);
54 |
55 | // Spawn off a new random stream seeded from the parent's stream.
56 | Random (Random& that);
57 |
58 | // Restore initial seed so we can replay a random sequence.
59 | void reset ();
60 |
61 | // Set the seed.
62 | // If zero, seed from the system clock.
63 | void reseed (u_int64_t seed);
64 |
65 | // double in [0..1) or [a..b)
66 | inline double fp ();
67 | inline double fp (double a, double b);
68 |
69 | // 32-bit signed integer in [-2^31,2^31) or [a..b]
70 | inline int32_t i32 ();
71 | inline int32_t i32 (int32_t a, int32_t b);
72 |
73 | // 32-bit unsigned integer in [0,2^32) or [a..b]
74 | inline u_int32_t ui32 ();
75 | inline u_int32_t ui32 (u_int32_t a, u_int32_t b);
76 |
77 | protected:
78 |
79 | void _init (u_int64_t seed);
80 |
81 | // The original seed for this random stream.
82 | u_int64_t _seed;
83 |
84 | // The current state for this random stream.
85 | u_int16_t _xsubi[3];
86 |
87 | };
88 |
89 | inline u_int32_t
90 | Random::ui32 ()
91 | {
92 | return ui32(0,u_int32_max);
93 | }
94 |
95 | inline u_int32_t
96 | Random::ui32 (u_int32_t a, u_int32_t b)
97 | {
98 | assert (a <= b);
99 | double x = fp ();
100 | return (u_int32_t) floor (x * ((double)b - (double)a + 1) + a);
101 | }
102 |
103 | inline int32_t
104 | Random::i32 ()
105 | {
106 | return i32(int32_min,int32_max);
107 | }
108 |
109 | inline int32_t
110 | Random::i32 (int32_t a, int32_t b)
111 | {
112 | assert (a <= b);
113 | double x = fp ();
114 | return (int32_t) floor (x * ((double)b - (double)a + 1) + a);
115 | }
116 |
117 | inline double
118 | Random::fp ()
119 | {
120 | return erand48 (_xsubi);
121 | }
122 |
123 | inline double
124 | Random::fp (double a, double b)
125 | {
126 | assert (a < b);
127 | return erand48 (_xsubi) * (b - a) + a;
128 | }
129 |
130 | #endif // __Random_hh__
131 |
132 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/Timer.cc:
--------------------------------------------------------------------------------
1 |
2 | // Copyright (C) 2002 David R. Martin
3 | //
4 | // This program is free software; you can redistribute it and/or
5 | // modify it under the terms of the GNU General Public License as
6 | // published by the Free Software Foundation; either version 2 of the
7 | // License, or (at your option) any later version.
8 | //
9 | // This program is distributed in the hope that it will be useful, but
10 | // WITHOUT ANY WARRANTY; without even the implied warranty of
11 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 | // General Public License for more details.
13 | //
14 | // You should have received a copy of the GNU General Public License
15 | // along with this program; if not, write to the Free Software
16 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
17 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
18 |
19 | #include
20 | #include
21 | #include
22 | #include "Timer.hh"
23 |
24 | typedef unsigned long long uint64;
25 |
26 | void
27 | Timer::_compute ()
28 | {
29 | // Compute elapsed time.
30 | long sec = _elapsed_stop.tv_sec - _elapsed_start.tv_sec;
31 | long usec = _elapsed_stop.tv_usec - _elapsed_start.tv_usec;
32 | if (usec < 0) {
33 | sec -= 1;
34 | usec += 1000000;
35 | }
36 | _elapsed += (double) sec + usec / 1e6;
37 |
38 | // Computer CPU user and system times.
39 | _user += (double) (_cpu_stop.tms_utime - _cpu_start.tms_utime)
40 | / sysconf(_SC_CLK_TCK);
41 | _system += (double) (_cpu_stop.tms_stime - _cpu_start.tms_stime)
42 | / sysconf(_SC_CLK_TCK);
43 | }
44 |
45 | // Convert time in seconds into a nice human-friendly format: h:mm:ss.ss
46 | // Return a pointer to a static buffer.
47 | const char*
48 | Timer::formatTime (double sec, int precision)
49 | {
50 | static char buf[128];
51 |
52 | // Limit range of precision for safety and sanity.
53 | precision = (precision < 0) ? 0 : precision;
54 | precision = (precision > 9) ? 9 : precision;
55 | uint64 base = 1;
56 | for (int digit = 0; digit < precision; digit++) { base *= 10;}
57 |
58 | bool neg = (sec < 0);
59 | uint64 ticks = (uint64) rint (fabs (sec) * base);
60 | uint64 rsec = ticks / base; // Rounded seconds.
61 | uint64 frac = ticks % base;
62 |
63 | uint64 h = rsec / 3600;
64 | uint64 m = (rsec / 60) % 60;
65 | uint64 s = rsec % 60;
66 |
67 | sprintf (buf, "%s%llu:%02llu:%02llu",
68 | neg ? "-" : "", h, m, s);
69 |
70 | if (precision > 0) {
71 | static char fmt[10];
72 | sprintf (fmt, ".%%0%dlld", precision);
73 | sprintf (buf + strlen (buf), fmt, frac);
74 | }
75 |
76 | return buf;
77 | }
78 |
79 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/Timer.hh:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __Timer_hh__
3 | #define __Timer_hh__
4 |
5 | // Copyright (C) 2002 David R. Martin
6 | //
7 | // This program is free software; you can redistribute it and/or
8 | // modify it under the terms of the GNU General Public License as
9 | // published by the Free Software Foundation; either version 2 of the
10 | // License, or (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful, but
13 | // WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 | // General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program; if not, write to the Free Software
19 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
20 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
21 |
22 | #include
23 | #include
24 | #include
25 | #include
26 | #include
27 |
28 | class Timer
29 | {
30 | public:
31 |
32 | inline Timer ();
33 | inline ~Timer ();
34 |
35 | inline void start ();
36 | inline void stop ();
37 | inline void reset ();
38 |
39 | // All times are in seconds.
40 | inline double cpu ();
41 | inline double user ();
42 | inline double system ();
43 | inline double elapsed ();
44 |
45 | // Convert time in seconds into a nice human-friendly format: h:mm:ss.ss
46 | // Precision is the number of digits after the decimal.
47 | // Return a pointer to a static buffer.
48 | static const char* formatTime (double sec, int precision = 2);
49 |
50 | private:
51 |
52 | void _compute ();
53 |
54 | enum State { stopped, running };
55 |
56 | State _state;
57 |
58 | struct timeval _elapsed_start;
59 | struct timeval _elapsed_stop;
60 | double _elapsed;
61 |
62 | struct tms _cpu_start;
63 | struct tms _cpu_stop;
64 | double _user;
65 | double _system;
66 | };
67 |
68 | Timer::Timer ()
69 | {
70 | reset ();
71 | }
72 |
73 | Timer::~Timer ()
74 | {
75 | }
76 |
77 | void
78 | Timer::reset ()
79 | {
80 | _state = stopped;
81 | _elapsed = _user = _system = 0;
82 | }
83 |
84 | void
85 | Timer::start ()
86 | {
87 | assert (_state == stopped);
88 | _state = running;
89 | gettimeofday (&_elapsed_start, NULL);
90 | times (&_cpu_start);
91 | }
92 |
93 | void
94 | Timer::stop ()
95 | {
96 | assert (_state == running);
97 | gettimeofday (&_elapsed_stop, NULL);
98 | times (&_cpu_stop);
99 | _compute ();
100 | _state = stopped;
101 | }
102 |
103 | double
104 | Timer::cpu ()
105 | {
106 | assert (_state == stopped);
107 | return _user + _system;
108 | }
109 |
110 | double
111 | Timer::user ()
112 | {
113 | assert (_state == stopped);
114 | return _user;
115 | }
116 |
117 | double
118 | Timer::system ()
119 | {
120 | assert (_state == stopped);
121 | return _system;
122 | }
123 |
124 | double
125 | Timer::elapsed ()
126 | {
127 | assert (_state == stopped);
128 | return _elapsed;
129 | }
130 |
131 | #endif // __Timer_hh__
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/distSqr.m:
--------------------------------------------------------------------------------
1 | function z = distSqr(x,y)
2 | % function z = distSqr(x,y)
3 | %
4 | % Return matrix of all-pairs squared distances between the vectors
5 | % in the columns of x and y.
6 | %
7 | % INPUTS
8 | % x dxn matrix of vectors
9 | % y dxm matrix of vectors
10 | %
11 | % OUTPUTS
12 | % z nxm matrix of squared distances
13 | %
14 | % This routine is faster when mn.
15 | %
16 | % David Martin
17 | % March 2003
18 |
19 | % Based on dist2.m code,
20 | % Copyright (c) Christopher M Bishop, Ian T Nabney (1996, 1997)
21 |
22 | if size(x,1)~=size(y,1),
23 | error('size(x,1)~=size(y,1)');
24 | end
25 |
26 | [d,n] = size(x);
27 | [d,m] = size(y);
28 |
29 | % z = repmat(sum(x.^2)',1,m) ...
30 | % + repmat(sum(y.^2),n,1) ...
31 | % - 2*x'*y;
32 |
33 | z = x'*y;
34 | x2 = sum(x.^2)';
35 | y2 = sum(y.^2);
36 | for i = 1:m,
37 | z(:,i) = x2 + y2(i) - 2*z(:,i);
38 | end
39 |
40 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/fftconv2.m:
--------------------------------------------------------------------------------
1 | function fim = fftconv2(im,f)
2 | % function fim = fftconv2(im,f)
3 | %
4 | % Convolution using FFT.
5 | %
6 | % David R. Martin
7 | % March 2003
8 |
9 | % wrap the filter around the origin and pad with zeros
10 | padf = zeros(size(im));
11 | r = floor(size(f,1)/2);
12 | padf(1:r+1,1:r+1) = f(r+1:end,r+1:end);
13 | padf(1:r,end-r+1:end) = f(r+2:end,1:r);
14 | padf(end-r+1:end,1:r) = f(1:r,r+2:end);
15 | padf(end-r+1:end,end-r+1:end) = f(1:r,1:r);
16 |
17 | % magic
18 | fftim = fft2(im);
19 | fftf = fft2(padf);
20 | fim = real(ifft2(fftim.*fftf));
21 |
22 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/gethosttype:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #
3 | # Create a canonical system name "machine_system" for use in
4 | # bin and lib paths.
5 | #
6 | # If a name cannot be found, then "unknown" is the output,
7 | # and the exit status is 1.
8 | #
9 |
10 | uname_machine=`(uname -m) 2> /dev/null` || machine=unknown
11 | uname_system=`(uname -s) 2> /dev/null` || system=unknown
12 |
13 | case $uname_machine in
14 | i*86)
15 | machine=ix86;;
16 | ia64)
17 | machine=ia64;;
18 | IP27)
19 | machine=mips;;
20 | sun4*)
21 | machine=sparc;;
22 | x86_64)
23 | machine=x86_64;;
24 | *)
25 | machine=unknown;;
26 | esac
27 |
28 | if [ $machine = "unknown" ]; then
29 | echo unknown
30 | exit 1
31 | fi
32 |
33 | case $uname_system in
34 | Linux)
35 | system=linux;;
36 | IRIX*)
37 | system=irix;;
38 | SunOS)
39 | system=solaris;;
40 | NetBSD)
41 | system=netbsd;;
42 | *)
43 | system=unknown;;
44 | esac
45 |
46 | if [ $system = "unknown" ]; then
47 | echo unknown
48 | exit 1
49 | fi
50 |
51 | echo ${machine}_${system}
52 | exit 0
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/isum.c:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 | #include
5 |
6 | // If you have trouble mexifying this file, then consider using
7 | // the MATLAB code (isum.m) instead. So long as your MATLAB
8 | // version is at least 6.5, you won't suffer too much of a
9 | // performance penalty.
10 |
11 | void
12 | mexFunction (
13 | int nlhs, mxArray* plhs[],
14 | int nrhs, const mxArray* prhs[])
15 | {
16 | // Check number of arguments.
17 | if (nlhs < 1) {
18 | mexErrMsgTxt("Too few output arguments.");
19 | }
20 | if (nlhs > 1) {
21 | mexErrMsgTxt("Too many output arguments.");
22 | }
23 | if (nrhs < 3) {
24 | mexErrMsgTxt("Too few input arguments.");
25 | }
26 | if (nrhs > 3) {
27 | mexErrMsgTxt("Too many input arguments.");
28 | }
29 |
30 | const double* x = mxGetPr(prhs[0]);
31 | const double* idx = mxGetPr(prhs[1]);
32 | int nbins = (int)mxGetScalar(prhs[2]);
33 | if (nbins < 0) { nbins = 0; }
34 |
35 | // Check arguments.
36 | const int n = mxGetNumberOfElements(prhs[0]);
37 | if (n != mxGetNumberOfElements(prhs[1])) {
38 | mexErrMsgTxt("x and idx must be the same size");
39 | }
40 |
41 | // Do the reduction.
42 | plhs[0] = mxCreateDoubleMatrix(nbins,1,mxREAL);
43 | double* acc = mxGetPr(plhs[0]);
44 | memset(acc,0,nbins*sizeof(*acc));
45 | for (int i = 0; i < n; i++) {
46 | int v = (int)idx[i];
47 | if (v < 1) { continue; }
48 | if (v > nbins) { continue; }
49 | acc[v-1] += x[i];
50 | }
51 | }
52 |
53 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/isum.m:
--------------------------------------------------------------------------------
1 | function acc = isum(x,idx,nbins)
2 | % function acc = isum(x,idx,nbins)
3 | %
4 | % Indexed sum reduction, where acc(i) contains the sum of
5 | % x(find(idx==i)).
6 | %
7 | % The mex version is 300x faster in R12, and 4x faster in R13. As far
8 | % as I can tell, there is no way to do this efficiently in matlab R12.
9 | %
10 | % David R. Martin
11 | % March 2003
12 |
13 | acc = zeros(nbins,1);
14 | for i = 1:numel(x),
15 | if idx(i)<1, continue; end
16 | if idx(i)>nbins, continue; end
17 | acc(idx(i)) = acc(idx(i)) + x(i);
18 | end
19 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/kofn.cc:
--------------------------------------------------------------------------------
1 |
2 | #include "Random.hh"
3 | #include "kofn.hh"
4 |
5 | // O(n) implementation.
6 | static void
7 | _kOfN_largeK (int k, int n, int* values)
8 | {
9 | assert (k > 0);
10 | assert (k <= n);
11 | int j = 0;
12 | for (int i = 0; i < n; i++) {
13 | double prob = (double) (k - j) / (n - i);
14 | assert (prob <= 1);
15 | double x = Random::rand.fp ();
16 | if (x < prob) {
17 | values[j++] = i;
18 | }
19 | }
20 | assert (j == k);
21 | }
22 |
23 | // O(k*lg(k)) implementation; constant factor is about 2x the constant
24 | // factor for the O(n) implementation.
25 | static void
26 | _kOfN_smallK (int k, int n, int* values)
27 | {
28 | assert (k > 0);
29 | assert (k <= n);
30 | if (k == 1) {
31 | values[0] = Random::rand.i32 (0, n - 1);
32 | return;
33 | }
34 | int leftN = n / 2;
35 | int rightN = n - leftN;
36 | int leftK = 0;
37 | int rightK = 0;
38 | for (int i = 0; i < k; i++) {
39 | int x = Random::rand.i32 (0, n - i - 1);
40 | if (x < leftN - leftK) {
41 | leftK++;
42 | } else {
43 | rightK++;
44 | }
45 | }
46 | if (leftK > 0) { _kOfN_smallK (leftK, leftN, values); }
47 | if (rightK > 0) { _kOfN_smallK (rightK, rightN, values + leftK); }
48 | for (int i = leftK; i < k; i++) {
49 | values[i] += leftN;
50 | }
51 | }
52 |
53 | // Return k randomly selected integers from the interval [0,n), in
54 | // increasing sorted order.
55 | void
56 | kOfN (int k, int n, int* values)
57 | {
58 | assert (k >= 0);
59 | assert (n >= 0);
60 | if (k == 0) { return; }
61 | static double log2 = log (2);
62 | double klogk = k * log (k) / log2;
63 | if (klogk < n / 2) {
64 | _kOfN_smallK (k, n, values);
65 | } else {
66 | _kOfN_largeK (k, n, values);
67 | }
68 | }
69 |
70 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/kofn.hh:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __kofn_hh__
3 | #define __kofn_hh__
4 |
5 | void kOfN (int k, int n, int* values);
6 |
7 | #endif // __kofn_hh__
8 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/logist2.m:
--------------------------------------------------------------------------------
1 | function [beta,p,lli] = logist2(y,x,w)
2 | % [beta,p,lli] = logist2(y,x)
3 | %
4 | % 2-class logistic regression.
5 | %
6 | % INPUT
7 | % y Nx1 colum vector of 0|1 class assignments
8 | % x NxK matrix of input vectors as rows
9 | % [w] Nx1 vector of sample weights
10 | %
11 | % OUTPUT
12 | % beta Kx1 column vector of model coefficients
13 | % p Nx1 column vector of fitted class 1 posteriors
14 | % lli log likelihood
15 | %
16 | % Class 1 posterior is 1 / (1 + exp(-x*beta))
17 | %
18 | % David Martin
19 | % April 16, 2002
20 |
21 | % Copyright (C) 2002 David R. Martin
22 | %
23 | % This program is free software; you can redistribute it and/or
24 | % modify it under the terms of the GNU General Public License as
25 | % published by the Free Software Foundation; either version 2 of the
26 | % License, or (at your option) any later version.
27 | %
28 | % This program is distributed in the hope that it will be useful, but
29 | % WITHOUT ANY WARRANTY; without even the implied warranty of
30 | % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
31 | % General Public License for more details.
32 | %
33 | % You should have received a copy of the GNU General Public License
34 | % along with this program; if not, write to the Free Software
35 | % Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
36 | % 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
37 |
38 | error(nargchk(2,3,nargin));
39 |
40 | % check inputs
41 | if size(y,2) ~= 1,
42 | error('Input y not a column vector.');
43 | end
44 | if size(y,1) ~= size(x,1),
45 | error('Input x,y sizes mismatched.');
46 | end
47 |
48 | % get sizes
49 | [N,k] = size(x);
50 |
51 | % if sample weights weren't specified, set them to 1
52 | if nargin < 3,
53 | w = 1;
54 | end
55 |
56 | % normalize sample weights so max is 1
57 | w = w / max(w);
58 |
59 | % initial guess for beta: all zeros
60 | beta = zeros(k,1);
61 |
62 | % Newton-Raphson via IRLS,
63 | % taken from Hastie/Tibshirani/Friedman Section 4.4.
64 | iter = 0;
65 | lli = 0;
66 | while 1==1,
67 | iter = iter + 1;
68 |
69 | % fitted probabilities
70 | p = 1 ./ (1 + exp(-x*beta));
71 |
72 | % log likelihood
73 | lli_prev = lli;
74 | lli = sum( w .* (y.*log(p+eps) + (1-y).*log(1-p+eps)) );
75 |
76 | % least-squares weights
77 | wt = w .* p .* (1-p);
78 |
79 | % derivatives of likelihood w.r.t. beta
80 | deriv = x'*(w.*(y-p));
81 |
82 | % Hessian of likelihood w.r.t. beta
83 | % hessian = x'Wx, where W=diag(w)
84 | % Do it this way to be memory efficient and fast.
85 | hess = zeros(k,k);
86 | for i = 1:k,
87 | wxi = wt .* x(:,i);
88 | for j = i:k,
89 | hij = wxi' * x(:,j);
90 | hess(i,j) = -hij;
91 | hess(j,i) = -hij;
92 | end
93 | end
94 |
95 | % make sure Hessian is well conditioned
96 | if (rcond(hess) < eps),
97 | error(['Stopped at iteration ' num2str(iter) ...
98 | ' because Hessian is poorly conditioned.']);
99 | break;
100 | end;
101 |
102 | % Newton-Raphson update step
103 | step = hess\deriv;
104 | beta = beta - step;
105 |
106 | % termination criterion based on derivatives
107 | tol = 1e-6;
108 | if abs(deriv'*step/k) < tol, break; end;
109 |
110 | % termination criterion based on log likelihood
111 | % tol = 1e-4;
112 | % if abs((lli-lli_prev)/(lli+lli_prev)) < 0.5*tol, break; end;
113 | end;
114 |
115 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/padReflect.m:
--------------------------------------------------------------------------------
1 | function [impad] = padReflect(im,r)
2 | % function [impad] = padReflect(im,r)
3 | %
4 | % Pad an image with a border of size r, and reflect the image into
5 | % the border.
6 | %
7 | % David R. Martin
8 | % March 2003
9 |
10 | impad = zeros(size(im)+2*r);
11 | impad(r+1:end-r,r+1:end-r) = im; % middle
12 | impad(1:r,r+1:end-r) = flipud(im(1:r,:)); % top
13 | impad(end-r+1:end,r+1:end-r) = flipud(im(end-r+1:end,:)); % bottom
14 | impad(r+1:end-r,1:r) = fliplr(im(:,1:r)); % left
15 | impad(r+1:end-r,end-r+1:end) = fliplr(im(:,end-r+1:end)); % right
16 | impad(1:r,1:r) = flipud(fliplr(im(1:r,1:r))); % top-left
17 | impad(1:r,end-r+1:end) = flipud(fliplr(im(1:r,end-r+1:end))); % top-right
18 | impad(end-r+1:end,1:r) = flipud(fliplr(im(end-r+1:end,1:r))); % bottom-left
19 | impad(end-r+1:end,end-r+1:end) = flipud(fliplr(im(end-r+1:end,end-r+1:end))); % bottom-right
20 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/progbar.m:
--------------------------------------------------------------------------------
1 | function progbar(i,n,w)
2 | % function progbar(i,n,w)
3 | %
4 | % Display a textual progress bar.
5 | %
6 | % INPUTS
7 | % i Iteration number.
8 | % n Number of iterations.
9 | % [w=50] Width of bar.
10 | %
11 | % EXAMPLE
12 | %
13 | % progbar(0,n);
14 | % for i = 1:n,
15 | % compute();
16 | % progbar(i,n);
17 | % end
18 | %
19 | % David R. Martin
20 | % April 2002
21 |
22 | if nargin<3, w=50; end
23 | w = min(w,n);
24 |
25 | if i==0,
26 | fwrite(2,'[');
27 | for c = 1:w, fwrite(2,'.'); end
28 | fwrite(2,']');
29 | for c = 1:w+1, fwrite(2,sprintf('\b')); end
30 | return
31 | end
32 |
33 | if mod(i,n/w) <= mod(i-1,n/w),
34 | fwrite(2,'=');
35 | end
36 |
37 | if i==n,
38 | fprintf(2,'\n');
39 | end
40 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/Util/test:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/Util/test
--------------------------------------------------------------------------------
/evaluation/EdgeEval/__init__.py:
--------------------------------------------------------------------------------
1 | from .EdgeMapEval import correspond
2 |
3 | __all__ = ['correspond']
4 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/CSA++/csa.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/CSA++/csa.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/EdgeMapEval.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/EdgeMapEval.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Exception.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Exception.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Matrix.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Matrix.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Random.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Random.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/String.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/String.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Timer.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/Timer.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/kofn.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/Util/kofn.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/correspondPixels.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/correspondPixels.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/match.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.5/match.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/CSA++/csa.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/CSA++/csa.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/EdgeMapEval.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/EdgeMapEval.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Exception.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Exception.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Matrix.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Matrix.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Random.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Random.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/String.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/String.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Timer.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/Timer.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/kofn.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/Util/kofn.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/correspondPixels.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/correspondPixels.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/match.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/EdgeEval/build/temp.linux-x86_64-3.8/match.o
--------------------------------------------------------------------------------
/evaluation/EdgeEval/correspondPixels.h:
--------------------------------------------------------------------------------
1 | void correspondPixels(double* bmap1, /*input*/
2 | double* bmap2, /*input*/
3 | double* match1, /*output*/
4 | double* match2, /*output*/
5 | double& cost, /*output*/
6 | double maxDist, /*parameters*/
7 | double outlierCost, /*parameters*/
8 | int height, /*image size*/
9 | int width /*image size*/);
--------------------------------------------------------------------------------
/evaluation/EdgeEval/correspondPixels.pyx:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | cimport numpy as np
3 |
4 | assert sizeof(int) == sizeof(np.int32_t)
5 |
6 | cdef extern from "correspondPixels.h":
7 | void correspondPixels(double* bmap1,
8 | double* bmap2,
9 | double* match1,
10 | double* match2,
11 | double& cost,
12 | double maxDist,
13 | double outlierCost,
14 | int height,
15 | int width)
16 |
17 | def correspond(np.ndarray[double,ndim=2] bmap1,
18 | np.ndarray[double,ndim=2] bmap2,
19 | double maxDist = 0.001,
20 | double outlierCost = 100):
21 |
22 | assert bmap1.shape[0] == bmap2.shape[0] and bmap1.shape[1] == bmap2.shape[1]
23 |
24 | cdef int height = bmap1.shape[0]
25 | cdef int width = bmap1.shape[1]
26 |
27 | cdef np.ndarray[np.double_t, ndim=2] \
28 | match1 = np.zeros((height,width), dtype=np.double)
29 |
30 | cdef np.ndarray[np.double_t, ndim=2] \
31 | match2 = np.zeros((height,width), dtype=np.double)
32 | cdef double cost = 0
33 | correspondPixels( bmap1.data,
34 | bmap2.data,
35 | match1.data,
36 | match2.data,
37 | cost,
38 | maxDist,
39 | outlierCost,
40 | height,
41 | width)
42 |
43 | return match1, match2, cost
44 |
45 |
46 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/include/Exception.hh:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __Exception_hh__
3 | #define __Exception_hh__
4 |
5 | // A simple exception class that contains an error message.
6 |
7 | // Copyright (C) 2002 David R. Martin
8 | //
9 | // This program is free software; you can redistribute it and/or
10 | // modify it under the terms of the GNU General Public License as
11 | // published by the Free Software Foundation; either version 2 of the
12 | // License, or (at your option) any later version.
13 | //
14 | // This program is distributed in the hope that it will be useful, but
15 | // WITHOUT ANY WARRANTY; without even the implied warranty of
16 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 | // General Public License for more details.
18 | //
19 | // You should have received a copy of the GNU General Public License
20 | // along with this program; if not, write to the Free Software
21 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
23 |
24 | #include
25 |
26 | class Exception
27 | {
28 | public:
29 |
30 | // Always construct exception with a message, so we can print
31 | // a useful error/log message.
32 | Exception (const char* msg);
33 |
34 | // We need to implement the copy constructor so that rethrowing
35 | // works.
36 | Exception (const Exception& that);
37 |
38 | virtual ~Exception ();
39 |
40 | // Retrieve the message that this exception carries.
41 | virtual const char* msg () const;
42 |
43 | protected:
44 |
45 | char* _msg;
46 |
47 | };
48 |
49 | // write to output stream
50 | inline std::ostream& operator<< (std::ostream& out, const Exception& e) {
51 | out << e.msg();
52 | return out;
53 | }
54 |
55 | #endif // __Exception_hh__
56 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/include/Point.hh:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __Point_hh__
3 | #define __Point_hh__
4 |
5 | // Simple point template classes.
6 | // Probably only make sense for intrinsic types.
7 |
8 | // 2D Points
9 |
10 | template
11 | class Point2D
12 | {
13 | public:
14 | Point2D () { x = 0; y = 0; }
15 | Point2D (T x, T y) { this->x = x; this->y = y; }
16 | T x,y;
17 | };
18 |
19 | template
20 | inline int operator== (const Point2D& a, const Point2D& b)
21 | { return (a.x == b.x) && (a.y == b.y); }
22 |
23 | template
24 | inline int operator!= (const Point2D& a, const Point2D& b)
25 | { return (a.x != b.x) || (a.y != b.y); }
26 |
27 | typedef Point2D Pixel;
28 |
29 | // 3D Points
30 |
31 | template
32 | class Point3D
33 | {
34 | public:
35 | Point3D () { x = 0; y = 0; z = 0; }
36 | Point3D (T x, T y) { this->x = x; this->y = y; this->z = z;}
37 | T x,y,z;
38 | };
39 |
40 | template
41 | inline int operator== (const Point3D& a, const Point3D& b)
42 | { return (a.x == b.x) && (a.y == b.y) && (a.z == b.z); }
43 |
44 | template
45 | inline int operator!= (const Point3D& a, const Point3D& b)
46 | { return (a.x != b.x) || (a.y != b.y) || (a.z != b.z); }
47 |
48 | typedef Point3D Voxel;
49 |
50 | #endif // __Point_hh__
51 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/include/Timer.hh:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __Timer_hh__
3 | #define __Timer_hh__
4 |
5 | // Copyright (C) 2002 David R. Martin
6 | //
7 | // This program is free software; you can redistribute it and/or
8 | // modify it under the terms of the GNU General Public License as
9 | // published by the Free Software Foundation; either version 2 of the
10 | // License, or (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful, but
13 | // WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 | // General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program; if not, write to the Free Software
19 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
20 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html.
21 |
22 | #include
23 | #include
24 | #include
25 | #include
26 | #include
27 |
28 | class Timer
29 | {
30 | public:
31 |
32 | inline Timer ();
33 | inline ~Timer ();
34 |
35 | inline void start ();
36 | inline void stop ();
37 | inline void reset ();
38 |
39 | // All times are in seconds.
40 | inline double cpu ();
41 | inline double user ();
42 | inline double system ();
43 | inline double elapsed ();
44 |
45 | // Convert time in seconds into a nice human-friendly format: h:mm:ss.ss
46 | // Precision is the number of digits after the decimal.
47 | // Return a pointer to a static buffer.
48 | static const char* formatTime (double sec, int precision = 2);
49 |
50 | private:
51 |
52 | void _compute ();
53 |
54 | enum State { stopped, running };
55 |
56 | State _state;
57 |
58 | struct timeval _elapsed_start;
59 | struct timeval _elapsed_stop;
60 | double _elapsed;
61 |
62 | struct tms _cpu_start;
63 | struct tms _cpu_stop;
64 | double _user;
65 | double _system;
66 | };
67 |
68 | Timer::Timer ()
69 | {
70 | reset ();
71 | }
72 |
73 | Timer::~Timer ()
74 | {
75 | }
76 |
77 | void
78 | Timer::reset ()
79 | {
80 | _state = stopped;
81 | _elapsed = _user = _system = 0;
82 | }
83 |
84 | void
85 | Timer::start ()
86 | {
87 | assert (_state == stopped);
88 | _state = running;
89 | gettimeofday (&_elapsed_start, NULL);
90 | times (&_cpu_start);
91 | }
92 |
93 | void
94 | Timer::stop ()
95 | {
96 | assert (_state == running);
97 | gettimeofday (&_elapsed_stop, NULL);
98 | times (&_cpu_stop);
99 | _compute ();
100 | _state = stopped;
101 | }
102 |
103 | double
104 | Timer::cpu ()
105 | {
106 | assert (_state == stopped);
107 | return _user + _system;
108 | }
109 |
110 | double
111 | Timer::user ()
112 | {
113 | assert (_state == stopped);
114 | return _user;
115 | }
116 |
117 | double
118 | Timer::system ()
119 | {
120 | assert (_state == stopped);
121 | return _system;
122 | }
123 |
124 | double
125 | Timer::elapsed ()
126 | {
127 | assert (_state == stopped);
128 | return _elapsed;
129 | }
130 |
131 | #endif // __Timer_hh__
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/include/kofn.hh:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __kofn_hh__
3 | #define __kofn_hh__
4 |
5 | void kOfN (int k, int n, int* values);
6 |
7 | #endif // __kofn_hh__
8 |
--------------------------------------------------------------------------------
/evaluation/EdgeEval/include/match.hh:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __match_hh__
3 | #define __match_hh__
4 |
5 | class Matrix;
6 |
7 | // returns the cost of the assignment
8 | double matchEdgeMaps (
9 | const Matrix& bmap1, const Matrix& bmap2,
10 | double maxDist, double outlierCost,
11 | Matrix& match1, Matrix& match2);
12 |
13 | #endif // __match_hh__
--------------------------------------------------------------------------------
/evaluation/EdgeEval/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from Cython.Distutils import build_ext
4 | from distutils.extension import Extension
5 | import subprocess
6 | import numpy as np
7 | import glob
8 |
9 | try:
10 | NP_INCLUDE = np.get_include()
11 | except AttributeError:
12 | NP_INCLUDE = np.get_numpy_include()
13 |
14 | CSAPP_INCLUDE = os.path.abspath('CSA++')
15 |
16 | INCLUDE = [os.path.abspath('include'),
17 | CSAPP_INCLUDE,
18 | NP_INCLUDE,
19 | os.path.abspath('Util')]
20 | # SOURCES = ['CSA++/csa.cc','match.cc','correspondPixels.cc']
21 | SOURCES = ['EdgeMapEval.cc','CSA++/csa.cc','match.cc','correspondPixels.pyx'] + glob.glob('Util/*.cc')
22 | # import pdb
23 | # pdb.set_trace()
24 |
25 | class custom_build_ext(build_ext):
26 | def build_extensions(self):
27 | build_ext.build_extensions(self)
28 |
29 | ext_modules = [
30 | Extension("EdgeMapEval",
31 | sources=SOURCES,
32 | include_dirs=INCLUDE,
33 | language='c++',
34 | extra_compile_args=["-std=c++11"],
35 | )
36 | ]
37 |
38 | setup(
39 | ext_modules = ext_modules,
40 | cmdclass = {
41 | 'build_ext': custom_build_ext
42 | }
43 | )
--------------------------------------------------------------------------------
/evaluation/Makefile:
--------------------------------------------------------------------------------
1 | all:
2 | cd EdgeEval; python setup.py build_ext --inplace; cd ..
3 |
4 |
--------------------------------------------------------------------------------
/evaluation/RasterizeLine/__init__.py:
--------------------------------------------------------------------------------
1 | from .draw import drawfn
2 | __all__ = ['drawfn']
--------------------------------------------------------------------------------
/evaluation/RasterizeLine/build/temp.linux-x86_64-3.5/draw.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/RasterizeLine/build/temp.linux-x86_64-3.5/draw.o
--------------------------------------------------------------------------------
/evaluation/RasterizeLine/build/temp.linux-x86_64-3.5/kernel.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/RasterizeLine/build/temp.linux-x86_64-3.5/kernel.o
--------------------------------------------------------------------------------
/evaluation/RasterizeLine/build/temp.linux-x86_64-3.8/draw.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/RasterizeLine/build/temp.linux-x86_64-3.8/draw.o
--------------------------------------------------------------------------------
/evaluation/RasterizeLine/build/temp.linux-x86_64-3.8/kernel.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/RasterizeLine/build/temp.linux-x86_64-3.8/kernel.o
--------------------------------------------------------------------------------
/evaluation/RasterizeLine/draw.hpp:
--------------------------------------------------------------------------------
1 | void _draw(const float* lines,
2 | double* map,
3 | int height, int width,
4 | int nlines);
--------------------------------------------------------------------------------
/evaluation/RasterizeLine/draw.pyx:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | cimport numpy as np
3 |
4 | assert sizeof(int) == sizeof(np.int32_t)
5 |
6 | cdef extern from "draw.hpp":
7 | void _draw(const float* lines, double* map, int height, int width, int nlines);
8 |
9 |
10 | def drawfn(np.ndarray[float, ndim=2] lines,
11 | int height, int width):
12 |
13 | cdef np.ndarray[np.double_t, ndim=2] linemap = np.zeros((height,width),
14 | dtype=np.double)
15 | cdef int nlines = lines.shape[0]
16 | _draw( lines.data, linemap.data,
17 | height, width, nlines)
18 |
19 |
20 | return linemap
--------------------------------------------------------------------------------
/evaluation/RasterizeLine/kernel.cpp:
--------------------------------------------------------------------------------
1 | #include "draw.hpp"
2 | #include "math.h"
3 | void _draw(const float* lines,
4 | double* map,
5 | int height, int width,
6 | int nlines)
7 | {
8 | for(int k = 0; k < nlines; ++k){
9 | float x1 = lines[4*k];
10 | float y1 = lines[4*k+1];
11 | float x2 = lines[4*k+2];
12 | float y2 = lines[4*k+3];
13 |
14 | float vn = ceil(sqrt((x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)));
15 |
16 | float dx = (x2-x1)/(vn-1.0);
17 | float dy = (y2-y1)/(vn-1.0);
18 |
19 | for(int j = 0; j < (int) vn; ++j){
20 | int xx = round(x1+(float)j*dx);
21 | int yy = round(y1+(float)j*dy);
22 | if (xx>=0 && xx<=width-1 && yy>=0 && yy<=height-1){
23 | int index = yy*width + xx;
24 | map[index] = 1.0;
25 | }
26 | }
27 | }
28 | }
--------------------------------------------------------------------------------
/evaluation/RasterizeLine/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from Cython.Distutils import build_ext
4 | from distutils.extension import Extension
5 | import subprocess
6 | import numpy as np
7 | import glob
8 |
9 | try:
10 | NP_INCLUDE = np.get_include()
11 | except AttributeError:
12 | NP_INCLUDE = np.get_numpy_include()
13 |
14 | class custom_build_ext(build_ext):
15 | def build_extensions(self):
16 | build_ext.build_extensions(self)
17 |
18 | ext_modules = [
19 | Extension("draw",
20 | sources=['kernel.cpp','draw.pyx'],
21 | include_dirs=[NP_INCLUDE],
22 | language='c++',
23 | extra_compile_args=["-std=c++11"],
24 | )
25 | ]
26 |
27 | setup(
28 | ext_modules = ext_modules,
29 | cmdclass = {
30 | 'build_ext': custom_build_ext
31 | }
32 | )
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .evaluation import LSDEvaluator
2 |
3 | __all__ = ["LSDEvaluator"]
4 |
--------------------------------------------------------------------------------
/evaluation/compute_prec_recall.py:
--------------------------------------------------------------------------------
1 | from .EdgeEval import correspond
2 |
--------------------------------------------------------------------------------
/evaluation/draw-hap.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import argparse
3 | import numpy as np
4 | import json
5 | import matplotlib as mpl
6 | import scipy.io as sio
7 | from scipy import interpolate
8 |
9 | mpl.rcParams.update({"font.size": 18})
10 | # mpl.font_manager._rebuild()
11 |
12 | def main():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--path',nargs='+', type=str)
15 | parser.add_argument('--threshold', type=int, choices = [5,10,15], default=10)
16 | parser.add_argument('--dest', default = None, type=str, help='the destination of the saved figure')
17 |
18 | args = parser.parse_args()
19 |
20 | evaluation_results = []
21 | legends = []
22 | for path in args.path:
23 | result = sio.loadmat(path)
24 | evaluation_results.append(result)
25 | legends.append(result['label'])
26 | f_scores = np.linspace(0.2,0.9,num=8).tolist()
27 | for f_score in f_scores:
28 | x = np.linspace(0.01,1)
29 | y = f_score*x/(2*x-f_score)
30 | l, = plt.plot(x[y >= 0], y[y >= 0], color=[0,0.5,0], alpha=0.3)
31 | plt.annotate("f={0:0.1}".format(f_score), xy=(0.9, y[45] + 0.02), alpha=0.4,fontsize=10)
32 |
33 | plt.rc('legend',fontsize=14)
34 | plt.grid(True)
35 | plt.axis([0.0, 1.0, 0.0, 1.0])
36 | plt.xticks(np.arange(0, 1.0, step=0.1))
37 | plt.xlabel("Recall")
38 | plt.ylabel("Precision")
39 | plt.yticks(np.arange(0, 1.0, step=0.1))
40 |
41 | for label, result in zip(legends,evaluation_results):
42 | label = label.item()
43 | precision = result['precision'][0].flatten()
44 | recall = result['recall'][0].flatten()
45 | idx = np.isfinite(precision)*np.isfinite(recall)
46 | precision = precision[idx]
47 | recall = recall[idx]
48 |
49 | x = np.arange(0,1,0.01)*recall[-1]
50 | f = interpolate.interp1d(recall,precision,kind='cubic',bounds_error=False,fill_value=precision[0])
51 | y = f(x)
52 | T = 0.005
53 |
54 | print(result['f'].item())
55 |
56 |
57 | #sap_head = "sAP$^{%d}$"%(args.threshold)
58 | if 'HAWP' in label:
59 | label_ = '[F={:.1f}] {}(Ours)'.format(result['f'].item()*100,label)
60 | else:
61 | label_ = '[F={:.1f}] {}'.format(result['f'].item()*100,label)
62 | #plt.plot(recall,precision,'-',label=label_sap,linewidth=3)
63 | # plt.plot(recall,precision,'-')
64 | plt.plot(x,y,'-',label=label_,linewidth=3)
65 | plt.legend(loc='best')
66 | title = "PR Curve for AP$H$"
67 | plt.title(title)
68 | if args.dest:
69 | plt.savefig(args.dest, bbox_inches='tight')
70 | else:
71 | plt.show()
72 |
73 | if __name__ == "__main__":
74 | main()
--------------------------------------------------------------------------------
/evaluation/draw-json.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import cv2
3 | import numpy as np
4 | import argparse
5 | import os
6 | import os.path as osp
7 | import json
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--pred',type=str,required=True,help='the json file for the wireframe or line segment predictions')
11 | parser.add_argument('--benchmark', type=str, choices = ['wireframe','york'], required=True)
12 | parser.add_argument('--threshold', type=float, default=0.5)
13 | parser.add_argument('--cmp', default='g', choices = ['g','l'])
14 | parser.add_argument('--dest', type=str, required=True)
15 | parser.add_argument('--width', default=0, type=int,)
16 | parser.add_argument('--height', default=0, type=int,)
17 | parser.add_argument('--topk', default=-1, type=int)
18 | parser.add_argument('--fname', default=None, type=str)
19 |
20 | args = parser.parse_args()
21 | with open(args.pred,'r') as fin:
22 | results = json.load(fin)
23 |
24 | if args.benchmark == 'wireframe':
25 | args.images = 'data/wireframe/images'
26 | elif args.benchmark == 'york':
27 | args.images = 'data/york/images'
28 |
29 | os.makedirs(args.dest,exist_ok=True)
30 |
31 | if args.fname is not None:
32 | results = [r for r in results if r['filename'] == args.fname]
33 | for result in results:
34 | fname = result['filename']
35 |
36 | image = cv2.imread(osp.join(args.images,fname))[:,:,::-1]
37 | ori_shape = image.shape
38 |
39 | lines = np.array(result['lines_pred'])
40 | score = np.array(result['lines_score'])
41 | if result['width']!=ori_shape[1] or result['height']!=ori_shape[0]:
42 | sx = float(ori_shape[1]/result['width'])
43 | sy = float(ori_shape[0]/result['height'])
44 | sxy = np.array([sx,sy,sx,sy]).reshape(-1,4)
45 | lines = lines*sxy
46 |
47 | if args.cmp == 'g':
48 | sort_arg = np.argsort(score)[::-1]
49 | lines = lines[sort_arg]
50 | score = score[sort_arg]
51 | idx = score>args.threshold
52 | else:
53 | sort_arg = np.argsort(score)
54 | lines = lines[sort_arg]
55 | score = score[sort_arg]
56 | idx = score 0:
59 | idx = np.arange(idx.shape[0])= 0], y[y >= 0], color=[0,0.5,0], alpha=0.3)
31 | plt.annotate("f={0:0.1}".format(f_score), xy=(0.9, y[45] + 0.02), alpha=0.4,fontsize=10)
32 |
33 | plt.rc('legend',fontsize=14)
34 | plt.grid(True)
35 | plt.axis([0.0, 1.0, 0.0, 1.0])
36 | plt.xticks(np.arange(0, 1.0, step=0.1))
37 | plt.xlabel("Recall")
38 | plt.ylabel("Precision")
39 | plt.yticks(np.arange(0, 1.0, step=0.1))
40 |
41 | for label, result in zip(legends,evaluation_results):
42 | precision = result['precision']
43 | recall = result['recall']
44 | sap_head = "sAP$^{%d}$"%(args.threshold)
45 | if 'HAWP' in label:
46 | label_sap = '[{}={:.1f}] {}(Ours)'.format(sap_head,result['sAP'],label)
47 | else:
48 | label_sap = '[{}={:.1f}] {}'.format(sap_head,result['sAP'],label)
49 | plt.plot(recall,precision,'-',label=label_sap,linewidth=3)
50 | plt.legend(loc='best')
51 | title = "PR Curve for sAP$^{%d}$"%(int(args.threshold))
52 | plt.title(title)
53 | if args.dest:
54 | plt.savefig(args.dest, bbox_inches='tight')
55 | else:
56 | plt.show()
57 |
58 | if __name__ == "__main__":
59 | main()
--------------------------------------------------------------------------------
/evaluation/evaluation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.io as sio
3 | import glob
4 | import os
5 | import os.path as osp
6 | from .prmeter import PrecisionRecallMeter
7 | import multiprocessing
8 | from tqdm import tqdm
9 |
10 | class LSDEvaluator(object):
11 | def __init__(self, root, thresholds, cmp='g', height=0,width=0):
12 | self.root = root
13 | self.filenames = glob.glob(osp.join(root,'*.mat'))
14 | self.thresholds = thresholds
15 | self.meter = PrecisionRecallMeter(self.thresholds,cmp=cmp)
16 | self.height =height
17 | self.width = width
18 |
19 | def eval_for_image(self, index):
20 | mat = sio.loadmat(self.filenames[index])
21 | gt = mat['gt']
22 | pred = mat['pred']
23 | height = mat['height'].item()
24 | width = mat['width'].item()
25 | # import pdb
26 | # pdb.set_trace()
27 | if self.height>0 and self.width>0:
28 | sx = float(self.width/width)
29 | sy = float(self.height/height)
30 | scale = np.array([sx,
31 | sy,
32 | sx,
33 | sy],dtype=np.float32)
34 | scale = scale.reshape((1,4))
35 | gt*=scale
36 | pred[:,:4] = pred[:,:4]*scale
37 | return self.meter(pred,gt,self.height,self.width)
38 | else:
39 | return self.meter(pred, gt, height, width)
40 |
41 | def __call__(self, num_workers=16, per_image = True):
42 | # self.eval_for_image(0)
43 | with multiprocessing.Pool(num_workers) as p:
44 | self.results = results = list(tqdm(p.imap(self.eval_for_image,
45 | range(len(self.filenames))), total=len(self.filenames)))
46 | if per_image:
47 | self.precisions = np.concatenate([r['p'][:,None] for r in self.results],axis=1)
48 | self.recalls = np.concatenate([r['r'][:, None] for r in self.results], axis=1)
49 |
50 | self.average_precisions = np.mean(self.precisions,axis=1)
51 | self.average_recalls = np.mean(self.recalls, axis=1)
52 | self.fmeasure = 2*self.average_precisions*self.average_recalls/(self.average_recalls+self.average_precisions)
53 |
54 | return {'precisions':self.precisions, 'recalls': self.recalls,
55 | 'avg_precision':self.average_precisions,
56 | 'avg_recall': self.average_recalls,
57 | 'avg_fmeasure': self.fmeasure,
58 | 'filenames': self.filenames}
59 | else:
60 | sumtp = sum(res['tp'] for res in results)
61 | sumfp = sum(res['fp'] for res in results)
62 | sumgt = sum(res['gt'] for res in results)
63 | # import pdb
64 | # pdb.set_trace()
65 | # rcs = sorted(sumtp/sumgt)
66 | # prs = sorted(sumtp/np.maximum(sumtp+sumfp,1e-9))[::-1]
67 | rcs = sumtp/sumgt
68 | prs = sumtp/np.maximum(sumtp+sumfp,1e-9)
69 | # temp = np.concatenate(([0],prs))
70 | # idx = np.where((temp[1:]-temp[:-1])>0)[0]
71 | # rcs = rcs[idx]
72 | # prs = prs[idx]
73 |
74 | return {'avg_precision': prs, 'avg_recall':rcs}
75 |
76 |
--------------------------------------------------------------------------------
/evaluation/example_evaluation.py:
--------------------------------------------------------------------------------
1 | import scipy.io as sio
2 | import matplotlib.pyplot as plt
3 | from evaluation.RasterizeLine import drawfn
4 | import numpy as np
5 | from evaluation.prmeter import PrecisionRecallMeter
6 | import glob
7 | import os.path as osp
8 | from tqdm import tqdm
9 | import multiprocessing
10 | if __name__ == '__main__':
11 |
12 | files = glob.glob('../outputs/afm_box_b4/wireframe_afm/*.mat')
13 | meter = PrecisionRecallMeter([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
14 |
15 | # results = [None]*len(files)
16 | def eval_on_image(i):
17 | mat = sio.loadmat(files[i])
18 | gt = mat['gt']
19 | pred = mat['pred']
20 | height = mat['height']
21 | width = mat['width']
22 |
23 | return meter(pred,gt,height,width)
24 |
25 | # for i in range(len(files)):
26 | # eval_on_image(i)
27 | # import pdb
28 | # pdb.set_trace()
29 | with multiprocessing.Pool(32) as p:
30 | results = list(tqdm(p.imap(eval_on_image, range(len(files))), total=len(files)))
31 | precisions = np.concatenate([r['p'][:,None] for r in results],axis=1)
32 | recalls = np.concatenate([r['r'][:,None] for r in results],axis=1)
33 |
34 | import pdb
35 | pdb.set_trace()
36 |
37 | # mat = sio.loadmat('../outputs/afm_box_b4/wireframe_afm/00031546.mat')
38 |
39 |
--------------------------------------------------------------------------------
/evaluation/example_rasterline.py:
--------------------------------------------------------------------------------
1 | import scipy.io as sio
2 | import matplotlib.pyplot as plt
3 | from evaluation.RasterizeLine import drawfn
4 | import numpy as np
5 |
6 | if __name__ == '__main__':
7 | mat = sio.loadmat('../outputs/afm_box_b4/wireframe_afm/00037266.mat')
8 | gt = mat['gt']
9 | pred = mat['pred']
10 | height = mat['height']
11 | width = mat['width']
12 | G = drawfn(np.ascontiguousarray(gt),height,width)
13 | P = drawfn(np.ascontiguousarray(pred[:,:4]), height,width)
14 | import pdb
15 | pdb.set_trace()
16 |
--------------------------------------------------------------------------------
/evaluation/prmeter.py:
--------------------------------------------------------------------------------
1 | from .EdgeEval import correspond
2 | from .RasterizeLine import drawfn
3 | import numpy as np
4 |
5 | # cmp_dict = {
6 | # 'g': lambda a,b: a>b,
7 | # 'l': lambda a,b: a==b,
8 | # 'e': lambda a,b: ab
29 |
30 | @staticmethod
31 | def cmp_e(a,b):
32 | return a==b
33 |
34 | @staticmethod
35 | def cmp_l(a,b):
36 | return at)[0]
58 | pred_t = np.ascontiguousarray(pred[idx,:4])
59 | pred_map = drawfn(pred_t,height,width)
60 |
61 | matchE, matchG, _ = correspond(pred_map, gt_map,self.maxDist)
62 | cntR[i] = np.sum(matchG>0)
63 | sumR[i] = np.sum(gt_map>0)
64 | cntP[i] = np.sum(matchE>0)
65 | sumP[i] = np.sum(pred_map>0)
66 |
67 | matchE = np.array(matchE>0,dtype=np.float32)
68 |
69 | tp[i] = matchE.sum()
70 | fp[i] = np.sum(pred_map) - matchE.sum()
71 | gt[i] = gt_map.sum()
72 | #fp: sumP-cntR
73 | #gt: sumR
74 | #tp: cntP
75 |
76 | recalls[i] = cntR[i] / (sumR[i]+1e-15)
77 | precisions[i] = cntP[i] / (sumP[i]+1e-15)
78 |
79 |
80 |
81 | fscore = 2*recalls*precisions/(recalls+precisions+1e-6)
82 |
83 | return {'p':precisions, 'r':recalls, 'f':fscore,
84 | 'tp': tp,
85 | 'fp': fp,
86 | 'gt': gt,
87 | 'sumR': sumR,
88 | 'cntR': cntR,
89 | 'sumP': sumP,
90 | 'cntP': cntP}
91 |
--------------------------------------------------------------------------------
/evaluation/runs/draw-hap.sh:
--------------------------------------------------------------------------------
1 | python -m evaluation.draw-hap --path \
2 | precomputed-results/benchmark/DWP-wireframe.mat \
3 | precomputed-results/benchmark/LCNN-wireframe.mat \
4 | precomputed-results/benchmark/LETR-R101-wireframe-aph.mat \
5 | precomputed-results/benchmark/LETR-R50-wireframe-aph.mat \
6 | precomputed-results/benchmark/FClip-HG2-LB-wireframe-aph.mat \
7 | precomputed-results/benchmark/afmpp-wireframe-aph.mat \
8 | precomputed-results/benchmark/HAWPv1-wireframe.mat \
9 | outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test-aph.mat \
10 | --dest figures/APH-wireframe.pdf
11 |
12 |
13 | python -m evaluation.draw-hap --path \
14 | precomputed-results/benchmark/DWP-york.mat \
15 | precomputed-results/benchmark/LCNN-york.mat \
16 | precomputed-results/benchmark/LETR-R101-york-aph.mat \
17 | precomputed-results/benchmark/LETR-R50-york-aph.mat \
18 | precomputed-results/benchmark/FClip-HG2-LB-york-aph.mat \
19 | precomputed-results/benchmark/afmpp-york-aph.mat \
20 | precomputed-results/benchmark/HAWPv1-york.mat \
21 | outputs/ihawp-train-rot-v2-full/220625-162909/york_test-aph.mat \
22 | --dest figures/APH-york.pdf
23 | # precomputed-results/benchmark/DWP-wireframe.mat \
--------------------------------------------------------------------------------
/evaluation/runs/draw-sap.sh:
--------------------------------------------------------------------------------
1 | # python -m evaluation.draw-sap --path \
2 | # outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json.sap \
3 | # precomputed-results/benchmark/HAWPv1-wireframe.json.sap \
4 | # precomputed-results/benchmark/FClip-HG2-LB-wireframe.json.sap \
5 | # precomputed-results/benchmark/LETR-R101-wireframe.json.sap \
6 | # precomputed-results/benchmark/LETR-R50-wireframe.json.sap \
7 | # precomputed-results/benchmark/LCNN-wireframe.json.sap \
8 | # --threshold=10 \
9 | # --dest figures/wireframe-sAP-10.pdf
10 |
11 | python -m evaluation.draw-sap --path \
12 | outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json.sap \
13 | precomputed-results/benchmark/HAWPv1-wireframe.json.sap \
14 | precomputed-results/benchmark/FClip-HG2-LB-wireframe.json.sap \
15 | precomputed-results/benchmark/LETR-R101-wireframe.json.sap \
16 | precomputed-results/benchmark/LETR-R50-wireframe.json.sap \
17 | precomputed-results/benchmark/LCNN-wireframe.json.sap \
18 | --threshold=5 \
19 | --dest figures/wireframe-sAP-05.pdf
20 | # outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json.sap \
21 | #precomputed-results/benchmark/HAWPv1-wireframe.json.sap \
22 |
23 | python -m evaluation.draw-sap --path \
24 | precomputed-results/benchmark/LETR-R101-york.json.sap \
25 | precomputed-results/benchmark/LETR-R50-york.json.sap \
26 | precomputed-results/benchmark/FClip-HG2-LB-york.json.sap \
27 | precomputed-results/benchmark/HAWPv1-york.json.sap \
28 | outputs/ihawp-train-rot-v2-full/220625-162909/york_test.json.sap \
29 | --threshold=5 \
30 | --dest figures/york-sAP-05.pdf
31 | # precomputed-results/benchmark/LCNN-york.json.sap \
--------------------------------------------------------------------------------
/evaluation/runs/draw-vis-im-york.sh:
--------------------------------------------------------------------------------
1 | # FNAME='00031811.png'
2 | # FNAME='00110785.png'
3 | FNAME='00255368.png'
4 | python -m evaluation.draw-json \
5 | --pred precomputed-results/benchmark/LETR-R101-wireframe.json \
6 | --benchmark wireframe \
7 | --topk 50 \
8 | --fname $FNAME \
9 | --dest precomputed-results/vis-fsl-sel/LETR-R101
10 |
11 | python -m evaluation.draw-json \
12 | --pred outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json \
13 | --benchmark wireframe \
14 | --topk 50 \
15 | --fname $FNAME \
16 | --dest precomputed-results/vis-fsl-sel/HAWPv2
17 |
18 |
19 | python -m evaluation.draw-json \
20 | --pred precomputed-results/benchmark/afmpp-wireframe.json \
21 | --benchmark wireframe \
22 | --threshold=0.2 \
23 | --fname $FNAME \
24 | --cmp=l \
25 | --dest precomputed-results/vis-fsl-sel/afmpp
26 |
27 | python -m evaluation.draw-json \
28 | --pred precomputed-results/benchmark/FClip-HG2-LB-wireframe.json \
29 | --benchmark wireframe \
30 | --fname $FNAME \
31 | --topk 50 \
32 | --dest precomputed-results/vis-fsl-sel/FClip
33 |
34 |
35 |
--------------------------------------------------------------------------------
/evaluation/runs/draw-vis-im.sh:
--------------------------------------------------------------------------------
1 | FNAME='00031811.png'
2 | FNAME='00110785.png'
3 | # FNAME='00255368.png'
4 | # FNAME='00031546.png'
5 | # FNAME='00031811.png'
6 | # FNAME='00034439.png'
7 | FNAME='00053549.png'
8 |
9 | python -m evaluation.draw-json \
10 | --pred precomputed-results/benchmark/HAWPv1-wireframe.json \
11 | --benchmark wireframe \
12 | --threshold 0.97 \
13 | --fname $FNAME \
14 | --dest precomputed-results/vis-fsl-sel/HAWPv1
15 |
16 | python -m evaluation.draw-json \
17 | --pred precomputed-results/benchmark/LETR-R101-wireframe.json \
18 | --benchmark wireframe \
19 | --topk 50 \
20 | --fname $FNAME \
21 | --dest precomputed-results/vis-fsl-sel/LETR-R101
22 |
23 | python -m evaluation.draw-json \
24 | --pred outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json \
25 | --benchmark wireframe \
26 | --threshold 0.9 \
27 | --fname $FNAME \
28 | --dest precomputed-results/vis-fsl-sel/HAWPv2
29 |
30 |
31 | python -m evaluation.draw-json \
32 | --pred precomputed-results/benchmark/afmpp-wireframe.json \
33 | --benchmark wireframe \
34 | --threshold=0.2 \
35 | --fname $FNAME \
36 | --cmp=l \
37 | --dest precomputed-results/vis-fsl-sel/afmpp
38 |
39 | python -m evaluation.draw-json \
40 | --pred precomputed-results/benchmark/FClip-HG2-LB-wireframe.json \
41 | --benchmark wireframe \
42 | --fname $FNAME \
43 | --topk 50 \
44 | --dest precomputed-results/vis-fsl-sel/FClip
45 |
46 |
47 |
--------------------------------------------------------------------------------
/evaluation/runs/draw-vis-york.sh:
--------------------------------------------------------------------------------
1 | python -m evaluation.draw-json \
2 | --pred precomputed-results/benchmark/LETR-R101-york.json \
3 | --benchmark york \
4 | --topk 50 \
5 | --dest precomputed-results/vis-fsl-york/LETR-R101
6 |
7 | python -m evaluation.draw-json \
8 | --pred outputs/ihawp-train-rot-v2-full/220625-162909/york_test.json \
9 | --benchmark york \
10 | --threshold 0.9 \
11 | --dest precomputed-results/vis-fsl-york/HAWPv2
12 |
13 |
14 | python -m evaluation.draw-json \
15 | --pred precomputed-results/benchmark/HAWPv1-york.json \
16 | --benchmark york \
17 | --threshold 0.97 \
18 | --dest precomputed-results/vis-fsl-york/HAWPv1
19 |
20 | python -m evaluation.draw-json \
21 | --pred precomputed-results/benchmark/afmpp-york.json \
22 | --benchmark york \
23 | --threshold=0.2 \
24 | --cmp=l \
25 | --dest precomputed-results/vis-fsl-york/afmpp
26 |
27 | python -m evaluation.draw-json \
28 | --pred precomputed-results/benchmark/FClip-HG2-LB-york.json \
29 | --benchmark york \
30 | --topk 100 \
31 | --dest precomputed-results/vis-fsl-york/FClip
32 |
33 |
34 |
--------------------------------------------------------------------------------
/evaluation/runs/draw-vis.sh:
--------------------------------------------------------------------------------
1 | python -m evaluation.draw-json \
2 | --pred precomputed-results/benchmark/LETR-R101-wireframe.json \
3 | --benchmark wireframe \
4 | --topk 50 \
5 | --dest precomputed-results/vis-fsl/LETR-R101
6 |
7 | python -m evaluation.draw-json \
8 | --pred outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json \
9 | --benchmark wireframe \
10 | --threshold 0.97 \
11 | # --topk 50 \
12 | --dest precomputed-results/vis-fsl/HAWPv2
13 |
14 | python -m evaluation.draw-json \
15 | --pred precomputed-results/benchmark/HAWPv1-wireframe.json \
16 | --benchmark wireframe \
17 | --threshold 0.97 \
18 | --dest precomputed-results/vis-fsl/HAWPv2
19 |
20 |
21 | python -m evaluation.draw-json \
22 | --pred precomputed-results/benchmark/afmpp-wireframe.json \
23 | --benchmark wireframe \
24 | --threshold=0.2 \
25 | --cmp=l \
26 | --dest precomputed-results/vis-fsl/afmpp
27 |
28 | python -m evaluation.draw-json \
29 | --pred precomputed-results/benchmark/FClip-HG2-LB-wireframe.json \
30 | --benchmark wireframe \
31 | --topk 50 \
32 | --dest precomputed-results/vis-fsl/FClip
33 |
34 |
35 |
--------------------------------------------------------------------------------
/evaluation/runs/eval-aph.sh:
--------------------------------------------------------------------------------
1 | python -m evaluation.eval-json --pred precomputed-results/benchmark/FClip-HG2-LB-wireframe.json --benchmark wireframe --label "F-Clip-HG2-LB" --nthreads=8 --thresholds 0.1 0.2 0.25 0.27 0.3 0.315 0.33 0.345 0.36 0.38 0.4 0.42 0.45 0.47 0.49 0.5 0.52 0.54 0.56 0.58
2 | python -m evaluation.eval-json --pred precomputed-results/benchmark/FClip-HG2-LB-york.json --benchmark york --label "F-Clip-HG2-LB" --nthreads=8 --thresholds 0.1 0.2 0.25 0.27 0.3 0.315 0.33 0.345 0.36 0.38 0.4 0.42 0.45 0.47 0.49 0.5 0.52 0.54 0.56 0.58
3 |
--------------------------------------------------------------------------------
/evaluation/runs/sap.sh:
--------------------------------------------------------------------------------
1 | #LETR-R50
2 | python -m evaluation.eval-sap --pred precomputed-results/benchmark/LETR-R50-wireframe.json --benchmark wireframe --label LETR-R50
3 | python -m evaluation.eval-sap --pred precomputed-results/benchmark/LETR-R50-york.json --benchmark york --label LETR-R50
4 | #LETR-R101
5 | python -m evaluation.eval-sap --pred precomputed-results/benchmark/LETR-R101-wireframe.json --benchmark wireframe --label LETR-R101
6 | python -m evaluation.eval-sap --pred precomputed-results/benchmark/LETR-R101-york.json --benchmark york --label LETR-R101
7 |
8 |
9 | #HAWPv2
10 | python -m evaluation.eval-sap --pred outputs/ihawp-train-rot-v2-full/220625-162909/wireframe_test.json --benchmark wireframe --label HAWPv2
11 | python -m evaluation.eval-sap --pred outputs/ihawp-train-rot-v2-full/220625-162909/york_test.json --benchmark york --label HAWPv2
12 |
--------------------------------------------------------------------------------
/evaluation/sAPEval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/evaluation/sAPEval/__init__.py
--------------------------------------------------------------------------------
/evaluation/sAPEval/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.io as sio
3 | import os.path as osp
4 | import glob
5 |
6 | def ap(tp, fp):
7 | recall = tp
8 | precision = tp / np.maximum(tp + fp, 1e-9)
9 |
10 | recall = np.concatenate(([0.0], recall, [1.0]))
11 | precision = np.concatenate(([0.0], precision, [0.0]))
12 |
13 | for i in range(precision.size - 1, 0, -1):
14 | precision[i - 1] = max(precision[i - 1], precision[i])
15 | i = np.where(recall[1:] != recall[:-1])[0]
16 | return np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
17 |
18 | def msTPFP(lines_pred, lines_gt, threshold = 5):
19 | x1_pred = lines_pred[:,:2]
20 | x2_pred = lines_pred[:,2:4]
21 | x1_gt = lines_gt[:,:2]
22 | x2_gt = lines_gt[:,2:]
23 | diff1_1 = ((x1_pred[:,None]-x1_gt)**2).sum(-1)
24 | diff1_2 = ((x1_pred[:,None]-x2_gt)**2).sum(-1)
25 |
26 | diff2_1 = ((x2_pred[:, None] - x1_gt) ** 2).sum(-1)
27 | diff2_2 = ((x2_pred[:, None] - x2_gt) ** 2).sum(-1)
28 |
29 | diff = np.minimum(diff1_1+diff2_2, diff1_2+diff2_1)
30 |
31 | choice = np.argmin(diff,1)
32 | dist = np.min(diff,1)
33 | hit = np.zeros(len(lines_gt),np.bool)
34 | tp = np.zeros(len(lines_pred),np.float)
35 | fp = np.zeros(len(lines_pred), np.float)
36 | for i in range(len(lines_pred)):
37 | if dist[i] < threshold and not hit[choice[i]]:
38 | hit[choice[i]] = True
39 | tp[i] = 1
40 | else:
41 | fp[i] = 1
42 |
43 | return tp,fp
44 |
45 |
46 | if __name__ == '__main__':
47 |
48 | path = osp.join('outputs','afmbox_R50-FPN-AFM-512','wireframe','*.mat')
49 |
50 | files = glob.glob(path)
51 |
52 | tps, fps, scores = [],[],[]
53 | n_gt = 0
54 |
55 | aps = []
56 | for f in files:
57 | mat = sio.loadmat(f)
58 |
59 | height = mat['height'].item()
60 | width = mat['width'].item()
61 | lines_pred = mat['pred']
62 | lines_gt = mat['gt']
63 |
64 | lines_pred[:, 0] *= 128 / width
65 | lines_pred[:, 2] *= 128 / width
66 | lines_pred[:, 1] *= 128 / height
67 | lines_pred[:, 3] *= 128 / height
68 |
69 | lines_gt[:, 0] *= 128 / width
70 | lines_gt[:, 2] *= 128 / width
71 | lines_gt[:, 1] *= 128 / height
72 | lines_gt[:, 3] *= 128 / height
73 |
74 | pred_score = lines_pred[:,4]
75 |
76 |
77 | n_gt += len(lines_gt)
78 |
79 | tp,fp = msTPFP(lines_pred,lines_gt,10)
80 | # import pdb
81 | # pdb.set_trace()
82 | # tp = np.cumsum(tp)/len(lines_gt)
83 | # fp = np.cumsum(fp)/len(lines_gt)
84 |
85 | aps += [ap(tp,fp)]
86 | tps.append(tp)
87 | fps.append(fp)
88 | scores.append(pred_score)
89 |
90 | tps = np.concatenate(tps)
91 | fps = np.concatenate(fps)
92 | scores = np.concatenate(scores)
93 | index = np.argsort(scores)
94 | tp = np.cumsum(tps[index])/n_gt
95 | fp = np.cumsum(fps[index])/n_gt
96 | import pdb
97 | pdb.set_trace()
98 |
99 |
100 |
--------------------------------------------------------------------------------
/hawp/__init__.py:
--------------------------------------------------------------------------------
1 | # from .models import HAWP
2 | from . import base
3 | from . import fsl
4 | from . import ssl
--------------------------------------------------------------------------------
/hawp/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .csrc import _C
2 | from . import utils
3 | from .utils.comm import to_device
4 | from .utils.logger import setup_logger
5 | from .utils.metric_logger import MetricLogger
6 | from .utils.miscellaneous import save_config
7 | from .wireframe import WireframeGraph
8 |
9 | __all__ = [
10 | "_C",
11 | "utils",
12 | "to_device",
13 | "setup_logger",
14 | "MetricLogger",
15 | "save_config",
16 | "WireframeGraph",
17 | ]
--------------------------------------------------------------------------------
/hawp/base/csrc/__init__.py:
--------------------------------------------------------------------------------
1 | from torch.utils.cpp_extension import load
2 | import glob
3 | import os.path as osp
4 |
5 | __this__ = osp.dirname(__file__)
6 |
7 | try:
8 | _C = load(name='_C',sources=[
9 | osp.join(__this__,'binding.cpp'),
10 | osp.join(__this__,'linesegment.cu'),
11 | ]
12 | )
13 | except:
14 | _C = None
15 |
16 | __all__ = ["_C"]
17 |
18 | #_C = load(name='base._C', sources=['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'])
19 |
--------------------------------------------------------------------------------
/hawp/base/csrc/binding.cpp:
--------------------------------------------------------------------------------
1 | #include "linesegment.h"
2 |
3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4 | m.def("encodels", &encodels, "Encoding line segments to maps");
5 | }
--------------------------------------------------------------------------------
/hawp/base/csrc/linesegment.h:
--------------------------------------------------------------------------------
1 | // #pragma once
2 | #include
3 |
4 | std::tuple lsencode_cuda(
5 | const at::Tensor& lines,
6 | const int input_height,
7 | const int input_width,
8 | const int height,
9 | const int width,
10 | const int num_lines);
11 |
12 | std::tuple encodels(
13 | const at::Tensor& lines,
14 | const int input_height,
15 | const int input_width,
16 | const int height,
17 | const int width,
18 | const int num_lines)
19 | {
20 | return lsencode_cuda(lines,
21 | input_height,
22 | input_width,
23 | height,
24 | width,
25 | num_lines);
26 | }
--------------------------------------------------------------------------------
/hawp/base/show/__init__.py:
--------------------------------------------------------------------------------
1 | from .canvas import Canvas, image_canvas, canvas
2 | from .painters import HAWPainter
3 | from .cli import cli, configure
--------------------------------------------------------------------------------
/hawp/base/show/cli.py:
--------------------------------------------------------------------------------
1 | # from hawp.config import defaults
2 | import logging
3 |
4 | from .canvas import Canvas
5 | from .painters import HAWPainter
6 | import matplotlib
7 | LOG = logging.getLogger(__name__)
8 |
9 | def cli(parser):
10 | group = parser.add_argument_group('show')
11 |
12 | assert not Canvas.show
13 | group.add_argument('--show', default=False,action='store_true',
14 | help='show every plot, i.e., call matplotlib show()')
15 |
16 | group.add_argument('--edge-threshold', default=None, type=float,
17 | help='show the wireframe edges whose confidences are greater than [edge_threshold]')
18 | group.add_argument('--out-ext', default='png', type=str,
19 | help='save the plot in specific format')
20 | def configure(args):
21 | Canvas.show = args.show
22 | Canvas.out_file_extension = args.out_ext
23 | if args.edge_threshold is not None:
24 | HAWPainter.confidence_threshold = args.edge_threshold
--------------------------------------------------------------------------------
/hawp/base/show/painters.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | try:
8 | import matplotlib
9 | import matplotlib.animation
10 | import matplotlib.collections
11 | import matplotlib.patches
12 | except ImportError:
13 | matplotlib = None
14 |
15 |
16 | LOG = logging.getLogger(__name__)
17 |
18 |
19 | class HAWPainter:
20 | line_width = None
21 | marker_size = None
22 | confidence_threshold = 0.05
23 |
24 | def __init__(self):
25 |
26 | if self.line_width is None:
27 | self.line_width = 2
28 |
29 | if self.marker_size is None:
30 | self.marker_size = max(1, int(self.line_width * 0.5))
31 |
32 | def draw_wireframe(self, ax, wireframe, *,
33 | edge_color = None, vertex_color = None):
34 | if wireframe is None:
35 | return
36 |
37 | if edge_color is None:
38 | edge_color = 'b'
39 | if vertex_color is None:
40 | vertex_color = 'c'
41 |
42 | line_segments = wireframe['lines_pred'][wireframe['lines_score']>self.confidence_threshold]
43 |
44 | if isinstance(line_segments, torch.Tensor):
45 | line_segments = line_segments.cpu().numpy()
46 |
47 | # line_segments = wireframe.line_segments(threshold=self.confidence_threshold)
48 | # line_segments = line_segments.cpu().numpy()
49 | ax.plot([line_segments[:,0],line_segments[:,2]],[line_segments[:,1],line_segments[:,3]],'-',color=edge_color)
50 | ax.plot(line_segments[:,0],line_segments[:,1],'.',color=vertex_color)
51 | ax.plot(line_segments[:,2],line_segments[:,3],'.',
52 | color=vertex_color)
53 |
--------------------------------------------------------------------------------
/hawp/base/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/hawp/base/utils/imports.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import torch
3 |
4 | import importlib
5 | import importlib.util
6 | import sys
7 |
8 | def import_file(module_name, file_path, make_importable=False):
9 | spec = importlib.util.spec_from_file_location(module_name, file_path)
10 | module = importlib.util.module_from_spec(spec)
11 | spec.loader.exec_module(module)
12 | if make_importable:
13 | sys.modules[module_name] = module
14 | return module
15 |
--------------------------------------------------------------------------------
/hawp/base/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import logging
3 | import os
4 | import sys
5 | from pythonjsonlogger import jsonlogger
6 |
7 |
8 | def setup_logger(name, save_dir, out_file='log.txt', json_format=False):
9 | logger = logging.getLogger(name)
10 | logger.setLevel(logging.DEBUG)
11 | ch = logging.StreamHandler(stream=sys.stdout)
12 | ch.setLevel(logging.DEBUG)
13 | if json_format:
14 | formatter = jsonlogger.JsonFormatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
15 | else:
16 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
17 | ch.setFormatter(formatter)
18 | logger.addHandler(ch)
19 |
20 | if save_dir:
21 | fh = logging.FileHandler(os.path.join(save_dir, out_file))
22 | fh.setLevel(logging.DEBUG)
23 | fh.setFormatter(formatter)
24 | logger.addHandler(fh)
25 |
26 | return logger
27 |
--------------------------------------------------------------------------------
/hawp/base/utils/metric_evaluation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def msTPFP(line_pred, line_gt, threshold):
4 | line_pred = line_pred.reshape(-1, 2, 2)[:, :, ::-1]
5 | line_gt = line_gt.reshape(-1, 2, 2)[:, :, ::-1]
6 | diff = ((line_pred[:, None, :, None] - line_gt[:, None]) ** 2).sum(-1)
7 | diff = np.minimum(
8 | diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
9 | )
10 |
11 | choice = np.argmin(diff, 1)
12 | dist = np.min(diff, 1)
13 | hit = np.zeros(len(line_gt), np.bool)
14 | tp = np.zeros(len(line_pred), np.float)
15 | fp = np.zeros(len(line_pred), np.float)
16 | for i in range(len(line_pred)):
17 | if dist[i] < threshold and not hit[choice[i]]:
18 | hit[choice[i]] = True
19 | tp[i] = 1
20 | else:
21 | fp[i] = 1
22 | return tp, fp
23 |
24 |
25 | def TPFP(lines_dt, lines_gt, threshold):
26 | lines_dt = lines_dt.reshape(-1,2,2)[:,:,::-1]
27 | lines_gt = lines_gt.reshape(-1,2,2)[:,:,::-1]
28 | diff = ((lines_dt[:, None, :, None] - lines_gt[:, None]) ** 2).sum(-1)
29 | diff = np.minimum(
30 | diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
31 | )
32 |
33 | # diff1 = ((lines_dt[:, None, :2] - lines_gt[:, :2]) ** 2).sum(-1)
34 | # diff2 = ((lines_dt[:, None, 2:] - lines_gt[:, 2:]) ** 2).sum(-1)
35 | # diff3 = ((lines_dt[:, None, :2] - lines_gt[:, 2:]) ** 2).sum(-1)
36 | # diff4 = ((lines_dt[:, None, 2:] - lines_gt[:, :2]) ** 2).sum(-1)
37 | # import pdb
38 | # pdb.set_trace()
39 | # diff = np.minimum(diff1+diff2, diff3+diff4)
40 | choice = np.argmin(diff,1)
41 | dist = np.min(diff,1)
42 | hit = np.zeros(len(lines_gt), np.bool)
43 | tp = np.zeros(len(lines_dt), np.float)
44 | fp = np.zeros(len(lines_dt),np.float)
45 |
46 | for i in range(lines_dt.shape[0]):
47 | if dist[i] < threshold and not hit[choice[i]]:
48 | hit[choice[i]] = True
49 | tp[i] = 1
50 | else:
51 | fp[i] = 1
52 | return tp, fp
53 |
54 | def AP(tp, fp):
55 | recall = tp
56 | precision = tp/np.maximum(tp+fp, 1e-9)
57 |
58 | recall = np.concatenate(([0.0], recall, [1.0]))
59 | precision = np.concatenate(([0.0], precision, [0.0]))
60 |
61 |
62 |
63 | for i in range(precision.size - 1, 0, -1):
64 | precision[i - 1] = max(precision[i - 1], precision[i])
65 | i = np.where(recall[1:] != recall[:-1])[0]
66 |
67 | ap = np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
68 |
69 | return ap
70 |
--------------------------------------------------------------------------------
/hawp/base/utils/metric_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | from collections import defaultdict
3 | from collections import deque
4 |
5 | import torch
6 |
7 |
8 | class SmoothedValue(object):
9 | """Track a series of values and provide access to smoothed values over a
10 | window or the global series average.
11 | """
12 |
13 | def __init__(self, window_size=20):
14 | self.deque = deque(maxlen=window_size)
15 | self.series = []
16 | self.total = 0.0
17 | self.count = 0
18 |
19 | def update(self, value):
20 | self.deque.append(value)
21 | self.series.append(value)
22 | self.count += 1
23 | self.total += value
24 |
25 | @property
26 | def median(self):
27 | d = torch.tensor(list(self.deque))
28 | return d.median().item()
29 |
30 | @property
31 | def avg(self):
32 | d = torch.tensor(list(self.deque))
33 | return d.mean().item()
34 |
35 | @property
36 | def global_avg(self):
37 | return self.total / self.count
38 |
39 |
40 | class MetricLogger(object):
41 | def __init__(self, delimiter="\t"):
42 | self.meters = defaultdict(SmoothedValue)
43 | self.delimiter = delimiter
44 |
45 | def update(self, **kwargs):
46 | for k, v in kwargs.items():
47 | if isinstance(v, torch.Tensor):
48 | v = v.item()
49 | assert isinstance(v, (float, int))
50 | self.meters[k].update(v)
51 |
52 | def __getattr__(self, attr):
53 | if attr in self.meters:
54 | return self.meters[attr]
55 | if attr in self.__dict__:
56 | return self.__dict__[attr]
57 | raise AttributeError("'{}' object has no attribute '{}'".format(
58 | type(self).__name__, attr))
59 |
60 | def __str__(self):
61 | loss_str = []
62 | keys = sorted(self.meters)
63 | # for name, meter in self.meters.items():
64 | for name in keys:
65 | meter = self.meters[name]
66 | loss_str.append(
67 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
68 | )
69 | return self.delimiter.join(loss_str)
70 |
71 | def tensorborad(self, iteration, writter, phase='train'):
72 | for name, meter in self.meters.items():
73 | if 'loss' in name:
74 | # writter.add_scalar('average/{}'.format(name), meter.avg, iteration)
75 | writter.add_scalar('{}/global/{}'.format(phase,name), meter.global_avg, iteration)
76 | # writter.add_scalar('median/{}'.format(name), meter.median, iteration)
77 |
78 |
--------------------------------------------------------------------------------
/hawp/base/utils/miscellaneous.py:
--------------------------------------------------------------------------------
1 | import errno
2 | import json
3 | import logging
4 | import os
5 |
6 |
7 | def mkdir(path):
8 | try:
9 | os.makedirs(path)
10 | except OSError as e:
11 | if e.errno != errno.EEXIST:
12 | raise
13 |
14 |
15 | def save_config(cfg, path):
16 | with open(path, 'w') as f:
17 | f.write(cfg.dump())
18 |
--------------------------------------------------------------------------------
/hawp/base/utils/model_zoo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import os
3 | import sys
4 |
5 | try:
6 | from torch.hub import _download_url_to_file
7 | from torch.hub import urlparse
8 | from torch.hub import HASH_REGEX
9 | except ImportError:
10 | from torch.utils.model_zoo import _download_url_to_file
11 | from torch.utils.model_zoo import urlparse
12 | from torch.utils.model_zoo import HASH_REGEX
13 |
14 | from .comm import is_main_process,synchronize
15 |
16 |
17 | # very similar to https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py
18 | # but with a few improvements and modifications
19 | def cache_url(url, model_dir=None, progress=True):
20 | r"""Loads the Torch serialized object at the given URL.
21 | If the object is already present in `model_dir`, it's deserialized and
22 | returned. The filename part of the URL should follow the naming convention
23 | ``filename-.ext`` where ```` is the first eight or more
24 | digits of the SHA256 hash of the contents of the file. The hash is used to
25 | ensure unique names and to verify the contents of the file.
26 | The default value of `model_dir` is ``$TORCH_HOME/models`` where
27 | ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
28 | overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
29 | Args:
30 | url (string): URL of the object to download
31 | model_dir (string, optional): directory in which to save the object
32 | progress (bool, optional): whether or not to display a progress bar to stderr
33 | Example:
34 | >>> cached_file = maskrcnn_benchmark.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
35 | """
36 | if model_dir is None:
37 | torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch"))
38 | model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models"))
39 | if not os.path.exists(model_dir):
40 | os.makedirs(model_dir)
41 | parts = urlparse(url)
42 | filename = os.path.basename(parts.path)
43 | if filename == "model_final.pkl":
44 | # workaround as pre-trained Caffe2 models from Detectron have all the same filename
45 | # so make the full path the filename by replacing / with _
46 | filename = parts.path.replace("/", "_")
47 | cached_file = os.path.join(model_dir, filename)
48 | if not os.path.exists(cached_file) and is_main_process():
49 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
50 | hash_prefix = HASH_REGEX.search(filename)
51 | if hash_prefix is not None:
52 | hash_prefix = hash_prefix.group(1)
53 | # workaround: Caffe2 models don't have a hash, but follow the R-50 convention,
54 | # which matches the hash PyTorch uses. So we skip the hash matching
55 | # if the hash_prefix is less than 6 characters
56 | if len(hash_prefix) < 6:
57 | hash_prefix = None
58 | _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
59 | synchronize()
60 | return cached_file
61 |
--------------------------------------------------------------------------------
/hawp/base/utils/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 |
4 | def _register_generic(module_dict, module_name, module):
5 | assert module_name not in module_dict
6 | module_dict[module_name] = module
7 |
8 |
9 | class Registry(dict):
10 | '''
11 | A helper class for managing registering modules, it extends a dictionary
12 | and provides a register functions.
13 |
14 | Eg. creeting a registry:
15 | some_registry = Registry({"default": default_module})
16 |
17 | There're two ways of registering new modules:
18 | 1): normal way is just calling register function:
19 | def foo():
20 | ...
21 | some_registry.register("foo_module", foo)
22 | 2): used as decorator when declaring the module:
23 | @some_registry.register("foo_module")
24 | @some_registry.register("foo_modeul_nickname")
25 | def foo():
26 | ...
27 |
28 | Access of module is just like using a dictionary, eg:
29 | f = some_registry["foo_modeul"]
30 | '''
31 | def __init__(self, *args, **kwargs):
32 | super(Registry, self).__init__(*args, **kwargs)
33 |
34 | def register(self, module_name, module=None):
35 | # used as function call
36 | if module is not None:
37 | _register_generic(self, module_name, module)
38 | return
39 |
40 | # used as decorator
41 | def register_fn(fn):
42 | _register_generic(self, module_name, fn)
43 | return fn
44 |
45 | return register_fn
46 |
--------------------------------------------------------------------------------
/hawp/base/wireframe.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import torch
5 | import json
6 |
7 | class WireframeGraph:
8 | def __init__(self,
9 | vertices: torch.Tensor,
10 | v_confidences: torch.Tensor,
11 | edges: torch.Tensor,
12 | edge_weights: torch.Tensor,
13 | frame_width: int,
14 | frame_height: int):
15 | self.vertices = vertices
16 | self.v_confidences = v_confidences
17 | self.edges = edges
18 | self.weights = edge_weights
19 | self.frame_width = frame_width
20 | self.frame_height = frame_height
21 |
22 | @classmethod
23 | def xyxy2indices(cls,junctions, lines):
24 | # junctions: (N,2)
25 | # lines: (M,4)
26 | # return: (M,2)
27 | dist1 = torch.norm(junctions[None,:,:]-lines[:,None,:2],dim=-1)
28 | dist2 = torch.norm(junctions[None,:,:]-lines[:,None,2:],dim=-1)
29 | idx1 = torch.argmin(dist1,dim=-1)
30 | idx2 = torch.argmin(dist2,dim=-1)
31 | return torch.stack((idx1,idx2),dim=-1)
32 | @classmethod
33 | def load_json(cls, fname):
34 | with open(fname,'r') as f:
35 | data = json.load(f)
36 |
37 |
38 | vertices = torch.tensor(data['vertices'])
39 | v_confidences = torch.tensor(data['vertices-score'])
40 | edges = torch.tensor(data['edges'])
41 | edge_weights = torch.tensor(data['edges-weights'])
42 | height = data['height']
43 | width = data['width']
44 |
45 | return WireframeGraph(vertices,v_confidences,edges,edge_weights,width,height)
46 |
47 | @property
48 | def is_empty(self):
49 | for key, val in self.__dict__.items():
50 | if val is None:
51 | return True
52 | return False
53 |
54 | @property
55 | def num_vertices(self):
56 | if self.is_empty:
57 | return 0
58 | return self.vertices.shape[0]
59 |
60 | @property
61 | def num_edges(self):
62 | if self.is_empty:
63 | return 0
64 | return self.edges.shape[0]
65 |
66 |
67 | def line_segments(self, threshold = 0.05, device=None, to_np=False):
68 | is_valid = self.weights>threshold
69 | p1 = self.vertices[self.edges[is_valid,0]]
70 | p2 = self.vertices[self.edges[is_valid,1]]
71 | ps = self.weights[is_valid]
72 |
73 | lines = torch.cat((p1,p2,ps[:,None]),dim=-1)
74 | if device is not None:
75 | lines = lines.to(device)
76 | if to_np:
77 | lines = lines.cpu().numpy()
78 |
79 | return lines
80 | # if device != self.device:
81 |
82 | def rescale(self, image_width, image_height):
83 | scale_x = float(image_width)/float(self.frame_width)
84 | scale_y = float(image_height)/float(self.frame_height)
85 |
86 | self.vertices[:,0] *= scale_x
87 | self.vertices[:,1] *= scale_y
88 | self.frame_width = image_width
89 | self.frame_height = image_height
90 |
91 | def jsonize(self):
92 | return {
93 | 'vertices': self.vertices.cpu().tolist(),
94 | 'vertices-score': self.v_confidences.cpu().tolist(),
95 | 'edges': self.edges.cpu().tolist(),
96 | 'edges-weights': self.weights.cpu().tolist(),
97 | 'height': self.frame_height,
98 | 'width': self.frame_width,
99 | }
100 | def __repr__(self) -> str:
101 | return "WireframeGraph\n"+\
102 | "Vertices: {}\n".format(self.num_vertices)+\
103 | "Edges: {}\n".format(self.num_edges,) + \
104 | "Frame size (HxW): {}x{}".format(self.frame_height,self.frame_width)
105 |
106 | #graph = WireframeGraph()
107 | if __name__ == "__main__":
108 | graph = WireframeGraph.load_json('NeuS/public_data/bmvs_clock/hawp/000.json')
109 | print(graph)
110 |
--------------------------------------------------------------------------------
/hawp/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .hafm import HAFMencoder
--------------------------------------------------------------------------------
/hawp/fsl/__init__.py:
--------------------------------------------------------------------------------
1 | from . import config
2 | from . import backbones
3 | from . import dataset
4 | from . import model
--------------------------------------------------------------------------------
/hawp/fsl/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_backbone
--------------------------------------------------------------------------------
/hawp/fsl/backbones/build.py:
--------------------------------------------------------------------------------
1 | from .registry import MODELS
2 | from .stacked_hg import HourglassNet, Bottleneck2D
3 | from .multi_task_head import MultitaskHead
4 | from .resnets import ResNets
5 | from .point_line import PointLineNet
6 | # from .point_line_new import PointLineNet
7 | from .stacked_point_line import StackPointLine
8 |
9 | @MODELS.register("Hourglass")
10 | def build_hg(cfg, **kwargs):
11 | inplanes = cfg.MODEL.HGNETS.INPLANES
12 | num_feats = cfg.MODEL.OUT_FEATURE_CHANNELS//2
13 | depth = cfg.MODEL.HGNETS.DEPTH
14 | num_stacks = cfg.MODEL.HGNETS.NUM_STACKS
15 | num_blocks = cfg.MODEL.HGNETS.NUM_BLOCKS
16 | head_size = cfg.MODEL.HEAD_SIZE
17 |
18 | out_feature_channels = cfg.MODEL.OUT_FEATURE_CHANNELS
19 |
20 | if kwargs.get('gray_scale',False):
21 | input_channels = 1
22 | else:
23 | input_channels = 3
24 | num_class = sum(sum(head_size, []))
25 | model = HourglassNet(
26 | input_channels=input_channels,
27 | block=Bottleneck2D,
28 | inplanes = inplanes,
29 | num_feats= num_feats,
30 | depth=depth,
31 | head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size),
32 | num_stacks = num_stacks,
33 | num_blocks = num_blocks,
34 | num_classes = num_class)
35 |
36 | model.out_feature_channels = out_feature_channels
37 |
38 | return model
39 |
40 |
41 |
42 | @MODELS.register("PointLine")
43 | def build_hg(cfg, **kwargs):
44 | inplanes = cfg.MODEL.HGNETS.INPLANES
45 | num_feats = cfg.MODEL.OUT_FEATURE_CHANNELS//2
46 | depth = cfg.MODEL.HGNETS.DEPTH
47 | num_stacks = cfg.MODEL.HGNETS.NUM_STACKS
48 | num_blocks = cfg.MODEL.HGNETS.NUM_BLOCKS
49 | head_size = cfg.MODEL.HEAD_SIZE
50 |
51 | out_feature_channels = cfg.MODEL.OUT_FEATURE_CHANNELS
52 |
53 | if kwargs.get('gray_scale',False):
54 | input_channels = 1
55 | else:
56 | input_channels = 3
57 | num_class = sum(sum(head_size, []))
58 | model = PointLineNet(
59 | head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size))
60 |
61 | model.out_feature_channels = out_feature_channels
62 |
63 | return model
64 |
65 |
66 |
67 | @MODELS.register("StackPointLine")
68 | def build_hg(cfg, **kwargs):
69 | inplanes = cfg.MODEL.HGNETS.INPLANES
70 | num_feats = cfg.MODEL.OUT_FEATURE_CHANNELS//2
71 | depth = cfg.MODEL.HGNETS.DEPTH
72 | num_stacks = cfg.MODEL.HGNETS.NUM_STACKS
73 | num_blocks = cfg.MODEL.HGNETS.NUM_BLOCKS
74 | head_size = cfg.MODEL.HEAD_SIZE
75 |
76 | out_feature_channels = cfg.MODEL.OUT_FEATURE_CHANNELS
77 |
78 | if kwargs.get('gray_scale',False):
79 | input_channels = 1
80 | else:
81 | input_channels = 3
82 | num_class = sum(sum(head_size, []))
83 | model = StackPointLine(
84 | input_channels=input_channels,
85 | block=Bottleneck2D,
86 | inplanes = inplanes,
87 | num_feats= num_feats,
88 | depth=depth,
89 | head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size),
90 | num_stacks = num_stacks,
91 | num_blocks = num_blocks,
92 | num_classes = num_class)
93 |
94 | model.out_feature_channels = out_feature_channels
95 |
96 | return model
97 |
98 |
99 | # @MODELS.register("ResNets")
100 | # def build_resnet(cfg):
101 | # head_size = cfg.MODEL.HEAD_SIZE
102 |
103 | # num_class = sum(sum(head_size,[]))
104 | # model = ResNets(cfg.MODEL.RESNETS.BASENET,head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size),num_class=num_class,pretrain=cfg.MODEL.RESNETS.PRETRAIN)
105 |
106 | # model.out_feature_channels = 128
107 | # return model
108 |
109 | def build_backbone(cfg, **kwargs):
110 | assert cfg.MODEL.NAME in MODELS, \
111 | "cfg.MODELS.NAME: {} is not registered in registry".format(cfg.MODELS.NAME)
112 |
113 | return MODELS[cfg.MODEL.NAME](cfg, **kwargs)
114 |
--------------------------------------------------------------------------------
/hawp/fsl/backbones/multi_task_head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | class MultitaskHead(nn.Module):
4 | def __init__(self, input_channels, num_class, head_size):
5 | super(MultitaskHead, self).__init__()
6 |
7 | m = int(input_channels / 4)
8 | heads = []
9 | for output_channels in sum(head_size, []):
10 | heads.append(
11 | nn.Sequential(
12 | nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
13 | nn.ReLU(inplace=True),
14 | nn.Conv2d(m, output_channels, kernel_size=1),
15 | )
16 | )
17 | self.heads = nn.ModuleList(heads)
18 | assert num_class == sum(sum(head_size, []))
19 |
20 | def forward(self, x):
21 | return torch.cat([head(x) for head in self.heads], dim=1)
22 |
23 |
24 | class AngleDistanceHead(nn.Module):
25 | def __init__(self, input_channels, num_class, head_size):
26 | super(AngleDistanceHead, self).__init__()
27 |
28 | m = int(input_channels/4)
29 |
30 | heads = []
31 | for output_channels in sum(head_size, []):
32 | if output_channels != 2:
33 | heads.append(
34 | nn.Sequential(
35 | nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(m, output_channels, kernel_size=1),
38 | )
39 | )
40 | else:
41 | heads.append(
42 | nn.Sequential(
43 | nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
44 | nn.ReLU(inplace=True),
45 | CosineSineLayer(m)
46 | )
47 | )
48 | self.heads = nn.ModuleList(heads)
49 | assert num_class == sum(sum(head_size, []))
50 | def forward(self, x):
51 | return torch.cat([head(x) for head in self.heads], dim=1)
--------------------------------------------------------------------------------
/hawp/fsl/backbones/registry.py:
--------------------------------------------------------------------------------
1 | from hawp.base.utils.registry import Registry
2 |
3 | MODELS = Registry()
--------------------------------------------------------------------------------
/hawp/fsl/backbones/resnets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 | class ResNets(nn.Module):
6 | RESNET_TEMPLATES = {
7 | 'resnet18':torchvision.models.resnet18,
8 | 'resnet34':torchvision.models.resnet34,
9 | 'resnet50':torchvision.models.resnet50,
10 | 'resnet101':torchvision.models.resnet101,
11 | }
12 | def __init__(self,basenet, head, num_class, pretrain=True,):
13 | super(ResNets, self).__init__()
14 | assert basenet in ResNets.RESNET_TEMPLATES
15 |
16 | basenet_fn = ResNets.RESNET_TEMPLATES.get(basenet)
17 |
18 | model = basenet_fn(pretrain)
19 |
20 | self.conv1 = nn.Conv2d(3,64,7,2,3)
21 | self.bn1 = nn.BatchNorm2d(64)
22 | self.relu = nn.ReLU(True)
23 | self.maxpool = nn.MaxPool2d(3,2,1)
24 |
25 | self.layer1 = model.layer1
26 | self.layer2 = model.layer2
27 | self.layer3 = model.layer3
28 | self.layer4 = model.layer4
29 |
30 | self.pixel_shuffle = nn.PixelShuffle(4)
31 | self.hafm_predictor = head(128,num_class)
32 | # self.hafm_predictor = nn.Sequential(nn.Conv2d(2048,512,3,1,1),nn.ReLU(True),nn.Conv2d(512,5,1,1,0))
33 | def forward(self, images):
34 | x = self.conv1(images)
35 | x = self.relu(self.bn1(x))
36 | x = self.maxpool(x)
37 |
38 | x = self.layer1(x)
39 | x = self.layer2(x)
40 | x = self.layer3(x)
41 | x = self.layer4(x)
42 |
43 | x = self.pixel_shuffle(x)
44 |
45 | return [self.hafm_predictor(x)], x
46 |
47 | if __name__ == "__main__":
48 | model = ResNets('resnet50')
49 |
50 | inp = torch.zeros((1,3,512,512))
51 |
52 | model(inp)
53 |
54 |
55 |
--------------------------------------------------------------------------------
/hawp/fsl/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .defaults import cfg
2 |
3 | __all__ = ['cfg']
--------------------------------------------------------------------------------
/hawp/fsl/config/dataset.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 | # ---------------------------------------------------------------------------- #
3 | # Dataset options
4 | # ---------------------------------------------------------------------------- #
5 | DATASETS = CN()
6 | DATASETS.TRAIN = ("wireframe_train",)
7 | DATASETS.VAL = ("wireframe_test",)
8 | DATASETS.TEST = ("wireframe_test",)
9 | DATASETS.IMAGE = CN()
10 | DATASETS.IMAGE.HEIGHT = 512
11 | DATASETS.IMAGE.WIDTH = 512
12 |
13 | DATASETS.IMAGE.PIXEL_MEAN = [109.730, 103.832, 98.681]
14 | DATASETS.IMAGE.PIXEL_STD = [22.275, 22.124, 23.229]
15 | DATASETS.IMAGE.TO_255 = True
16 | DATASETS.TARGET = CN()
17 | DATASETS.TARGET.HEIGHT= 128
18 | DATASETS.TARGET.WIDTH = 128
19 | DATASETS.AUGMENTATION = 4
20 | DATASETS.DISTANCE_TH = 0.02
21 | DATASETS.NUM_STATIC_POSITIVE_LINES = 300
22 | DATASETS.NUM_STATIC_NEGATIVE_LINES = 40
23 |
24 | #
25 |
--------------------------------------------------------------------------------
/hawp/fsl/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 | from .models import MODELS
3 | from .dataset import DATASETS
4 | from .solver import SOLVER
5 | from .detr import DETR
6 | cfg = CN()
7 |
8 | cfg.ENCODER = CN()
9 | cfg.ENCODER.DIS_TH = 5
10 | cfg.ENCODER.ANG_TH = 0.1
11 | cfg.ENCODER.NUM_STATIC_POS_LINES = 300
12 | cfg.ENCODER.NUM_STATIC_NEG_LINES = 40
13 | cfg.ENCODER.BACKGROUND_WEIGHT = 0.0
14 | cfg.MODELING_PATH = 'hawp'
15 | cfg.MODEL = MODELS
16 | cfg.DATASETS = DATASETS
17 | cfg.SOLVER = SOLVER
18 |
19 | cfg.DATALOADER = CN()
20 | cfg.DATALOADER.NUM_WORKERS = 8
21 | cfg.OUTPUT_DIR = "outputs/dev"
22 |
--------------------------------------------------------------------------------
/hawp/fsl/config/detr.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 |
4 | DETR = CN()
5 |
6 | DETR.backbone = 'resnet50'
7 | DETR.dilation = False #dilated conv, DC5 for DETR
8 | DETR.position_embedding = 'sine'
9 | DETR.lr_backbone = 1e-5
10 | DETR.enc_layers = 6
11 | DETR.dec_layers = 6
12 | DETR.dim_feedforward = 2048
13 | DETR.hidden_dim = 256
14 | DETR.dropout = 0.1
15 | DETR.nheads = 8
16 | DETR.num_queries = 1000
17 | DETR.pre_norm = False
18 | DETR.eos_coef = 0.1 #"Relative classification weight of the no-object class"
19 |
20 | DETR.no_aux_loss = False
21 | DETR.set_cost_class = 1.0
22 | DETR.set_cost_lines = 5.0
23 |
24 |
25 |
--------------------------------------------------------------------------------
/hawp/fsl/config/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import MODELS
2 |
3 | __all__ = ['MODELS']
--------------------------------------------------------------------------------
/hawp/fsl/config/models/head.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | PARSING_HEAD = CN()
4 |
5 | PARSING_HEAD.MAX_DISTANCE = 5.0
6 |
7 | PARSING_HEAD.N_STC_POSL = 300
8 | PARSING_HEAD.N_STC_NEGL = 40
9 |
10 | PARSING_HEAD.MATCHING_STRATEGY = 'junction' #junction or line_adjusted
11 | PARSING_HEAD.N_DYN_JUNC = 300
12 | PARSING_HEAD.N_DYN_POSL = 300
13 | PARSING_HEAD.N_DYN_NEGL = 300
14 | PARSING_HEAD.N_DYN_OTHR = 0
15 | PARSING_HEAD.N_DYN_OTHR2 = 300
16 |
17 | PARSING_HEAD.N_PTS0 = 32
18 | PARSING_HEAD.N_PTS1 = 8
19 |
20 | PARSING_HEAD.DIM_LOI = 128
21 | PARSING_HEAD.DIM_FC = 1024
22 | PARSING_HEAD.USE_RESIDUAL = 1
23 | PARSING_HEAD.N_OUT_JUNC = 250
24 | PARSING_HEAD.N_OUT_LINE = 2500
25 | PARSING_HEAD.JMATCH_THRESHOLD = 1.5
26 | PARSING_HEAD.J2L_THRESHOLD = 1000.0
27 | PARSING_HEAD.JUNCTION_HM_THRESHOLD = 0.008
28 | #INFERENCE FLAGS
29 | #0, only use junctions to yield line segments
30 | #1, only use learned angles to yield line segments
31 | #2, match line segment proposals with junctions
32 | # PARSING_HEAD.INFERENCE = 0
--------------------------------------------------------------------------------
/hawp/fsl/config/models/models.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 | from .shg import HGNETS
3 | from .resnets import RESNETS
4 | from .head import PARSING_HEAD
5 |
6 | MODELS = CN()
7 |
8 | MODELS.NAME = "Hourglass"
9 | MODELS.HGNETS = HGNETS
10 | MODELS.RESNETS = RESNETS
11 | MODELS.DEVICE = "cuda"
12 | MODELS.WEIGHTS = ""
13 | MODELS.HEAD_SIZE = [[3], [1], [1], [2], [2]]
14 | MODELS.OUT_FEATURE_CHANNELS = 256
15 |
16 | MODELS.LOSS_WEIGHTS = CN(new_allowed=True)
17 | MODELS.PARSING_HEAD = PARSING_HEAD
18 | MODELS.SCALE = 1.0
19 |
20 | MODELS.USE_LINE_HEATMAP = False
21 | MODELS.USE_HR_JUNCTION = False
22 |
23 | MODELS.LOI_POOLING = CN()
24 | MODELS.LOI_POOLING.USE_INIT_LINES = True
25 | MODELS.LOI_POOLING.NUM_POINTS = 32
26 | MODELS.LOI_POOLING.DIM_EDGE_FEATURE = 16
27 | MODELS.LOI_POOLING.DIM_JUNCTION_FEATURE = 128
28 | MODELS.LOI_POOLING.DIM_FC = 1024
29 | MODELS.LOI_POOLING.TYPE = 'softmax'
30 | MODELS.LOI_POOLING.LAYER_NORM = False
31 | MODELS.LOI_POOLING.ACTIVATION = 'relu'
32 |
33 | MODELS.FOCAL_LOSS = CN()
34 | MODELS.FOCAL_LOSS.ALPHA = -1.0
35 | MODELS.FOCAL_LOSS.GAMMA = 0.0
--------------------------------------------------------------------------------
/hawp/fsl/config/models/proposal_head.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | PROPOSAL_HEAD = CN()
4 |
5 | PROPOSAL_HEAD.NUM_PROPOSALS = 512
6 | PROPOSAL_HEAD.ANGULAR_THRESHOLD = 0.01 #TODO: REMOVE
7 | PROPOSAL_HEAD.NUM_SAMPLE_POINTS = 32
8 | PROPOSAL_HEAD.POSITIVE_DIS_TH = 10.0
9 | PROPOSAL_HEAD.NEGATIVE_DIS_TH = 15.0
10 | PROPOSAL_HEAD.LOWEST_SCORE_TH = 0.05
11 | PROPOSAL_HEAD.POST_NMS_TH = 10.0
12 | PROPOSAL_HEAD.USE_1DPOOLING = False
13 | PROPOSAL_HEAD.MIN_DISTANCE = 0.0
14 | PROPOSAL_HEAD.MAX_DISTANCE = 1.5
15 | PROPOSAL_HEAD.NUM_DIS_PROPOSAL = 9
16 | PROPOSAL_HEAD.USE_EDGE = False
17 | PROPOSAL_HEAD.SHARE_WITH_JUNCTION_FEATURE = False
18 |
19 |
20 | PROPOSAL_HEAD.NUM_DYNAMIC_JUNCTIONS = 300
21 | PROPOSAL_HEAD.NUM_DYNAMIC_POSITIVE_LINES = 300
22 | PROPOSAL_HEAD.NUM_DYNAMIC_NEGATIVE_LINES = 80
23 | PROPOSAL_HEAD.NUM_DYNAMIC_OTHER_LINES = 600
24 |
25 | # PROPOSAL_HEAD.LOSS_WEIGHTS = CN()
26 | # PROPOSAL_HEAD.LOSS_WEIGHTS.AFM = 1.0
27 | # PROPOSAL_HEAD.LOSS_WEIGHTS.PROPOSAL = 1.0
28 |
--------------------------------------------------------------------------------
/hawp/fsl/config/models/resnets.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | RESNETS = CN()
4 | RESNETS.BASENET = 'resnet50'
5 | RESNETS.PRETRAIN = True
6 |
7 | # RESNETS.NUM_GROUPS = 1
8 | # RESNETS.WIDTH_PER_GROUP = 64
9 | # #RESNETS.STRIDE_IN_1X1 = True
10 |
11 | # # RESNETS.TRANS_FUNC = "BottleneckWithFixedBatchNorm"
12 | # RESNETS.NORM_FUNC = 'BatchNorm'
13 | # RESNETS.STEM_FUNC = "Stem"
14 | # RESNETS.STEM_OUT_CHANNELS = 64
15 | # RESNETS.STEM_KERNEL_SIZE = 7
16 | # RESNETS.STEM_RETURN_FEATURE = False
17 | # RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4
18 | # RESNETS.RES2_OUT_CHANNELS = 256
19 |
20 |
21 | # RESNETS.STAGE_SPECS = (3,4,6,3) # ResNet 50
22 |
23 |
--------------------------------------------------------------------------------
/hawp/fsl/config/models/shg.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | HGNETS = CN()
4 |
5 | HGNETS.DEPTH = 4
6 | HGNETS.NUM_STACKS = 2
7 | HGNETS.NUM_BLOCKS = 1
8 |
9 | HGNETS.INPLANES = 64
10 | HGNETS.NUM_FEATS = 128
11 |
--------------------------------------------------------------------------------
/hawp/fsl/config/paths_catalog.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | class DatasetCatalog(object):
5 |
6 | DATA_DIR = osp.abspath(osp.join(osp.dirname(__file__),
7 | '..','..','..','data'))
8 |
9 | DATASETS = {
10 | 'wireframe_train': {
11 | 'img_dir': 'wireframe/images',
12 | 'ann_file': 'wireframe/train.json',
13 | },
14 | 'wireframe_train-pseudo': {
15 | 'img_dir': 'wireframe-pseudo/images',
16 | 'ann_file': 'wireframe-pseudo/train.json',
17 | },
18 | 'wireframe_train-syn-export': {
19 | 'img_dir': 'wireframe-syn-export/images',
20 | 'ann_file': 'wireframe-syn-export/train.json',
21 | },
22 | 'wireframe_train-syn-export-1': {
23 | 'img_dir': 'wireframe-syn-export-ep30-iter100-th075/images',
24 | 'ann_file': 'wireframe-syn-export-ep30-iter100-th075/train.json',
25 | },
26 | 'wireframe_test1': {
27 | 'img_dir': 'wireframe/images',
28 | 'ann_file': 'wireframe/overfit.json',
29 | },
30 | 'synthetic_train': {
31 | 'img_dir': 'synthetic-shapes/images',
32 | 'ann_file': 'synthetic-shapes/train.json',
33 | },
34 | 'synthetic_test': {
35 | 'img_dir': 'synthetic-shapes/images',
36 | 'ann_file': 'synthetic-shapes/test.json',
37 | },
38 | 'cities_train': {
39 | 'img_dir': 'cities/images',
40 | 'ann_file': 'cities/train.json',
41 | },
42 | 'cities_test': {
43 | 'img_dir': 'cities/images',
44 | 'ann_file': 'cities/test.json',
45 | },
46 | 'wireframe_test': {
47 | 'img_dir': 'wireframe/images',
48 | 'ann_file': 'wireframe/test.json',
49 | },
50 | 'york_test': {
51 | 'img_dir': 'york/images',
52 | 'ann_file': 'york/test.json',
53 | },
54 | 'coco_train-val2017': {
55 | 'img_dir': 'coco/val2017',
56 | 'ann_file': 'coco/coco-wf-val.json',
57 | },
58 | 'coco_test-val2017': {
59 | 'img_dir': 'coco/val2017',
60 | 'ann_file': 'coco/coco-wf-val.json',
61 | }
62 | }
63 |
64 | @staticmethod
65 | def get(name):
66 | assert name in DatasetCatalog.DATASETS
67 | data_dir = DatasetCatalog.DATA_DIR
68 | attrs = DatasetCatalog.DATASETS[name]
69 |
70 | args = dict(
71 | root = osp.join(data_dir,attrs['img_dir']),
72 | ann_file = osp.join(data_dir,attrs['ann_file'])
73 | )
74 |
75 | if 'train' in name:
76 | return dict(factory="TrainDataset",args=args)
77 | if 'test' in name and 'ann_file' in attrs:
78 | return dict(factory="TestDatasetWithAnnotations",
79 | args=args)
80 | raise NotImplementedError()
--------------------------------------------------------------------------------
/hawp/fsl/config/solver.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 |
4 | SOLVER = CN()
5 | SOLVER.IMS_PER_BATCH = 6
6 | SOLVER.MAX_EPOCH = 30
7 | SOLVER.OPTIMIZER = "ADAM"
8 | SOLVER.BASE_LR = 0.01
9 | SOLVER.BACKBONE_LR_FACTOR=1.0
10 | SOLVER.BIAS_LR_FACTOR = 1
11 |
12 | SOLVER.MOMENTUM = 0.9
13 | SOLVER.WEIGHT_DECAY = 0.0002
14 | SOLVER.WEIGHT_DECAY_BIAS = 0
15 | SOLVER.GAMMA = 0.1
16 |
17 | SOLVER.STEPS = (25,)
18 | SOLVER.CHECKPOINT_PERIOD = 1
19 | SOLVER.AMSGRAD = False
20 |
--------------------------------------------------------------------------------
/hawp/fsl/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .train_dataset import TrainDataset
2 | from . import transforms
3 | from .build import build_train_dataset, build_test_dataset
4 | from .test_dataset import TestDatasetWithAnnotations
--------------------------------------------------------------------------------
/hawp/fsl/dataset/build.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .transforms import *
3 | from . import train_dataset
4 | from ..config.paths_catalog import DatasetCatalog
5 | from . import test_dataset
6 |
7 | def build_transform(cfg):
8 | transforms = Compose(
9 | [ResizeImage(cfg.DATASETS.IMAGE.HEIGHT,
10 | cfg.DATASETS.IMAGE.WIDTH),
11 | ToTensor(),
12 | Normalize(cfg.DATASETS.IMAGE.PIXEL_MEAN,
13 | cfg.DATASETS.IMAGE.PIXEL_STD,
14 | cfg.DATASETS.IMAGE.TO_255)
15 | ]
16 | )
17 |
18 | if cfg.MODEL.NAME == "PointLine":
19 | transforms = Compose(
20 | [ResizeImage(cfg.DATASETS.IMAGE.HEIGHT,
21 | cfg.DATASETS.IMAGE.WIDTH),
22 | ToTensor()
23 | ]
24 | )
25 |
26 | return transforms
27 | def build_train_dataset(cfg):
28 | assert len(cfg.DATASETS.TRAIN) == 1
29 | name = cfg.DATASETS.TRAIN[0]
30 | dargs = DatasetCatalog.get(name)
31 |
32 | factory = getattr(train_dataset,dargs['factory'])
33 | args = dargs['args']
34 | args['augmentation'] = cfg.DATASETS.AUGMENTATION
35 | args['transform'] = Compose(
36 | [Resize(cfg.DATASETS.IMAGE.HEIGHT,
37 | cfg.DATASETS.IMAGE.WIDTH,
38 | cfg.DATASETS.TARGET.HEIGHT,
39 | cfg.DATASETS.TARGET.WIDTH),
40 | ToTensor(),
41 | Normalize(cfg.DATASETS.IMAGE.PIXEL_MEAN,
42 | cfg.DATASETS.IMAGE.PIXEL_STD,
43 | cfg.DATASETS.IMAGE.TO_255)])
44 |
45 | if cfg.MODEL.NAME == "PointLine":
46 | args['transform'] = Compose(
47 | [Resize(cfg.DATASETS.IMAGE.HEIGHT,
48 | cfg.DATASETS.IMAGE.WIDTH,
49 | cfg.DATASETS.TARGET.HEIGHT,
50 | cfg.DATASETS.TARGET.WIDTH),
51 | ToTensor()])
52 |
53 |
54 | dataset = factory(**args)
55 |
56 | dataset = torch.utils.data.DataLoader(dataset,
57 | batch_size=cfg.SOLVER.IMS_PER_BATCH,
58 | collate_fn=train_dataset.collate_fn,
59 | shuffle = True,
60 | num_workers = cfg.DATALOADER.NUM_WORKERS)
61 | return dataset
62 |
63 | def build_test_dataset(cfg):
64 | transforms = Compose(
65 | [ResizeImage(cfg.DATASETS.IMAGE.HEIGHT,
66 | cfg.DATASETS.IMAGE.WIDTH),
67 | ToTensor(),
68 | Normalize(cfg.DATASETS.IMAGE.PIXEL_MEAN,
69 | cfg.DATASETS.IMAGE.PIXEL_STD,
70 | cfg.DATASETS.IMAGE.TO_255)
71 | ]
72 | )
73 |
74 | if cfg.MODEL.NAME == "PointLine":
75 | transforms = Compose(
76 | [ResizeImage(cfg.DATASETS.IMAGE.HEIGHT,
77 | cfg.DATASETS.IMAGE.WIDTH),
78 | ToTensor()
79 | ]
80 | )
81 |
82 | datasets = []
83 | for name in cfg.DATASETS.TEST:
84 | dargs = DatasetCatalog.get(name)
85 | factory = getattr(test_dataset,dargs['factory'])
86 | args = dargs['args']
87 | args['transform'] = transforms
88 | dataset = factory(**args)
89 | dataset = torch.utils.data.DataLoader(
90 | dataset, batch_size = 1,
91 | collate_fn = dataset.collate_fn,
92 | num_workers = cfg.DATALOADER.NUM_WORKERS,
93 | )
94 | datasets.append((name,dataset))
95 | return datasets
96 |
--------------------------------------------------------------------------------
/hawp/fsl/dataset/imagelist.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | try:
9 | import cv2 # pylint: disable=import-error
10 | except ImportError:
11 | cv2 = None
12 |
13 | import PIL
14 | try:
15 | import PIL.ImageGrab
16 | except ImportError:
17 | pass
18 |
19 | try:
20 | import mss
21 | except ImportError:
22 | mss = None
23 |
24 | import os
25 | import os.path as osp
26 | LOG = logging.getLogger(__name__)
27 |
28 | class ToTensor(object):
29 | def __call__(self,image):
30 | tensor = torch.from_numpy(image).float()/255.0
31 | tensor = tensor.permute((2,0,1)).contiguous()
32 | return tensor
33 |
34 | class Normalize(object):
35 | def __init__(self, mean, std):
36 | self.mean = mean
37 | self.std = std
38 | def __call__(self, image):
39 | image[0] = (image[0]-self.mean[0])/self.std[0]
40 | image[1] = (image[1]-self.mean[1])/self.std[1]
41 | image[2] = (image[2]-self.mean[2])/self.std[2]
42 | return image
43 |
44 | # pylint: disable=abstract-method
45 | class ImageList(torch.utils.data.Dataset):
46 | horizontal_flip = None
47 | rotate = None
48 | crop = None
49 | scale = 1.0
50 | start_frame = None
51 | start_msec = None
52 | max_frames = None
53 |
54 | def __init__(self, source, *,
55 | input_size,
56 | transforms,
57 | ):
58 | super().__init__()
59 |
60 | self.source = source
61 | self.input_size = input_size
62 | self.transforms = transforms
63 |
64 | self.filenames = sorted(os.listdir(source))
65 |
66 | def __len__(self):
67 | return len(self.filenames)
68 |
69 | # pylint: disable=unsubscriptable-object
70 | def preprocessing(self, image):
71 | if self.scale != 1.0:
72 | image = cv2.resize(image, None, fx=self.scale, fy=self.scale)
73 | LOG.debug('resized image size: %s', image.shape)
74 | if self.horizontal_flip:
75 | image = image[:, ::-1]
76 | if self.crop:
77 | if self.crop[0]:
78 | image = image[:, self.crop[0]:]
79 | if self.crop[1]:
80 | image = image[self.crop[1]:, :]
81 | if self.crop[2]:
82 | image = image[:, :-self.crop[2]]
83 | if self.crop[3]:
84 | image = image[:-self.crop[3], :]
85 | if self.rotate == 'left':
86 | image = np.swapaxes(image, 0, 1)
87 | image = np.flip(image, axis=0)
88 | elif self.rotate == 'right':
89 | image = np.swapaxes(image, 0, 1)
90 | image = np.flip(image, axis=1)
91 | elif self.rotate == '180':
92 | image = np.flip(image, axis=0)
93 | image = np.flip(image, axis=1)
94 |
95 | meta = {
96 | 'width': image.shape[1],
97 | 'height': image.shape[0],
98 | }
99 |
100 | processed_image = self.transforms(image)
101 |
102 | return image, processed_image, meta
103 |
104 | def __getitem__(self, id):
105 | fname = osp.join(self.source,self.filenames[id])
106 |
107 | image = cv2.imread(fname)
108 | meta = {
109 | 'width': image.shape[1],
110 | 'height': image.shape[0],
111 | }
112 | meta['frame_i'] = id
113 | meta['filename'] = '{:05d}'.format(id)
114 | processed_image = self.transforms(image)
115 | return image, processed_image, meta
116 |
--------------------------------------------------------------------------------
/hawp/fsl/dataset/iteration_based_batch_sampler.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.sampler import BatchSampler
2 |
3 |
4 | class IterationBasedBatchSampler(BatchSampler):
5 | """
6 | Wraps a BatchSampler, resampling from it until
7 | a specified number of iterations have been sampled
8 | """
9 |
10 | def __init__(self, batch_sampler, num_iterations, start_iter=0):
11 | self.batch_sampler = batch_sampler
12 | self.num_iterations = num_iterations
13 | self.start_iter = start_iter
14 |
15 | def __iter__(self):
16 | iteration = self.start_iter
17 | while iteration <= self.num_iterations:
18 | # if the underlying sampler has a set_epoch method, like
19 | # DistributedSampler, used for making each process see
20 | # a different split of the dataset, then set it
21 | if hasattr(self.batch_sampler.sampler, "set_epoch"):
22 | self.batch_sampler.sampler.set_epoch(iteration)
23 | for batch in self.batch_sampler:
24 | iteration += 1
25 | if iteration > self.num_iterations:
26 | break
27 | yield batch
28 |
29 | def __len__(self):
30 | return self.num_iterations
--------------------------------------------------------------------------------
/hawp/fsl/dataset/test_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from torch.utils.data.dataloader import default_collate
4 | import json
5 | import copy
6 | from PIL import Image
7 | from skimage import io
8 | import os
9 | import os.path as osp
10 | import numpy as np
11 | import cv2
12 | class TestDatasetWithAnnotations(Dataset):
13 | '''
14 | Format of the annotation file
15 | annotations[i] has the following dict items:
16 | - filename # of the input image, str
17 | - height # of the input image, int
18 | - width # of the input image, int
19 | - lines # of the input image, list of list, N*4
20 | - junc # of the input image, list of list, M*2
21 | '''
22 |
23 | def __init__(self, root, ann_file, transform = None):
24 | self.root = root
25 | with open(ann_file, 'r') as _:
26 | self.annotations = json.load(_)
27 | self.transform = transform
28 |
29 | def __len__(self):
30 | return len(self.annotations)
31 |
32 | def __getitem__(self, idx):
33 | ann = copy.deepcopy(self.annotations[idx])
34 | # image = Image.open(osp.join(self.root,ann['filename'])).convert('RGB')
35 | # image = io.imread(osp.join(self.root,ann['filename']))#.astype(float)[:,:,:3]
36 | image = cv2.imread(osp.join(self.root,ann['filename']), cv2.IMREAD_GRAYSCALE).astype(float)
37 | # image = cv2.imread(osp.join(self.root,ann['filename']))[:,:,::-1]
38 | if len(image.shape) == 2:
39 | image = np.stack((image,image,image),axis=-1)
40 |
41 | image = image.astype(float)[:,:,:3]
42 | # image = io.imread(osp.join(self.root,ann['filename'])).astype(float)[:,:,:3]
43 | for key, _type in (['junc',np.float32],
44 | ['lines', np.float32]):
45 | ann[key] = np.array(ann[key],dtype=_type)
46 |
47 | if self.transform is not None:
48 | return self.transform(image,ann)
49 | return image, ann
50 | def image(self, idx):
51 | ann = copy.deepcopy(self.annotations[idx])
52 | image = Image.open(osp.join(self.root,ann['filename'])).convert('RGB')
53 | return image
54 | @staticmethod
55 | def collate_fn(batch):
56 | return (default_collate([b[0] for b in batch]),
57 | [b[1] for b in batch])
--------------------------------------------------------------------------------
/hawp/fsl/dataset/train_dataset.bck.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 |
4 | import os.path as osp
5 | import json
6 | import cv2
7 | from skimage import io
8 | from PIL import Image
9 | import numpy as np
10 | import random
11 | from torch.utils.data.dataloader import default_collate
12 | from torch.utils.data.dataloader import DataLoader
13 | import matplotlib.pyplot as plt
14 | from torchvision.transforms import functional as F
15 | import copy
16 | class TrainDataset(Dataset):
17 | def __init__(self, root, ann_file, transform = None):
18 | self.root = root
19 | with open(ann_file,'r') as _:
20 | self.annotations = json.load(_)
21 | self.transform = transform
22 |
23 | def __getitem__(self, idx_):
24 | # print(idx_)
25 | idx = idx_%len(self.annotations)
26 | reminder = idx_//len(self.annotations)
27 | # idx = 0
28 | # reminder = 0
29 | ann = copy.deepcopy(self.annotations[idx])
30 | ann['reminder'] = reminder
31 | # image = io.imread(osp.join(self.root,ann['filename']))#.astype(float)[:,:,:3]
32 | # if len(image.shape) == 2:
33 | # image = np.stack((image,image,image),axis=-1)
34 |
35 | # image = image.astype(float)[:,:,:3]
36 |
37 | image = io.imread(osp.join(self.root,ann['filename']), as_gray=True).astype(float)#[:,:,:3]
38 | if len(image.shape) == 2:
39 | image = np.concatenate([image[...,None],image[...,None],image[...,None]],axis=-1)
40 | else:
41 | image = image[:,:,:3]
42 |
43 | # image = Image.open(osp.join(self.root,ann['filename'])).convert('RGB')
44 | for key,_type in (['junctions',np.float32],
45 | ['edges_positive',np.int32],
46 | ['edges_negative',np.int32]):
47 | ann[key] = np.array(ann[key],dtype=_type)
48 |
49 | width = ann['width']
50 | height = ann['height']
51 | if reminder == 1:
52 | image = image[:,::-1,:]
53 | # image = F.hflip(image)
54 | ann['junctions'][:,0] = width-ann['junctions'][:,0]
55 | elif reminder == 2:
56 | # image = F.vflip(image)
57 | image = image[::-1,:,:]
58 | ann['junctions'][:,1] = height-ann['junctions'][:,1]
59 | elif reminder == 3:
60 | # image = F.vflip(F.hflip(image))
61 | image = image[::-1,::-1,:]
62 | ann['junctions'][:,0] = width-ann['junctions'][:,0]
63 | ann['junctions'][:,1] = height-ann['junctions'][:,1]
64 | else:
65 | pass
66 |
67 | if self.transform is not None:
68 | return self.transform(image,ann)
69 | return image, ann
70 |
71 | def __len__(self):
72 | return len(self.annotations)*4
73 | # return 1000
74 |
75 | def collate_fn(batch):
76 | return (default_collate([b[0] for b in batch]),
77 | [b[1] for b in batch])
--------------------------------------------------------------------------------
/hawp/fsl/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_model
--------------------------------------------------------------------------------
/hawp/fsl/model/build.py:
--------------------------------------------------------------------------------
1 | from . import models
2 |
3 |
4 | def build_model(cfg):
5 | model = models.WireframeDetector(cfg)
6 |
7 | return model
--------------------------------------------------------------------------------
/hawp/fsl/model/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | def cross_entropy_loss_for_junction(logits, positive):
4 | nlogp = -F.log_softmax(logits, dim=1)
5 |
6 | loss = (positive * nlogp[:, None, 1] + (1 - positive) * nlogp[:, None, 0])
7 |
8 | return loss.mean()
9 |
10 | def focal_loss_for_junction(logits, positive, gamma=2.0):
11 | prob = F.softmax(logits, 1)
12 | ce_loss = F.cross_entropy(logits, positive, reduction='none')
13 | p_t = prob[:,1:]*positive + prob[:,:1]*(1-positive)
14 | loss = ce_loss * ((1-p_t)**gamma)
15 |
16 | return loss.mean()
17 |
18 | def sigmoid_l1_loss(logits, targets, offset = 0.0, mask=None):
19 | logp = torch.sigmoid(logits) + offset
20 | loss = torch.abs(logp-targets)
21 |
22 | if mask is not None:
23 | w = mask.mean(3, True).mean(2,True)
24 | w[w==0] = 1
25 | loss = loss*(mask/w)
26 |
27 | return loss.mean()
28 |
29 |
30 | def sigmoid_focal_loss(
31 | inputs: torch.Tensor,
32 | targets: torch.Tensor,
33 | alpha: float = -1,
34 | gamma: float = 2,
35 | reduction: str = "none",
36 | ) -> torch.Tensor:
37 | """
38 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
39 | Args:
40 | inputs: A float tensor of arbitrary shape.
41 | The predictions for each example.
42 | targets: A float tensor with the same shape as inputs. Stores the binary
43 | classification label for each element in inputs
44 | (0 for the negative class and 1 for the positive class).
45 | alpha: (optional) Weighting factor in range (0,1) to balance
46 | positive vs negative examples. Default = -1 (no weighting).
47 | gamma: Exponent of the modulating factor (1 - p_t) to
48 | balance easy vs hard examples.
49 | reduction: 'none' | 'mean' | 'sum'
50 | 'none': No reduction will be applied to the output.
51 | 'mean': The output will be averaged.
52 | 'sum': The output will be summed.
53 | Returns:
54 | Loss tensor with the reduction option applied.
55 | """
56 | inputs = inputs.float()
57 | targets = targets.float()
58 | p = torch.sigmoid(inputs)
59 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
60 | p_t = p * targets + (1 - p) * (1 - targets)
61 | loss = ce_loss * ((1 - p_t) ** gamma)
62 |
63 | if alpha >= 0:
64 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
65 | loss = alpha_t * loss
66 |
67 | if reduction == "mean":
68 | loss = loss.mean()
69 | elif reduction == "sum":
70 | loss = loss.sum()
71 |
72 | return loss
--------------------------------------------------------------------------------
/hawp/fsl/model/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 |
6 | def non_maximum_suppression(a):
7 | ap = F.max_pool2d(a, 3, stride=1, padding=1)
8 | mask = (a == ap).float().clamp(min=0.0)
9 |
10 | return a * mask
11 |
12 | def get_junctions(jloc, joff, topk = 300, th=0):
13 | height, width = jloc.size(1), jloc.size(2)
14 | jloc = jloc.reshape(-1)
15 | joff = joff.reshape(2, -1)
16 |
17 |
18 | scores, index = torch.topk(jloc, k=topk)
19 | # y = (index // width).float() + torch.gather(joff[1], 0, index) + 0.5
20 | y = torch.div(index,width,rounding_mode='trunc').float()+ torch.gather(joff[1], 0, index) + 0.5
21 | x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5
22 |
23 | junctions = torch.stack((x, y)).t()
24 |
25 | return junctions[scores>th], scores[scores>th]
26 |
27 | def plot_lines(lines, scale=1.0, color = 'red', **kwargs):
28 | if isinstance(lines, np.ndarray):
29 | plt.plot([lines[:,0]*scale,lines[:,2]*scale],[lines[:,1]*scale,lines[:,3]*scale],color=color,linestyle='-')
30 | else:
31 | lines_np = lines.detach().cpu().numpy()
32 | plt.plot([lines_np[:,0]*scale,lines_np[:,2]*scale],[lines_np[:,1]*scale,lines_np[:,3]*scale],color=color,linestyle='-')
33 |
34 |
--------------------------------------------------------------------------------
/hawp/fsl/point_model/point_model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/hawp/fsl/point_model/point_model.pth
--------------------------------------------------------------------------------
/hawp/fsl/solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def make_optimizer(cfg, model, loss_reducer = None):
4 |
5 | params = []
6 |
7 | for key, value in model.named_parameters():
8 | if not value.requires_grad:
9 | continue
10 |
11 | lr=cfg.SOLVER.BASE_LR
12 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
13 | # if 'md_predictor' in key or 'st_predictor' in key or 'ed_predictor' in key:
14 | # lr = cfg.SOLVER.BASE_LR*100.0
15 |
16 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
17 |
18 | if loss_reducer is not None:
19 | for key, value in loss_reducer.named_parameters():
20 | params += [{"params": [value], "lr": lr, "weight_decay": 0}]
21 |
22 |
23 | if cfg.SOLVER.OPTIMIZER == 'SGD':
24 | optimizer = torch.optim.SGD(params,
25 | cfg.SOLVER.BASE_LR,
26 | momentum=cfg.SOLVER.MOMENTUM,
27 | weight_decay=cfg.SOLVER.WEIGHT_DECAY)
28 | elif cfg.SOLVER.OPTIMIZER == 'ADAM':
29 | optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR,
30 | weight_decay=cfg.SOLVER.WEIGHT_DECAY,
31 | amsgrad=cfg.SOLVER.AMSGRAD)
32 | elif cfg.SOLVER.OPTIMIZER == 'ADAMW':
33 | optimizer = torch.optim.AdamW(params,
34 | cfg.SOLVER.BASE_LR,betas=(0.9,0.999),
35 | eps=1e-8,
36 | weight_decay=cfg.SOLVER.WEIGHT_DECAY,
37 | amsgrad=cfg.SOLVER.AMSGRAD,
38 | )
39 | else:
40 | raise NotImplementedError()
41 | return optimizer
42 |
43 | def make_lr_scheduler(cfg,optimizer):
44 | return torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=cfg.SOLVER.STEPS,gamma=cfg.SOLVER.GAMMA)
--------------------------------------------------------------------------------
/hawp/ssl/__init__.py:
--------------------------------------------------------------------------------
1 | from . import config, datasets, models
2 |
--------------------------------------------------------------------------------
/hawp/ssl/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .project_config import Config
2 | from .utils import *
--------------------------------------------------------------------------------
/hawp/ssl/config/exports/wireframe-100iters.yaml:
--------------------------------------------------------------------------------
1 | ### General dataset parameters
2 | dataset_name: "wireframe"
3 | add_augmentation_to_all_splits: False
4 | gray_scale: True
5 | # Ground truth source ('official' or path to the exported h5 dataset.)
6 | # gt_source_train: "" # Fill with your own export file
7 | # gt_source_test: "" # Fill with your own export file
8 | # Return type: (1) single (to train the detector only)
9 | # or (2) paired_desc (to train the detector + descriptor)
10 | return_type: "single"
11 | random_seed: 0
12 |
13 | ### Descriptor training parameters
14 | # Number of points extracted per line
15 | max_num_samples: 10
16 | # Max number of training line points extracted in the whole image
17 | max_pts: 1000
18 | # Min distance between two points on a line (in pixels)
19 | min_dist_pts: 10
20 | # Small jittering of the sampled points during training
21 | jittering: 0
22 |
23 | ### Data preprocessing configuration
24 | preprocessing:
25 | resize: [512, 512]
26 | blur_size: 11
27 | augmentation:
28 | random_scaling:
29 | enable: True
30 | range: [0.7, 1.5]
31 | photometric:
32 | enable: true
33 | primitives: ['random_brightness', 'random_contrast',
34 | 'additive_speckle_noise', 'additive_gaussian_noise',
35 | 'additive_shade', 'motion_blur' ]
36 | params:
37 | random_brightness: {brightness: 0.2}
38 | random_contrast: {contrast: [0.3, 1.5]}
39 | additive_gaussian_noise: {stddev_range: [0, 10]}
40 | additive_speckle_noise: {prob_range: [0, 0.0035]}
41 | additive_shade:
42 | transparency_range: [-0.5, 0.5]
43 | kernel_size_range: [100, 150]
44 | motion_blur: {max_kernel_size: 3}
45 | random_order: True
46 | homographic:
47 | enable: true
48 | params:
49 | translation: true
50 | rotation: true
51 | scaling: true
52 | perspective: true
53 | scaling_amplitude: 0.2
54 | perspective_amplitude_x: 0.2
55 | perspective_amplitude_y: 0.2
56 | patch_ratio: 0.85
57 | max_angle: 1.57
58 | allow_artifacts: true
59 | valid_border_margin: 3
60 |
61 | ## Homography adaptation configuration
62 | homography_adaptation:
63 | num_iter: 10
64 | valid_border_margin: 3
65 | min_counts: 30
66 | homographies:
67 | translation: true
68 | rotation: true
69 | scaling: true
70 | perspective: true
71 | scaling_amplitude: 0.2
72 | perspective_amplitude_x: 0.2
73 | perspective_amplitude_y: 0.2
74 | allow_artifacts: true
75 | patch_ratio: 0.85
--------------------------------------------------------------------------------
/hawp/ssl/config/exports/wireframe-10iters.yaml:
--------------------------------------------------------------------------------
1 | ### General dataset parameters
2 | dataset_name: "wireframe"
3 | add_augmentation_to_all_splits: False
4 | gray_scale: True
5 | # Ground truth source ('official' or path to the exported h5 dataset.)
6 | # gt_source_train: "" # Fill with your own export file
7 | # gt_source_test: "" # Fill with your own export file
8 | # Return type: (1) single (to train the detector only)
9 | # or (2) paired_desc (to train the detector + descriptor)
10 | return_type: "single"
11 | random_seed: 0
12 |
13 | ### Descriptor training parameters
14 | # Number of points extracted per line
15 | max_num_samples: 10
16 | # Max number of training line points extracted in the whole image
17 | max_pts: 1000
18 | # Min distance between two points on a line (in pixels)
19 | min_dist_pts: 10
20 | # Small jittering of the sampled points during training
21 | jittering: 0
22 |
23 | ### Data preprocessing configuration
24 | preprocessing:
25 | resize: [512, 512]
26 | blur_size: 11
27 | augmentation:
28 | random_scaling:
29 | enable: True
30 | range: [0.7, 1.5]
31 | photometric:
32 | enable: true
33 | primitives: ['random_brightness', 'random_contrast',
34 | 'additive_speckle_noise', 'additive_gaussian_noise',
35 | 'additive_shade', 'motion_blur' ]
36 | params:
37 | random_brightness: {brightness: 0.2}
38 | random_contrast: {contrast: [0.3, 1.5]}
39 | additive_gaussian_noise: {stddev_range: [0, 10]}
40 | additive_speckle_noise: {prob_range: [0, 0.0035]}
41 | additive_shade:
42 | transparency_range: [-0.5, 0.5]
43 | kernel_size_range: [100, 150]
44 | motion_blur: {max_kernel_size: 3}
45 | random_order: True
46 | homographic:
47 | enable: true
48 | params:
49 | translation: true
50 | rotation: true
51 | scaling: true
52 | perspective: true
53 | scaling_amplitude: 0.2
54 | perspective_amplitude_x: 0.2
55 | perspective_amplitude_y: 0.2
56 | patch_ratio: 0.85
57 | max_angle: 1.57
58 | allow_artifacts: true
59 | valid_border_margin: 3
60 |
61 | ## Homography adaptation configuration
62 | homography_adaptation:
63 | num_iter: 10
64 | valid_border_margin: 3
65 | min_counts: 3
66 | homographies:
67 | translation: true
68 | rotation: true
69 | scaling: true
70 | perspective: true
71 | scaling_amplitude: 0.2
72 | perspective_amplitude_x: 0.2
73 | perspective_amplitude_y: 0.2
74 | allow_artifacts: true
75 | patch_ratio: 0.85
--------------------------------------------------------------------------------
/hawp/ssl/config/hawpv3-hrheat.yaml:
--------------------------------------------------------------------------------
1 | ENCODER:
2 | ANG_TH: 0.0
3 | BACKGROUND_WEIGHT: 0.0
4 | DIS_TH: 2
5 | NUM_STATIC_NEG_LINES: 0
6 | NUM_STATIC_POS_LINES: 300
7 | MODEL:
8 | DEVICE: cuda
9 | USE_LINE_HEATMAP: True
10 | USE_HR_JUNCTION: True
11 | HEAD_SIZE:
12 | - - 3
13 | - - 1
14 | - - 1
15 | - - 2
16 | - - 2
17 | HGNETS:
18 | DEPTH: 4
19 | INPLANES: 64
20 | NUM_BLOCKS: 1
21 | NUM_FEATS: 128
22 | NUM_STACKS: 2
23 | LOI_POOLING:
24 | ACTIVATION: relu
25 | DIM_EDGE_FEATURE: 4
26 | DIM_FC: 1024
27 | DIM_JUNCTION_FEATURE: 128
28 | LAYER_NORM: false
29 | NUM_POINTS: 32
30 | TYPE: softmax
31 | LOSS_WEIGHTS:
32 | loss_aux: 1.0
33 | loss_dis: 1.0
34 | loss_jloc: 8.0
35 | loss_joff: 0.25
36 | # loss_joff: 0.0
37 | loss_johr: 0.25
38 | loss_lineness: 1.0
39 | loss_md: 1.0
40 | loss_neg: 1.0
41 | loss_pos: 1.0
42 | loss_res: 1.0
43 | loss_heatmap: 1.0
44 | loss_jmap: 1.0
45 | NAME: Hourglass
46 | OUT_FEATURE_CHANNELS: 256
47 | PARSING_HEAD:
48 | DIM_FC: 1024
49 | DIM_LOI: 128
50 | J2L_THRESHOLD: 10.0
51 | JMATCH_THRESHOLD: 1.5
52 | MATCHING_STRATEGY: junction
53 | MAX_DISTANCE: 5.0
54 | N_DYN_JUNC: 300
55 | N_DYN_NEGL: 40
56 | N_DYN_OTHR: 0
57 | N_DYN_OTHR2: 300
58 | N_DYN_POSL: 300
59 | N_OUT_JUNC: 250
60 | N_OUT_LINE: 2500
61 | N_PTS0: 32
62 | N_PTS1: 8
63 | N_STC_NEGL: 40
64 | N_STC_POSL: 300
65 | USE_RESIDUAL: 1
66 | RESNETS:
67 | BASENET: resnet50
68 | PRETRAIN: true
69 | SCALE: 1.0
70 | WEIGHTS: ''
71 | MODELING_PATH: ihawp-v2
72 | OUTPUT_DIR: output/ihawp
73 | SOLVER:
74 | AMSGRAD: true
75 | BACKBONE_LR_FACTOR: 1.0
76 | BASE_LR: 0.00004
77 | BIAS_LR_FACTOR: 1
78 | CHECKPOINT_PERIOD: 1
79 | GAMMA: 0.1
80 | IMS_PER_BATCH: 6
81 | MAX_EPOCH: 30
82 | MOMENTUM: 0.9
83 | OPTIMIZER: ADAM
84 | STEPS:
85 | - 25
86 | WEIGHT_DECAY: 0.0001
87 | WEIGHT_DECAY_BIAS: 0
88 |
--------------------------------------------------------------------------------
/hawp/ssl/config/hawpv3.yaml:
--------------------------------------------------------------------------------
1 | ENCODER:
2 | ANG_TH: 0.0
3 | BACKGROUND_WEIGHT: 0.0
4 | DIS_TH: 2
5 | NUM_STATIC_NEG_LINES: 0
6 | NUM_STATIC_POS_LINES: 300
7 | MODEL:
8 | DEVICE: cuda
9 | USE_LINE_HEATMAP: True
10 | HEAD_SIZE:
11 | - - 3
12 | - - 1
13 | - - 1
14 | - - 2
15 | - - 2
16 | HGNETS:
17 | DEPTH: 4
18 | INPLANES: 64
19 | NUM_BLOCKS: 1
20 | NUM_FEATS: 128
21 | NUM_STACKS: 2
22 | LOI_POOLING:
23 | ACTIVATION: relu
24 | DIM_EDGE_FEATURE: 4
25 | DIM_FC: 1024
26 | DIM_JUNCTION_FEATURE: 128
27 | LAYER_NORM: false
28 | NUM_POINTS: 32
29 | TYPE: softmax
30 | LOSS_WEIGHTS:
31 | loss_aux: 1.0
32 | loss_dis: 1.0
33 | loss_jloc: 8.0
34 | loss_joff: 0.25
35 | # loss_joff: 0.0
36 | loss_lineness: 1.0
37 | loss_md: 1.0
38 | loss_neg: 1.0
39 | loss_pos: 1.0
40 | loss_res: 1.0
41 | loss_heatmap: 1.0
42 | NAME: Hourglass
43 | OUT_FEATURE_CHANNELS: 256
44 | PARSING_HEAD:
45 | DIM_FC: 1024
46 | DIM_LOI: 128
47 | J2L_THRESHOLD: 10.0
48 | JMATCH_THRESHOLD: 1.5
49 | MATCHING_STRATEGY: junction
50 | MAX_DISTANCE: 5.0
51 | N_DYN_JUNC: 300
52 | N_DYN_NEGL: 40
53 | N_DYN_OTHR: 0
54 | N_DYN_OTHR2: 300
55 | N_DYN_POSL: 300
56 | N_OUT_JUNC: 250
57 | N_OUT_LINE: 2500
58 | N_PTS0: 32
59 | N_PTS1: 8
60 | N_STC_NEGL: 40
61 | N_STC_POSL: 300
62 | USE_RESIDUAL: 1
63 | RESNETS:
64 | BASENET: resnet50
65 | PRETRAIN: true
66 | SCALE: 1.0
67 | WEIGHTS: ''
68 | MODELING_PATH: ihawp-v2
69 | OUTPUT_DIR: output/ihawp
70 | SOLVER:
71 | AMSGRAD: true
72 | BACKBONE_LR_FACTOR: 1.0
73 | BASE_LR: 0.00004
74 | BIAS_LR_FACTOR: 1
75 | CHECKPOINT_PERIOD: 1
76 | GAMMA: 0.1
77 | IMS_PER_BATCH: 6
78 | MAX_EPOCH: 30
79 | MOMENTUM: 0.9
80 | OPTIMIZER: ADAM
81 | STEPS:
82 | - 25
83 | WEIGHT_DECAY: 0.0001
84 | WEIGHT_DECAY_BIAS: 0
85 |
--------------------------------------------------------------------------------
/hawp/ssl/config/project_config.py:
--------------------------------------------------------------------------------
1 | """
2 | Project configurations.
3 | """
4 | import os
5 |
6 |
7 | class Config(object):
8 | """ Datasets and experiments folders for the whole project. """
9 | #####################
10 | ## Dataset setting ##
11 | #####################
12 | default_dataroot = os.path.join(
13 | os.path.dirname(__file__),
14 | '..','..','..','data-ssl'
15 | )
16 | default_dataroot = os.path.abspath(default_dataroot)
17 | default_exproot = os.path.join(
18 | os.path.dirname(__file__),
19 | '..','..','..','exp-ssl'
20 | )
21 | default_exproot = os.path.abspath(default_exproot)
22 |
23 | DATASET_ROOT = os.getenv("DATASET_ROOT", default_dataroot) # TODO: path to your datasets folder
24 | if not os.path.exists(DATASET_ROOT):
25 | os.makedirs(DATASET_ROOT)
26 |
27 | # Synthetic shape dataset
28 | synthetic_dataroot = os.path.join(DATASET_ROOT, "synthetic_shapes")
29 | synthetic_cache_path = os.path.join(DATASET_ROOT, "synthetic_shapes")
30 | if not os.path.exists(synthetic_dataroot):
31 | os.makedirs(synthetic_dataroot)
32 |
33 | # Exported predictions dataset
34 | export_dataroot = os.path.join(DATASET_ROOT, "export_datasets")
35 | export_cache_path = os.path.join(DATASET_ROOT, "export_datasets")
36 | if not os.path.exists(export_dataroot):
37 | os.makedirs(export_dataroot)
38 |
39 | # York Urban dataset
40 | yorkurban_dataroot = os.path.join(DATASET_ROOT, "YorkUrbanDB")
41 | yorkurban_cache_path = os.path.join(DATASET_ROOT, "YorkUrbanDB")
42 |
43 | # Wireframe dataset
44 | wireframe_dataroot = os.path.join(DATASET_ROOT, "wireframe")
45 | wireframe_cache_path = os.path.join(DATASET_ROOT, "wireframe")
46 |
47 | # Holicity dataset
48 | holicity_dataroot = os.path.join(DATASET_ROOT, "Holicity")
49 | holicity_cache_path = os.path.join(DATASET_ROOT, "Holicity")
50 |
51 | ########################
52 | ## Experiment Setting ##
53 | ########################
54 | EXP_PATH = os.getenv("EXP_PATH", default_exproot) # TODO: path to your experiments folder
55 |
56 | if not os.path.exists(EXP_PATH):
57 | os.makedirs(EXP_PATH)
58 |
--------------------------------------------------------------------------------
/hawp/ssl/config/synthetic_dataset-4k.yaml:
--------------------------------------------------------------------------------
1 | ### General dataset parameters
2 | dataset_name: "synthetic_shape"
3 | primitives: "all"
4 | add_augmentation_to_all_splits: True
5 | test_augmentation_seed: 200
6 | alias: 4k
7 | # Shape generation configuration
8 | generation:
9 | # split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
10 | split_sizes: {'train': 4000, 'val': 2000, 'test': 400}
11 | random_seed: 10
12 | image_size: [960, 1280]
13 | min_len: 0.0985
14 | min_label_len: 0.099
15 | params:
16 | generate_background:
17 | min_kernel_size: 150
18 | max_kernel_size: 500
19 | min_rad_ratio: 0.02
20 | max_rad_ratio: 0.031
21 | draw_stripes:
22 | transform_params: [0.1, 0.1]
23 | draw_multiple_polygons:
24 | kernel_boundaries: [50, 100]
25 |
26 | ### Data preprocessing configuration.
27 | preprocessing:
28 | resize: [512, 512]
29 | blur_size: 11
30 | augmentation:
31 | photometric:
32 | enable: True
33 | primitives: 'all'
34 | params: {}
35 | random_order: True
36 | homographic:
37 | enable: True
38 | params:
39 | translation: true
40 | rotation: true
41 | scaling: true
42 | perspective: true
43 | scaling_amplitude: 0.2
44 | perspective_amplitude_x: 0.2
45 | perspective_amplitude_y: 0.2
46 | patch_ratio: 0.8
47 | max_angle: 1.57
48 | allow_artifacts: true
49 | translation_overflow: 0.05
50 | valid_border_margin: 0
51 |
--------------------------------------------------------------------------------
/hawp/ssl/config/synthetic_dataset.yaml:
--------------------------------------------------------------------------------
1 | ### General dataset parameters
2 | dataset_name: "synthetic_shape"
3 | primitives: "all"
4 | add_augmentation_to_all_splits: True
5 | test_augmentation_seed: 200
6 | # Shape generation configuration
7 | generation:
8 | # split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
9 | split_sizes: {'train': 2000, 'val': 2000, 'test': 400}
10 | random_seed: 10
11 | image_size: [960, 1280]
12 | min_len: 0.0985
13 | min_label_len: 0.099
14 | params:
15 | generate_background:
16 | min_kernel_size: 150
17 | max_kernel_size: 500
18 | min_rad_ratio: 0.02
19 | max_rad_ratio: 0.031
20 | draw_stripes:
21 | transform_params: [0.1, 0.1]
22 | draw_multiple_polygons:
23 | kernel_boundaries: [50, 100]
24 |
25 | ### Data preprocessing configuration.
26 | preprocessing:
27 | resize: [512, 512]
28 | blur_size: 11
29 | augmentation:
30 | photometric:
31 | enable: True
32 | primitives: 'all'
33 | params: {}
34 | random_order: True
35 | homographic:
36 | enable: True
37 | params:
38 | translation: true
39 | rotation: true
40 | scaling: true
41 | perspective: true
42 | scaling_amplitude: 0.2
43 | perspective_amplitude_x: 0.2
44 | perspective_amplitude_y: 0.2
45 | patch_ratio: 0.8
46 | max_angle: 1.57
47 | allow_artifacts: true
48 | translation_overflow: 0.05
49 | valid_border_margin: 0
50 |
--------------------------------------------------------------------------------
/hawp/ssl/config/utils.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import os
3 |
4 | def load_config(config_path):
5 | """ Load configurations from a given yaml file. """
6 | # Check file exists
7 | if not os.path.exists(config_path):
8 | raise ValueError("[Error] The provided config path is not valid.")
9 |
10 | # Load the configuration
11 | with open(config_path, "r") as f:
12 | config = yaml.safe_load(f)
13 |
14 | return config
15 |
16 | def update_config(path, model_cfg=None, dataset_cfg=None):
17 | """ Update configuration file from the resume path. """
18 | # Check we need to update or completely override.
19 | model_cfg = {} if model_cfg is None else model_cfg
20 | dataset_cfg = {} if dataset_cfg is None else dataset_cfg
21 |
22 | # Load saved configs
23 | with open(os.path.join(path, "model_cfg.yaml"), "r") as f:
24 | model_cfg_saved = yaml.safe_load(f)
25 | model_cfg.update(model_cfg_saved)
26 | with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f:
27 | dataset_cfg_saved = yaml.safe_load(f)
28 | dataset_cfg.update(dataset_cfg_saved)
29 |
30 | # Update the saved yaml file
31 | if not model_cfg == model_cfg_saved:
32 | with open(os.path.join(path, "model_cfg.yaml"), "w") as f:
33 | yaml.dump(model_cfg, f)
34 | if not dataset_cfg == dataset_cfg_saved:
35 | with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f:
36 | yaml.dump(dataset_cfg, f)
37 |
38 | return model_cfg, dataset_cfg
39 |
40 | def record_config(model_cfg, dataset_cfg, output_path):
41 | """ Record dataset config to the log path. """
42 | # Record model config
43 | with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f:
44 | yaml.safe_dump(model_cfg, f)
45 |
46 | # Record dataset config
47 | with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f:
48 | yaml.safe_dump(dataset_cfg, f)
--------------------------------------------------------------------------------
/hawp/ssl/config/wireframe_official_gt.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: "wireframe"
2 | add_augmentation_to_all_splits: False
3 | gray_scale: True
4 | # Ground truth source (official or path to the epxorted h5 dataset.)
5 | gt_source_train: "official"
6 | gt_source_test: "official"
7 | # Date preprocessing configuration.
8 | preprocessing:
9 | resize: [512, 512]
10 | blur_size: 11
11 | augmentation:
12 | photometric:
13 | enable: True
14 | homographic:
15 | enable: True
16 | # The homography adaptation configuration
17 | homography_adaptation:
18 | num_iter: 100
19 | aggregation: 'sum'
20 | mode: 'ver1'
21 | valid_border_margin: 3
22 | min_counts: 30
23 | homographies:
24 | translation: true
25 | rotation: true
26 | scaling: true
27 | perspective: true
28 | scaling_amplitude: 0.2
29 | perspective_amplitude_x: 0.2
30 | perspective_amplitude_y: 0.2
31 | allow_artifacts: true
32 | patch_ratio: 0.85
33 | # Evaluation related config
34 | evaluation:
35 | repeatability:
36 | # Initial random seed used to sample homographic augmentation
37 | seed: 200
38 | # Parameter used to sample illumination change evaluation set.
39 | photometric:
40 | enable: False
41 | # Parameter used to sample viewpoint change evaluation set.
42 | homographic:
43 | enable: True
44 | num_samples: 2
45 | params:
46 | translation: true
47 | rotation: true
48 | scaling: true
49 | perspective: true
50 | scaling_amplitude: 0.2
51 | perspective_amplitude_x: 0.2
52 | perspective_amplitude_y: 0.2
53 | patch_ratio: 0.85
54 | max_angle: 1.57
55 | allow_artifacts: true
56 | valid_border_margin: 3
--------------------------------------------------------------------------------
/hawp/ssl/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/hawp/ssl/datasets/__init__.py
--------------------------------------------------------------------------------
/hawp/ssl/datasets/dataset_util.py:
--------------------------------------------------------------------------------
1 | """
2 | The interface of initializing different datasets.
3 | """
4 | from .synthetic_dataset import SyntheticShapes,synthetic_collate_fn
5 | from .wireframe_dataset import WireframeDataset,wireframe_collate_fn
6 | from .yorkurban_dataset import YorkUrbanDataset,yorkurban_collate_fn
7 | from .images_dataset import ImageCollections, images_collate_fn
8 | # from .holicity_dataset import HolicityDataset
9 | # from .merge_dataset import MergeDataset
10 |
11 |
12 | def get_dataset(mode="train", dataset_cfg=None):
13 | """ Initialize different dataset based on a configuration. """
14 | # Check dataset config is given
15 | if dataset_cfg is None:
16 | raise ValueError("[Error] The dataset config is required!")
17 |
18 | # Synthetic dataset
19 | if dataset_cfg["dataset_name"] == "synthetic_shape":
20 | dataset = SyntheticShapes(
21 | mode, dataset_cfg
22 | )
23 | # Get the collate_fn
24 | # from sold2.dataset.synthetic_dataset import synthetic_collate_fn
25 | collate_fn = synthetic_collate_fn
26 |
27 | # Wireframe dataset
28 | elif dataset_cfg["dataset_name"] == "wireframe":
29 | dataset = WireframeDataset(
30 | mode, dataset_cfg
31 | )
32 |
33 | # Get the collate_fn
34 | collate_fn = wireframe_collate_fn
35 | elif dataset_cfg["dataset_name"] == "yorkurban":
36 | dataset = YorkUrbanDataset(
37 | mode, dataset_cfg
38 | )
39 |
40 | # Get the collate_fn
41 | collate_fn = yorkurban_collate_fn
42 | # Holicity dataset
43 | elif dataset_cfg["dataset_name"] == "holicity":
44 | dataset = HolicityDataset(
45 | mode, dataset_cfg
46 | )
47 |
48 | # Get the collate_fn
49 | from sold2.dataset.holicity_dataset import holicity_collate_fn
50 | collate_fn = holicity_collate_fn
51 |
52 | # Dataset merging several datasets in one
53 | elif dataset_cfg["dataset_name"] == "merge":
54 | dataset = MergeDataset(
55 | mode, dataset_cfg
56 | )
57 |
58 | # Get the collate_fn
59 | from sold2.dataset.holicity_dataset import holicity_collate_fn
60 | collate_fn = holicity_collate_fn
61 | elif dataset_cfg["dataset_name"] == "general":
62 | dataset = ImageCollections(mode, dataset_cfg)
63 | collate_fn = images_collate_fn
64 | else:
65 | raise ValueError(
66 | "[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"])
67 |
68 | return dataset, collate_fn
69 |
--------------------------------------------------------------------------------
/hawp/ssl/datasets/transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/hawp/ssl/datasets/transforms/__init__.py
--------------------------------------------------------------------------------
/hawp/ssl/datasets/wireframe.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 |
4 | import os.path as osp
5 | import json
6 | import cv2
7 | from skimage import io
8 | from PIL import Image
9 | import numpy as np
10 | import random
11 | from torch.utils.data.dataloader import default_collate
12 | from torch.utils.data.dataloader import DataLoader
13 | import matplotlib.pyplot as plt
14 | from torchvision.transforms import functional as F
15 | import copy
16 |
17 |
18 | # class WireframeDataset(Dataset):
19 | # def __init__(self, )
--------------------------------------------------------------------------------
/hawp/ssl/misc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sair-lab/PLNet/32b42b3b6cafdbd2267e4feac409c8810684cd78/hawp/ssl/misc/__init__.py
--------------------------------------------------------------------------------
/hawp/ssl/misc/geometry_utils.old.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | ### Point-related utils
6 |
7 | # Warp a list of points using a homography
8 | def warp_points(points, homography):
9 | # Convert to homogeneous and in xy format
10 | new_points = np.concatenate([points[..., [1, 0]],
11 | np.ones_like(points[..., :1])], axis=-1)
12 | # Warp
13 | new_points = (homography @ new_points.T).T
14 | # Convert back to inhomogeneous and hw format
15 | new_points = new_points[..., [1, 0]] / new_points[..., 2:]
16 | return new_points
17 |
18 |
19 | # Mask out the points that are outside of img_size
20 | def mask_points(points, img_size):
21 | mask = ((points[..., 0] >= 0)
22 | & (points[..., 0] < img_size[0])
23 | & (points[..., 1] >= 0)
24 | & (points[..., 1] < img_size[1]))
25 | return mask
26 |
27 |
28 | # Convert a tensor [N, 2] or batched tensor [B, N, 2] of N keypoints into
29 | # a grid in [-1, 1]² that can be used in torch.nn.functional.interpolate
30 | def keypoints_to_grid(keypoints, img_size):
31 | n_points = keypoints.size()[-2]
32 | device = keypoints.device
33 | grid_points = keypoints.float() * 2. / torch.tensor(
34 | img_size, dtype=torch.float, device=device) - 1.
35 | grid_points = grid_points[..., [1, 0]].view(-1, n_points, 1, 2)
36 | return grid_points
37 |
38 |
39 | # Return a 2D matrix indicating the local neighborhood of each point
40 | # for a given threshold and two lists of corresponding keypoints
41 | def get_dist_mask(kp0, kp1, valid_mask, dist_thresh):
42 | b_size, n_points, _ = kp0.size()
43 | dist_mask0 = torch.norm(kp0.unsqueeze(2) - kp0.unsqueeze(1), dim=-1)
44 | dist_mask1 = torch.norm(kp1.unsqueeze(2) - kp1.unsqueeze(1), dim=-1)
45 | dist_mask = torch.min(dist_mask0, dist_mask1)
46 | dist_mask = dist_mask <= dist_thresh
47 | dist_mask = dist_mask.repeat(1, 1, b_size).reshape(b_size * n_points,
48 | b_size * n_points)
49 | dist_mask = dist_mask[valid_mask, :][:, valid_mask]
50 | return dist_mask
51 |
52 |
53 | ### Line-related utils
54 |
55 | # Sample n points along lines of shape (num_lines, 2, 2)
56 | def sample_line_points(lines, n):
57 | line_points_x = np.linspace(lines[:, 0, 0], lines[:, 1, 0], n, axis=-1)
58 | line_points_y = np.linspace(lines[:, 0, 1], lines[:, 1, 1], n, axis=-1)
59 | line_points = np.stack([line_points_x, line_points_y], axis=2)
60 | return line_points
61 |
62 |
63 | # Return a mask of the valid lines that are within a valid mask of an image
64 | def mask_lines(lines, valid_mask):
65 | h, w = valid_mask.shape
66 | int_lines = np.clip(np.round(lines).astype(int), 0, [h - 1, w - 1])
67 | h_valid = valid_mask[int_lines[:, 0, 0], int_lines[:, 0, 1]]
68 | w_valid = valid_mask[int_lines[:, 1, 0], int_lines[:, 1, 1]]
69 | valid = h_valid & w_valid
70 | return valid
71 |
72 |
73 | # Return a 2D matrix indicating for each pair of points
74 | # if they are on the same line or not
75 | def get_common_line_mask(line_indices, valid_mask):
76 | b_size, n_points = line_indices.shape
77 | common_mask = line_indices[:, :, None] == line_indices[:, None, :]
78 | common_mask = common_mask.repeat(1, 1, b_size).reshape(b_size * n_points,
79 | b_size * n_points)
80 | common_mask = common_mask[valid_mask, :][:, valid_mask]
81 | return common_mask
82 |
--------------------------------------------------------------------------------
/hawp/ssl/misc/train_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains some useful functions for train / val.
3 | """
4 | import os
5 | import numpy as np
6 | import torch
7 |
8 |
9 | #################
10 | ## image utils ##
11 | #################
12 | def convert_image(input_tensor, axis):
13 | """ Convert single channel images to 3-channel images. """
14 | image_lst = [input_tensor for _ in range(3)]
15 | outputs = np.concatenate(image_lst, axis)
16 | return outputs
17 |
18 |
19 | ######################
20 | ## checkpoint utils ##
21 | ######################
22 | def get_latest_checkpoint(checkpoint_root, checkpoint_name,
23 | device=torch.device("cuda")):
24 | """ Get the latest checkpoint or by filename. """
25 | # Load specific checkpoint
26 | if checkpoint_name is not None:
27 | checkpoint = torch.load(
28 | os.path.join(checkpoint_root, checkpoint_name),
29 | map_location=device)
30 | # Load the latest checkpoint
31 | else:
32 | lastest_checkpoint = sorted(os.listdir(os.path.join(
33 | checkpoint_root, "*.tar")))[-1]
34 | checkpoint = torch.load(os.path.join(
35 | checkpoint_root, lastest_checkpoint), map_location=device)
36 | return checkpoint
37 |
38 |
39 | def remove_old_checkpoints(checkpoint_root, max_ckpt=15):
40 | """ Remove the outdated checkpoints. """
41 | # Get sorted list of checkpoints
42 | checkpoint_list = sorted(
43 | [_ for _ in os.listdir(os.path.join(checkpoint_root))
44 | if _.endswith(".tar")])
45 |
46 | # Get the checkpoints to be removed
47 | if len(checkpoint_list) > max_ckpt:
48 | remove_list = checkpoint_list[:-max_ckpt]
49 | for _ in remove_list:
50 | full_name = os.path.join(checkpoint_root, _)
51 | os.remove(full_name)
52 | print("[Debug] Remove outdated checkpoint %s" % (full_name))
53 |
54 |
55 | ################
56 | ## HDF5 utils ##
57 | ################
58 | def parse_h5_data(h5_data):
59 | """ Parse h5 dataset. """
60 | output_data = {}
61 | for key in h5_data.keys():
62 | output_data[key] = np.array(h5_data[key])
63 |
64 | return output_data
65 |
--------------------------------------------------------------------------------
/hawp/ssl/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # from . import detector, detect_with_heatmaps
3 | # from . import detect_with_heatmaps
4 | from .detector import HAWP
5 | from . import detector
6 | from . import detector_with_heatmap
7 | from . import detector_with_hrheat
8 | from .registry import MODELS
--------------------------------------------------------------------------------
/hawp/ssl/models/heatmap_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # The pixel shuffle decoder
6 | class PixelShuffleDecoder(nn.Module):
7 | def __init__(self, input_feat_dim=128, num_upsample=2, output_channel=2):
8 | super(PixelShuffleDecoder, self).__init__()
9 | # Get channel parameters
10 | self.channel_conf = self.get_channel_conf(num_upsample)
11 |
12 | # Define the pixel shuffle
13 | self.pixshuffle = nn.PixelShuffle(2)
14 |
15 | # Process the feature
16 | self.conv_block_lst = []
17 | # The input block
18 | self.conv_block_lst.append(
19 | nn.Sequential(
20 | nn.Conv2d(input_feat_dim, self.channel_conf[0], kernel_size=3, stride=1, padding=1),
21 | nn.BatchNorm2d(self.channel_conf[0]),
22 | nn.ReLU(inplace=True)
23 | ))
24 |
25 | # Intermediate block
26 | for channel in self.channel_conf[1:-1]:
27 | self.conv_block_lst.append(
28 | nn.Sequential(
29 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),
30 | nn.BatchNorm2d(channel),
31 | nn.ReLU(inplace=True)
32 | ))
33 |
34 | # Output block
35 | self.conv_block_lst.append(
36 | nn.Conv2d(self.channel_conf[-1], output_channel, kernel_size=1, stride=1, padding=0)
37 | )
38 | self.conv_block_lst = nn.ModuleList(self.conv_block_lst)
39 |
40 | # Get num of channels based on number of upsampling.
41 | def get_channel_conf(self, num_upsample):
42 | if num_upsample == 2:
43 | return [256, 64, 16]
44 | elif num_upsample == 3:
45 | return [256, 64, 16, 4]
46 |
47 | def forward(self, input_features):
48 | # Iterate til output block
49 | out = input_features
50 | for block in self.conv_block_lst[:-1]:
51 | out = block(out)
52 | out = self.pixshuffle(out)
53 |
54 | # Output layer
55 | out = self.conv_block_lst[-1](out)
56 |
57 | return out
58 |
--------------------------------------------------------------------------------
/hawp/ssl/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | def cross_entropy_loss_for_junction(logits, positive):
4 | nlogp = -F.log_softmax(logits, dim=1)
5 |
6 | loss = (positive * nlogp[:, None, 1] + (1 - positive) * nlogp[:, None, 0])
7 |
8 | return loss.mean()
9 |
10 | def focal_loss_for_junction(logits, positive, gamma=2.0):
11 | prob = F.softmax(logits, 1)
12 | ce_loss = F.cross_entropy(logits, positive, reduction='none')
13 | p_t = prob[:,1:]*positive + prob[:,:1]*(1-positive)
14 | loss = ce_loss * ((1-p_t)**gamma)
15 |
16 | return loss.mean()
17 |
18 | def sigmoid_l1_loss(logits, targets, offset = 0.0, mask=None):
19 | logp = torch.sigmoid(logits) + offset
20 | loss = torch.abs(logp-targets)
21 |
22 | if mask is not None:
23 | w = mask.mean(3, True).mean(2,True)
24 | w[w==0] = 1
25 | loss = loss*(mask/w)
26 |
27 | return loss.mean()
28 |
29 | def sigmoid_focal_loss(
30 | inputs: torch.Tensor,
31 | targets: torch.Tensor,
32 | alpha: float = -1,
33 | gamma: float = 2,
34 | reduction: str = "none",
35 | ) -> torch.Tensor:
36 | """
37 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
38 | Args:
39 | inputs: A float tensor of arbitrary shape.
40 | The predictions for each example.
41 | targets: A float tensor with the same shape as inputs. Stores the binary
42 | classification label for each element in inputs
43 | (0 for the negative class and 1 for the positive class).
44 | alpha: (optional) Weighting factor in range (0,1) to balance
45 | positive vs negative examples. Default = -1 (no weighting).
46 | gamma: Exponent of the modulating factor (1 - p_t) to
47 | balance easy vs hard examples.
48 | reduction: 'none' | 'mean' | 'sum'
49 | 'none': No reduction will be applied to the output.
50 | 'mean': The output will be averaged.
51 | 'sum': The output will be summed.
52 | Returns:
53 | Loss tensor with the reduction option applied.
54 | """
55 | inputs = inputs.float()
56 | targets = targets.float()
57 | p = torch.sigmoid(inputs)
58 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
59 | p_t = p * targets + (1 - p) * (1 - targets)
60 | loss = ce_loss * ((1 - p_t) ** gamma)
61 |
62 | if alpha >= 0:
63 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
64 | loss = alpha_t * loss
65 |
66 | if reduction == "mean":
67 | loss = loss.mean()
68 | elif reduction == "sum":
69 | loss = loss.sum()
70 |
71 | return loss
--------------------------------------------------------------------------------
/hawp/ssl/models/registry.py:
--------------------------------------------------------------------------------
1 | from hawp.base.utils.registry import Registry
2 |
3 | MODELS = Registry()
4 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | cython
3 | matplotlib
4 | yacs
5 | scikit-image
6 | tqdm
7 | python-json-logger
8 | h5py
9 | shapely
10 | pycolmap
11 | seaborn
12 | kornia
13 | easydict
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | from setuptools import find_packages
5 | from setuptools import setup
6 |
7 | setup(
8 | name="hawp",
9 | version="1.0",
10 | author="nxue",
11 | description="Holistically-Attracted Wireframe Parsing",
12 | packages=find_packages(),
13 | install_requires=[
14 | "torch",
15 | "torchvision",
16 | "opencv-python",
17 | "cython",
18 | "matplotlib",
19 | "yacs",
20 | "scikit-image",
21 | "tqdm",
22 | "python-json-logger",
23 | "h5py",
24 | "shapely",
25 | "seaborn",
26 | "easydict",
27 | ],
28 | extras_require={
29 | "dev": [
30 | "pycolmap",
31 | ]
32 | }
33 | )
34 |
--------------------------------------------------------------------------------