├── .idea
├── .gitignore
├── CurriculumLoc.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── ProcessData
├── gen_gt_files.py
├── gen_images_file.py
└── rgb_norm.py
├── README.md
├── __pycache__
├── cnnmatchs.cpython-310.pyc
├── cnnmatchs.cpython-38.pyc
├── plotmatch.cpython-310.pyc
└── plotmatch.cpython-38.pyc
├── environment.yaml
├── fig
├── dataset.png
└── outline.png
├── lib
├── __pycache__
│ ├── cnn_feature.cpython-310.pyc
│ ├── cnn_feature.cpython-38.pyc
│ ├── dataset.cpython-38.pyc
│ ├── dataset_terra.cpython-310.pyc
│ ├── dataset_terra.cpython-38.pyc
│ ├── exceptions.cpython-310.pyc
│ ├── exceptions.cpython-38.pyc
│ ├── loss.cpython-310.pyc
│ ├── loss.cpython-38.pyc
│ ├── model.cpython-38.pyc
│ ├── model_d2.cpython-38.pyc
│ ├── model_swin.cpython-310.pyc
│ ├── model_swin.cpython-38.pyc
│ ├── model_swin_unet_d2.cpython-38.pyc
│ ├── model_test.cpython-310.pyc
│ ├── model_test.cpython-38.pyc
│ ├── pyramid.cpython-310.pyc
│ ├── pyramid.cpython-38.pyc
│ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc
│ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc
│ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc.139907302345344
│ ├── swin_unet.cpython-310.pyc
│ ├── swin_unet.cpython-38.pyc
│ ├── utils.cpython-310.pyc
│ └── utils.cpython-38.pyc
├── backbone_model
│ ├── __pycache__
│ │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc
│ │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc
│ │ ├── swin_unet.cpython-310.pyc
│ │ ├── swin_unet.cpython-38.pyc
│ │ └── unet.cpython-38.pyc
│ ├── erfnet.py
│ ├── swin_transformer_unet_skip_expand_decoder_sys.py
│ ├── swin_unet.py
│ └── unet.py
├── dataset_terra.py
├── exceptions.py
├── full_model
│ ├── __pycache__
│ │ ├── model_d2.cpython-310.pyc
│ │ ├── model_d2.cpython-38.pyc
│ │ ├── model_swin_unet_d2.cpython-310.pyc
│ │ ├── model_swin_unet_d2.cpython-38.pyc
│ │ └── model_unet.cpython-38.pyc
│ ├── model_erf.py
│ ├── model_swin_unet_d2.py
│ ├── model_unet.py
│ └── model_vit.py
├── model_swin.py
├── model_test.py
├── swin_transformer_unet_skip_expand_decoder_sys.py
├── swin_unet.py
└── utils.py
├── match_localization.py
├── matching.py
├── performance.ini
├── plotmatch.py
├── result
├── cv-match.py
└── plot_rs.py
├── terratrack_utils
├── compute_mean_std.py
├── depth_bin2array.py
├── depth_png2h5.py
├── preprocess_terra.py
├── train_scenes_500.txt
├── train_scenes_origin.txt
├── undistort_reconstructions_terra.py
├── valid_scenes_500.txt
└── valid_scenes_origin.txt
└── train_terra.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/CurriculumLoc.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/ProcessData/gen_gt_files.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import csv
3 | from os.path import join
4 | import pandas as pd
5 |
6 | # data_q = []
7 | # data_db = []
8 | #
9 | path = '/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/'
10 | #
11 | # with open(join(path, 'reference.csv'), 'r') as csv_file_db:
12 | # csv_data = csv.DictReader(csv_file_db)
13 | # for row in csv_data:
14 | # data_db.append(row)
15 | #
16 | # with open(join(path, 'query.csv'), 'r') as csv_file_q:
17 | # csv_data_q = csv.DictReader(csv_file_q)
18 | # for row in csv_data_q:
19 | # data_q.append(row)
20 | #
21 | # numQ = len(data_q)
22 | # numDb = len(data_db)
23 | #
24 | # utmDb = [[float(f["easting"]), float(f["northing"])] for f in data_db]
25 | # utmQ = [[float(f["easting"]), float(f["northing"])] for f in data_q]
26 | posDistThr = 50
27 | qData = pd.read_csv(join(path, 'query.csv'))
28 | qData = qData.sort_values(by='name', ascending=True)
29 | dbData = pd.read_csv(join(path, 'reference.csv'))
30 | dbData = dbData.sort_values(by='name', ascending=True)
31 | utmQ = qData[['easting', 'northing']].values.reshape(-1, 2)
32 | utmDb = dbData[['easting', 'northing']].values.reshape(-1, 2)
33 |
34 | np.savez('../patchnetvlad/dataset_gt_files/mavic_npu_dist50', utmQ=utmQ, utmDb=utmDb, posDistThr=posDistThr)
35 |
36 | print("Save Done!")
37 |
--------------------------------------------------------------------------------
/ProcessData/gen_images_file.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 |
5 | parser = argparse.ArgumentParser(description='origin data path')
6 | parser.add_argument('--data_path', type=str, default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/reference_images_500', help='origin images path')
7 | parser.add_argument('--save_path', type=str, default='../patchnetvlad/dataset_imagenames/mavic_npu_reference_imageNames_index.txt',
8 | help='save images txt path')
9 |
10 |
11 | opt = parser.parse_args()
12 | print(opt)
13 |
14 |
15 | for dir in os.listdir(opt.data_path):
16 | sub_data_path = os.path.join(opt.data_path, dir)
17 | if os.path.isdir(sub_data_path):
18 | for f in os.listdir(sub_data_path):
19 | image_name = os.path.join(sub_data_path, f)
20 | image_name_w = os.path.join(*(image_name.split('/')[-3:]))
21 | with open(opt.save_path, 'a') as txt:
22 | txt.write(image_name_w)
23 | txt.write('\n')
24 | else:
25 | image_name = os.path.join(opt.data_path, dir)
26 | # image_name_w = os.path.join(*(image_name.split('/')[-2:]))
27 | with open(opt.save_path, 'a') as txt:
28 | # txt.write(image_name_w)
29 | txt.write(image_name)
30 | txt.write('\n')
31 |
--------------------------------------------------------------------------------
/ProcessData/rgb_norm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | def compute(img, min_percentile, max_percentile):
6 | """计算分位点,目的是去掉图1的直方图两头的异常情况"""
7 |
8 |
9 | max_percentile_pixel = np.percentile(img, max_percentile)
10 | min_percentile_pixel = np.percentile(img, min_percentile)
11 |
12 | return max_percentile_pixel, min_percentile_pixel
13 |
14 |
15 | def aug(src):
16 | """图像亮度增强"""
17 | if get_lightness(src) > 130:
18 | print("图片亮度足够,不做增强")
19 | # 先计算分位点,去掉像素值中少数异常值,这个分位点可以自己配置。
20 | # 比如1中直方图的红色在0到255上都有值,但是实际上像素值主要在0到20内。
21 |
22 |
23 | max_percentile_pixel, min_percentile_pixel = compute(src, 1, 99)
24 |
25 | # 去掉分位值区间之外的值
26 | src[src >= max_percentile_pixel] = max_percentile_pixel
27 | src[src <= min_percentile_pixel] = min_percentile_pixel
28 |
29 | # 将分位值区间拉伸到0到255,这里取了255*0.1与255*0.9是因为可能会出现像素值溢出的情况,所以最好不要设置为0到255。
30 | out = np.zeros(src.shape, src.dtype)
31 | cv2.normalize(src, out, 255 * 0.1, 255 * 0.9, cv2.NORM_MINMAX)
32 |
33 | return out
34 |
35 |
36 | def get_lightness(src):
37 | # 计算亮度
38 | hsv_image = cv2.cvtColor(src, cv2.COLOR_BGR2HSV)
39 | lightness = hsv_image[:, :, 2].mean()
40 |
41 | return lightness
42 |
43 |
44 |
45 | img = cv2.imread("/home/a409/users/chenlin/VPR_huboni/Val/GPR_Dataset/query_images/000000.png", cv2.IMREAD_UNCHANGED)
46 | img = aug(img)
47 | print("ppppp")
48 | cv2.imwrite('/home/a409/users/chenlin/VPR_huboni/Val/GPR_Dataset/output/000000.png', img)
49 |
50 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CurriculumLoc for Visual Geo-localization
2 |
3 |
4 |
5 | ## Introduction
6 | CurriculumLoc is a PyTorch implementation for our paper ["CurriculumLoc: Enhancing Cross-Domain Geolocalization through Multi-Stage Refinement"](https://arxiv.org/abs/2311.11604). If you use this code for your research, please cite our paper. For additional questions contact us via huboni@mail.nwpu.edu.cn or huboni7@gmail.com.
7 |
8 |
9 | ## Installation
10 |
11 | We test this repo with Python 3.10, PyTorch 1.12.1, and CUDA 11.3. However, it should be runnable with recent PyTorch versions. You can install by conda with our prove environment.taml.
12 |
13 | ```shell
14 | conda env create -f environment.yaml
15 | ```
16 |
17 |
18 | ## Preparation
19 |
20 | We test our models on two datasets. One dataset is [ALTO](https://github.com/MetaSLAM/ALTO), it can be download at [here](https://github.com/MetaSLAM/ALTO). Another is our TerraTrack, TerraTrack is being prepared.... All of these datasets contain some challenging environmental variations, as shown in below table.
21 |
22 |
23 |
24 |
25 | ## Training
26 |
27 | ```
28 | python train_terra.py
29 | ```
30 |
31 |
32 |
33 | ## Testing
34 |
35 | ```
36 | python match_localization.py
37 | ```
38 |
39 |
40 |
41 | ## Citation
42 |
43 | If you're using CurculumLoc in your research or applications, please cite using this BibTeX:
44 |
45 | ```bibtex
46 | @misc{hu2023curriculumloc,
47 | title={CurriculumLoc: Enhancing Cross-Domain Geolocalization through Multi-Stage Refinement},
48 | author={Boni Hu and Lin Chen and Runjian Chen and Shuhui Bu and Pengcheng Han and Haowei Li},
49 | year={2023},
50 | eprint={2311.11604},
51 | archivePrefix={arXiv},
52 | primaryClass={cs.CV}
53 | }
54 | ```
55 |
56 |
--------------------------------------------------------------------------------
/__pycache__/cnnmatchs.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/__pycache__/cnnmatchs.cpython-310.pyc
--------------------------------------------------------------------------------
/__pycache__/cnnmatchs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/__pycache__/cnnmatchs.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/plotmatch.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/__pycache__/plotmatch.cpython-310.pyc
--------------------------------------------------------------------------------
/__pycache__/plotmatch.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/__pycache__/plotmatch.cpython-38.pyc
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: cnn-matching
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - blosc=1.21.3=h6a678d5_0
9 | - boost-cpp=1.73.0=h7f8727e_12
10 | - bzip2=1.0.8=h7b6447c_0
11 | - c-ares=1.18.1=h7f8727e_0
12 | - ca-certificates=2023.01.10=h06a4308_0
13 | - cairo=1.16.0=hb05425b_3
14 | - cfitsio=3.470=hf0d0db6_6
15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
16 | - cudatoolkit=11.3.1=h2bc3f7f_2
17 | - curl=7.87.0=h5eee18b_0
18 | - expat=2.4.9=h6a678d5_0
19 | - ffmpeg=4.3=hf484d3e_0
20 | - flit-core=3.6.0=pyhd3eb1b0_0
21 | - fontconfig=2.14.1=h52c9d5c_1
22 | - freetype=2.12.1=h4a9f257_0
23 | - freexl=1.0.6=h27cfd23_0
24 | - geos=3.8.0=he6710b0_0
25 | - geotiff=1.7.0=hd69d5b1_0
26 | - giflib=5.2.1=h5eee18b_1
27 | - glib=2.69.1=he621ea3_2
28 | - gmp=6.2.1=h295c915_3
29 | - gnutls=3.6.15=he1e5248_0
30 | - hdf4=4.2.13=h3ca952b_2
31 | - hdf5=1.10.6=hb1b8bf9_0
32 | - icu=58.2=he6710b0_3
33 | - intel-openmp=2022.1.0=h9e868ea_3769
34 | - jpeg=9e=h7f8727e_0
35 | - json-c=0.16=h5eee18b_0
36 | - kealib=1.5.0=hd940352_0
37 | - krb5=1.19.4=h568e23c_0
38 | - lame=3.100=h7b6447c_0
39 | - lcms2=2.12=h3be6417_0
40 | - ld_impl_linux-64=2.38=h1181459_1
41 | - lerc=3.0=h295c915_0
42 | - libboost=1.73.0=h28710b8_12
43 | - libcurl=7.87.0=h91b91d3_0
44 | - libdeflate=1.8=h7f8727e_5
45 | - libedit=3.1.20221030=h5eee18b_0
46 | - libev=4.33=h7f8727e_1
47 | - libffi=3.4.2=h6a678d5_6
48 | - libgcc-ng=11.2.0=h1234567_1
49 | - libgdal=3.6.0=hc0e11bb_0
50 | - libgfortran-ng=7.5.0=ha8ba4b0_17
51 | - libgfortran4=7.5.0=ha8ba4b0_17
52 | - libgomp=11.2.0=h1234567_1
53 | - libiconv=1.16=h7f8727e_2
54 | - libidn2=2.3.2=h7f8727e_0
55 | - libkml=1.3.0=h096b73e_6
56 | - libnetcdf=4.8.1=h8322cc2_2
57 | - libnghttp2=1.46.0=hce63b2e_0
58 | - libpng=1.6.37=hbc83047_0
59 | - libpq=12.9=h16c4e8d_3
60 | - libspatialite=4.3.0a=h71b31bf_21
61 | - libssh2=1.10.0=h8f2d780_0
62 | - libstdcxx-ng=11.2.0=h1234567_1
63 | - libtasn1=4.16.0=h27cfd23_0
64 | - libtiff=4.5.0=h6a678d5_1
65 | - libunistring=0.9.10=h27cfd23_0
66 | - libuuid=1.41.5=h5eee18b_0
67 | - libwebp=1.2.4=h11a3e52_0
68 | - libwebp-base=1.2.4=h5eee18b_0
69 | - libxcb=1.15=h7f8727e_0
70 | - libxml2=2.9.14
71 | - libzip=1.8.0=h5cef20c_0
72 | - lz4-c=1.9.4=h6a678d5_0
73 | - mkl=2020.2=256
74 | - ncurses=6.4=h6a678d5_0
75 | - nettle=3.7.3=hbbd107a_1
76 | - nspr=4.33=h295c915_0
77 | - nss=3.74=h0370c37_0
78 | - openh264=2.1.1=h4ff587b_0
79 | - openjpeg=2.4.0=h3ad879b_0
80 | - openssl=1.1.1t=h7f8727e_0
81 | - pcre=8.45=h295c915_0
82 | - pcre2=10.37=he7ceb23_1
83 | - pixman=0.40.0=h7f8727e_1
84 | - poppler=22.12.0=h381b16e_0
85 | - poppler-data=0.4.11=h06a4308_1
86 | - postgresql=12.9=h16c4e8d_3
87 | - proj=6.2.1=h05a3930_0
88 | - pycparser=2.21=pyhd3eb1b0_0
89 | - pyopenssl=22.0.0=pyhd3eb1b0_0
90 | - python=3.10.9=h7a1cb2a_0
91 | - pytorch=1.12.1=py3.10_cuda11.3_cudnn8.3.2_0
92 | - pytorch-mutex=1.0=cuda
93 | - qhull=2020.2=hdb19cb5_2
94 | - readline=8.2=h5eee18b_0
95 | - sqlite=3.40.1=h5082296_0
96 | - tiledb=2.3.3=h1132f93_2
97 | - tk=8.6.12=h1ccaba5_0
98 | - typing_extensions=4.4.0=py310h06a4308_0
99 | - tzdata=2022g=h04d1e81_0
100 | - xerces-c=3.2.4=h94c2ce2_0
101 | - xz=5.2.10=h5eee18b_1
102 | - zlib=1.2.13=h5eee18b_0
103 | - zstd=1.5.2=ha4553b6_0
104 | - pip:
105 | - brotlipy==0.7.0
106 | - certifi==2022.12.7
107 | - cffi==1.15.1
108 | - contourpy==1.0.7
109 | - cryptography==38.0.4
110 | - cycler==0.11.0
111 | - einops==0.7.0
112 | - filelock==3.12.4
113 | - fonttools==4.38.0
114 | - fsspec==2023.9.2
115 | - gdal==3.6.0
116 | - huggingface-hub==0.17.3
117 | - idna==3.4
118 | - imageio==2.25.1
119 | - joblib==1.2.0
120 | - kiwisolver==1.4.4
121 | - matplotlib==3.7.0
122 | - networkx==3.0
123 | - numpy==1.22.3
124 | - opencv-python==4.7.0.68
125 | - packaging==23.0
126 | - pandas==1.5.3
127 | - pillow==9.3.0
128 | - pip==22.3.1
129 | - pyparsing==3.0.9
130 | - pysocks==1.7.1
131 | - python-dateutil==2.8.2
132 | - pytz==2022.7.1
133 | - pywavelets==1.4.1
134 | - pyyaml==6.0.1
135 | - requests==2.28.1
136 | - safetensors==0.4.0
137 | - scikit-image==0.19.3
138 | - scikit-learn==1.2.1
139 | - scipy==1.10.0
140 | - six==1.16.0
141 | - threadpoolctl==3.1.0
142 | - tifffile==2023.2.3
143 | - timm==0.9.7
144 | - torch==1.12.1
145 | - torchaudio==0.12.1
146 | - torchvision==0.13.1
147 | - tqdm==4.64.1
148 | - typing-extensions==4.4.0
149 | - urllib3==1.26.14
150 | prefix: /home/a409_home/anaconda3/envs/cnn-matching
151 |
--------------------------------------------------------------------------------
/fig/dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/fig/dataset.png
--------------------------------------------------------------------------------
/fig/outline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/fig/outline.png
--------------------------------------------------------------------------------
/lib/__pycache__/cnn_feature.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/cnn_feature.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/cnn_feature.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/cnn_feature.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataset_terra.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/dataset_terra.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataset_terra.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/dataset_terra.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/exceptions.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/exceptions.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/exceptions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/exceptions.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/loss.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/loss.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/model_d2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_d2.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/model_swin.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_swin.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/model_swin.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_swin.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/model_swin_unet_d2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_swin_unet_d2.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/model_test.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_test.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/model_test.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_test.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/pyramid.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/pyramid.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/pyramid.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/pyramid.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc.139907302345344:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc.139907302345344
--------------------------------------------------------------------------------
/lib/__pycache__/swin_unet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/swin_unet.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/swin_unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/swin_unet.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/backbone_model/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/backbone_model/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/backbone_model/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/backbone_model/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/backbone_model/__pycache__/swin_unet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/backbone_model/__pycache__/swin_unet.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/backbone_model/__pycache__/swin_unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/backbone_model/__pycache__/swin_unet.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/backbone_model/__pycache__/unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/backbone_model/__pycache__/unet.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/backbone_model/erfnet.py:
--------------------------------------------------------------------------------
1 | # ERFNet full model definition for Pytorch
2 | # Sept 2017
3 | # Eduardo Romera
4 | #######################
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.init as init
9 | import torch.nn.functional as F
10 |
11 |
12 | class DownsamplerBlock (nn.Module):
13 | def __init__(self, ninput, noutput):
14 | super().__init__()
15 |
16 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True)
17 | self.pool = nn.MaxPool2d(2, stride=2)
18 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3)
19 |
20 | def forward(self, input):
21 | output = torch.cat([self.conv(input), self.pool(input)], 1)
22 | output = self.bn(output)
23 | return F.relu(output)
24 |
25 |
26 | class non_bottleneck_1d (nn.Module):
27 | def __init__(self, chann, dropprob, dilated):
28 | super().__init__()
29 |
30 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
31 |
32 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
33 |
34 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
35 |
36 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(
37 | 1*dilated, 0), bias=True, dilation=(dilated, 1))
38 |
39 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(
40 | 0, 1*dilated), bias=True, dilation=(1, dilated))
41 |
42 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
43 |
44 | self.dropout = nn.Dropout2d(dropprob)
45 |
46 | def forward(self, input):
47 |
48 | output = self.conv3x1_1(input)
49 | output = F.relu(output)
50 | output = self.conv1x3_1(output)
51 | output = self.bn1(output)
52 | output = F.relu(output)
53 |
54 | output = self.conv3x1_2(output)
55 | output = F.relu(output)
56 | output = self.conv1x3_2(output)
57 | output = self.bn2(output)
58 |
59 | if (self.dropout.p != 0):
60 | output = self.dropout(output)
61 |
62 | return F.relu(output+input) # +input = identity (residual connection)
63 |
64 |
65 | class Encoder(nn.Module):
66 | def __init__(self, num_classes):
67 | super().__init__()
68 | self.initial_block = DownsamplerBlock(3, 16)
69 |
70 | self.layers = nn.ModuleList()
71 |
72 | self.layers.append(DownsamplerBlock(16, 64))
73 |
74 | for x in range(0, 5): # 5 times
75 | self.layers.append(non_bottleneck_1d(64, 0.03, 1))
76 |
77 | self.layers.append(DownsamplerBlock(64, 128))
78 |
79 | for x in range(0, 2): # 2 times
80 | self.layers.append(non_bottleneck_1d(128, 0.3, 2))
81 | self.layers.append(non_bottleneck_1d(128, 0.3, 4))
82 | self.layers.append(non_bottleneck_1d(128, 0.3, 8))
83 | self.layers.append(non_bottleneck_1d(128, 0.3, 16))
84 |
85 | # Only in encoder mode:
86 | # self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True)
87 |
88 | def forward(self, input):
89 | output = self.initial_block(input)
90 | for layer in self.layers:
91 | output = layer(output)
92 | # print(output.size())
93 |
94 | return output
95 |
96 |
97 | class UpsamplerBlock (nn.Module):
98 | def __init__(self, ninput, noutput):
99 | super().__init__()
100 | self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2,
101 | padding=1, output_padding=1, bias=True)
102 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3)
103 |
104 | def forward(self, input):
105 | output = self.conv(input)
106 | output = self.bn(output)
107 | return F.relu(output)
108 |
109 |
110 | class Decoder (nn.Module):
111 | def __init__(self, num_classes):
112 | super().__init__()
113 |
114 | self.layers = nn.ModuleList()
115 |
116 | self.layers.append(UpsamplerBlock(128, 64))
117 | self.layers.append(non_bottleneck_1d(64, 0, 1))
118 | self.layers.append(non_bottleneck_1d(64, 0, 1))
119 |
120 | self.layers.append(UpsamplerBlock(64, 16))
121 | self.layers.append(non_bottleneck_1d(16, 0, 1))
122 | self.layers.append(non_bottleneck_1d(16, 0, 1))
123 |
124 | self.output_conv = nn.ConvTranspose2d(
125 | 16, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True)
126 |
127 | def forward(self, input):
128 | output = input
129 |
130 | for layer in self.layers:
131 | output = layer(output)
132 |
133 | output = self.output_conv(output)
134 |
135 | return output
136 |
137 | # ERFNet
138 |
139 |
140 | class Net(nn.Module):
141 | def __init__(self, num_classes): # use encoder to pass pretrained encoder
142 | super().__init__()
143 |
144 | self.encoder = Encoder(num_classes)
145 | self.decoder = Decoder(num_classes)
146 |
147 | def forward(self, input):
148 | output = self.encoder(input)
149 | return self.decoder.forward(output)
150 |
--------------------------------------------------------------------------------
/lib/backbone_model/swin_transformer_unet_skip_expand_decoder_sys.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 | from einops import rearrange
5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6 |
7 |
8 | class Mlp(nn.Module):
9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
10 | super().__init__()
11 | out_features = out_features or in_features
12 | hidden_features = hidden_features or in_features
13 | self.fc1 = nn.Linear(in_features, hidden_features)
14 | self.act = act_layer()
15 | self.fc2 = nn.Linear(hidden_features, out_features)
16 | self.drop = nn.Dropout(drop)
17 |
18 | def forward(self, x):
19 | x = self.fc1(x)
20 | x = self.act(x)
21 | x = self.drop(x)
22 | x = self.fc2(x)
23 | x = self.drop(x)
24 | return x
25 |
26 |
27 | def window_partition(x, window_size):
28 | """
29 | Args:
30 | x: (B, H, W, C)
31 | window_size (int): window size
32 |
33 | Returns:
34 | windows: (num_windows*B, window_size, window_size, C)
35 | """
36 | B, H, W, C = x.shape
37 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
38 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
39 | return windows
40 |
41 |
42 | def window_reverse(windows, window_size, H, W):
43 | """
44 | Args:
45 | windows: (num_windows*B, window_size, window_size, C)
46 | window_size (int): Window size
47 | H (int): Height of image
48 | W (int): Width of image
49 |
50 | Returns:
51 | x: (B, H, W, C)
52 | """
53 | B = int(windows.shape[0] / (H * W / window_size / window_size))
54 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
55 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
56 | return x
57 |
58 |
59 | class WindowAttention(nn.Module):
60 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
61 | It supports both of shifted and non-shifted window.
62 |
63 | Args:
64 | dim (int): Number of input channels.
65 | window_size (tuple[int]): The height and width of the window.
66 | num_heads (int): Number of attention heads.
67 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
68 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
69 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
70 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
71 | """
72 |
73 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
74 |
75 | super().__init__()
76 | self.dim = dim
77 | self.window_size = window_size # Wh, Ww
78 | self.num_heads = num_heads
79 | head_dim = dim // num_heads
80 | self.scale = qk_scale or head_dim ** -0.5
81 |
82 | # define a parameter table of relative position bias
83 | self.relative_position_bias_table = nn.Parameter(
84 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
85 |
86 | # get pair-wise relative position index for each token inside the window
87 | coords_h = torch.arange(self.window_size[0])
88 | coords_w = torch.arange(self.window_size[1])
89 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
90 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
91 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
92 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
93 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
94 | relative_coords[:, :, 1] += self.window_size[1] - 1
95 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
96 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
97 | self.register_buffer("relative_position_index", relative_position_index)
98 |
99 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
100 | self.attn_drop = nn.Dropout(attn_drop)
101 | self.proj = nn.Linear(dim, dim)
102 | self.proj_drop = nn.Dropout(proj_drop)
103 |
104 | trunc_normal_(self.relative_position_bias_table, std=.02)
105 | self.softmax = nn.Softmax(dim=-1)
106 |
107 | def forward(self, x, mask=None):
108 | """
109 | Args:
110 | x: input features with shape of (num_windows*B, N, C)
111 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
112 | """
113 | B_, N, C = x.shape
114 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
115 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
116 |
117 | q = q * self.scale
118 | attn = (q @ k.transpose(-2, -1))
119 |
120 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
121 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
122 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
123 | attn = attn + relative_position_bias.unsqueeze(0)
124 |
125 | if mask is not None:
126 | nW = mask.shape[0]
127 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
128 | attn = attn.view(-1, self.num_heads, N, N)
129 | attn = self.softmax(attn)
130 | else:
131 | attn = self.softmax(attn)
132 |
133 | attn = self.attn_drop(attn)
134 |
135 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
136 | x = self.proj(x)
137 | x = self.proj_drop(x)
138 | return x
139 |
140 | def extra_repr(self) -> str:
141 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
142 |
143 | def flops(self, N):
144 | # calculate flops for 1 window with token length of N
145 | flops = 0
146 | # qkv = self.qkv(x)
147 | flops += N * self.dim * 3 * self.dim
148 | # attn = (q @ k.transpose(-2, -1))
149 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
150 | # x = (attn @ v)
151 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
152 | # x = self.proj(x)
153 | flops += N * self.dim * self.dim
154 | return flops
155 |
156 |
157 | class SwinTransformerBlock(nn.Module):
158 | r""" Swin Transformer Block.
159 |
160 | Args:
161 | dim (int): Number of input channels.
162 | input_resolution (tuple[int]): Input resulotion.
163 | num_heads (int): Number of attention heads.
164 | window_size (int): Window size.
165 | shift_size (int): Shift size for SW-MSA.
166 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
167 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
168 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
169 | drop (float, optional): Dropout rate. Default: 0.0
170 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
171 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
172 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
173 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
174 | """
175 |
176 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
177 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
178 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
179 | super().__init__()
180 | self.dim = dim
181 | self.input_resolution = input_resolution
182 | self.num_heads = num_heads
183 | self.window_size = window_size
184 | self.shift_size = shift_size
185 | self.mlp_ratio = mlp_ratio
186 | if min(self.input_resolution) <= self.window_size:
187 | # if window size is larger than input resolution, we don't partition windows
188 | self.shift_size = 0
189 | self.window_size = min(self.input_resolution)
190 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
191 |
192 | self.norm1 = norm_layer(dim)
193 | self.attn = WindowAttention(
194 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
195 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
196 |
197 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
198 | self.norm2 = norm_layer(dim)
199 | mlp_hidden_dim = int(dim * mlp_ratio)
200 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
201 |
202 | if self.shift_size > 0:
203 | # calculate attention mask for SW-MSA
204 | H, W = self.input_resolution
205 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
206 | h_slices = (slice(0, -self.window_size),
207 | slice(-self.window_size, -self.shift_size),
208 | slice(-self.shift_size, None))
209 | w_slices = (slice(0, -self.window_size),
210 | slice(-self.window_size, -self.shift_size),
211 | slice(-self.shift_size, None))
212 | cnt = 0
213 | for h in h_slices:
214 | for w in w_slices:
215 | img_mask[:, h, w, :] = cnt
216 | cnt += 1
217 |
218 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
219 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
220 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
221 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
222 | else:
223 | attn_mask = None
224 |
225 | self.register_buffer("attn_mask", attn_mask)
226 |
227 | def forward(self, x):
228 | H, W = self.input_resolution
229 | B, L, C = x.shape
230 | assert L == H * W, "input feature has wrong size"
231 |
232 | shortcut = x
233 | x = self.norm1(x)
234 | x = x.view(B, H, W, C)
235 |
236 | # cyclic shift
237 | if self.shift_size > 0:
238 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
239 | else:
240 | shifted_x = x
241 |
242 | # partition windows
243 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
244 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
245 |
246 | # W-MSA/SW-MSA
247 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
248 |
249 | # merge windows
250 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
251 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
252 |
253 | # reverse cyclic shift
254 | if self.shift_size > 0:
255 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
256 | else:
257 | x = shifted_x
258 | x = x.view(B, H * W, C)
259 |
260 | # FFN
261 | x = shortcut + self.drop_path(x)
262 | x = x + self.drop_path(self.mlp(self.norm2(x)))
263 |
264 | return x
265 |
266 | def extra_repr(self) -> str:
267 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
268 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
269 |
270 | def flops(self):
271 | flops = 0
272 | H, W = self.input_resolution
273 | # norm1
274 | flops += self.dim * H * W
275 | # W-MSA/SW-MSA
276 | nW = H * W / self.window_size / self.window_size
277 | flops += nW * self.attn.flops(self.window_size * self.window_size)
278 | # mlp
279 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
280 | # norm2
281 | flops += self.dim * H * W
282 | return flops
283 |
284 |
285 | class PatchMerging(nn.Module):
286 | r""" Patch Merging Layer.
287 |
288 | Args:
289 | input_resolution (tuple[int]): Resolution of input feature.
290 | dim (int): Number of input channels.
291 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
292 | """
293 |
294 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
295 | super().__init__()
296 | self.input_resolution = input_resolution
297 | self.dim = dim
298 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
299 | self.norm = norm_layer(4 * dim)
300 |
301 | def forward(self, x):
302 | """
303 | x: B, H*W, C
304 | """
305 | H, W = self.input_resolution
306 | B, L, C = x.shape
307 | assert L == H * W, "input feature has wrong size"
308 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
309 |
310 | x = x.view(B, H, W, C)
311 |
312 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
313 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
314 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
315 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
316 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
317 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
318 |
319 | x = self.norm(x)
320 | x = self.reduction(x)
321 |
322 | return x
323 |
324 | def extra_repr(self) -> str:
325 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
326 |
327 | def flops(self):
328 | H, W = self.input_resolution
329 | flops = H * W * self.dim
330 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
331 | return flops
332 |
333 | class PatchExpand(nn.Module):
334 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
335 | super().__init__()
336 | self.input_resolution = input_resolution
337 | self.dim = dim
338 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
339 | self.norm = norm_layer(dim // dim_scale)
340 |
341 | def forward(self, x):
342 | """
343 | x: B, H*W, C
344 | """
345 | H, W = self.input_resolution
346 | x = self.expand(x)
347 | B, L, C = x.shape
348 | assert L == H * W, "input feature has wrong size"
349 |
350 | x = x.view(B, H, W, C)
351 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
352 | x = x.view(B,-1,C//4)
353 | x= self.norm(x)
354 |
355 | return x
356 |
357 | class FinalPatchExpand_X4(nn.Module):
358 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
359 | super().__init__()
360 | self.input_resolution = input_resolution
361 | self.dim = dim
362 | self.dim_scale = dim_scale
363 | self.expand = nn.Linear(dim, 16*dim, bias=False)
364 | self.output_dim = dim
365 | self.norm = norm_layer(self.output_dim)
366 |
367 | def forward(self, x):
368 | """
369 | x: B, H*W, C
370 | """
371 | H, W = self.input_resolution
372 | x = self.expand(x)
373 | B, L, C = x.shape
374 | assert L == H * W, "input feature has wrong size"
375 |
376 | x = x.view(B, H, W, C)
377 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
378 | x = x.view(B,-1,self.output_dim)
379 | x= self.norm(x)
380 |
381 | return x
382 |
383 | class BasicLayer(nn.Module):
384 | """ A basic Swin Transformer layer for one stage.
385 |
386 | Args:
387 | dim (int): Number of input channels.
388 | input_resolution (tuple[int]): Input resolution.
389 | depth (int): Number of blocks.
390 | num_heads (int): Number of attention heads.
391 | window_size (int): Local window size.
392 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
393 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
394 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
395 | drop (float, optional): Dropout rate. Default: 0.0
396 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
397 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
398 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
399 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
400 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
401 | """
402 |
403 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
404 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
405 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
406 |
407 | super().__init__()
408 | self.dim = dim
409 | self.input_resolution = input_resolution
410 | self.depth = depth
411 | self.use_checkpoint = use_checkpoint
412 |
413 | # build blocks
414 | self.blocks = nn.ModuleList([
415 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
416 | num_heads=num_heads, window_size=window_size,
417 | shift_size=0 if (i % 2 == 0) else window_size // 2,
418 | mlp_ratio=mlp_ratio,
419 | qkv_bias=qkv_bias, qk_scale=qk_scale,
420 | drop=drop, attn_drop=attn_drop,
421 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
422 | norm_layer=norm_layer)
423 | for i in range(depth)])
424 |
425 | # patch merging layer
426 | if downsample is not None:
427 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
428 | else:
429 | self.downsample = None
430 |
431 | def forward(self, x):
432 | for blk in self.blocks:
433 | if self.use_checkpoint:
434 | x = checkpoint.checkpoint(blk, x)
435 | else:
436 | x = blk(x)
437 | if self.downsample is not None:
438 | x = self.downsample(x)
439 | return x
440 |
441 | def extra_repr(self) -> str:
442 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
443 |
444 | def flops(self):
445 | flops = 0
446 | for blk in self.blocks:
447 | flops += blk.flops()
448 | if self.downsample is not None:
449 | flops += self.downsample.flops()
450 | return flops
451 |
452 | class BasicLayer_up(nn.Module):
453 | """ A basic Swin Transformer layer for one stage.
454 |
455 | Args:
456 | dim (int): Number of input channels.
457 | input_resolution (tuple[int]): Input resolution.
458 | depth (int): Number of blocks.
459 | num_heads (int): Number of attention heads.
460 | window_size (int): Local window size.
461 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
462 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
463 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
464 | drop (float, optional): Dropout rate. Default: 0.0
465 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
466 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
467 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
468 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
469 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
470 | """
471 |
472 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
473 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
474 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
475 |
476 | super().__init__()
477 | self.dim = dim
478 | self.input_resolution = input_resolution
479 | self.depth = depth
480 | self.use_checkpoint = use_checkpoint
481 |
482 | # build blocks
483 | self.blocks = nn.ModuleList([
484 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
485 | num_heads=num_heads, window_size=window_size,
486 | shift_size=0 if (i % 2 == 0) else window_size // 2,
487 | mlp_ratio=mlp_ratio,
488 | qkv_bias=qkv_bias, qk_scale=qk_scale,
489 | drop=drop, attn_drop=attn_drop,
490 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
491 | norm_layer=norm_layer)
492 | for i in range(depth)])
493 |
494 | # patch merging layer
495 | if upsample is not None:
496 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
497 | else:
498 | self.upsample = None
499 |
500 | def forward(self, x):
501 | for blk in self.blocks:
502 | if self.use_checkpoint:
503 | x = checkpoint.checkpoint(blk, x)
504 | else:
505 | x = blk(x)
506 | if self.upsample is not None:
507 | x = self.upsample(x)
508 | return x
509 |
510 | class PatchEmbed(nn.Module):
511 | r""" Image to Patch Embedding
512 |
513 | Args:
514 | img_size (int): Image size. Default: 224.
515 | patch_size (int): Patch token size. Default: 4.
516 | in_chans (int): Number of input image channels. Default: 3.
517 | embed_dim (int): Number of linear projection output channels. Default: 96.
518 | norm_layer (nn.Module, optional): Normalization layer. Default: None
519 | """
520 |
521 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
522 | super().__init__()
523 | img_size = to_2tuple(img_size)
524 | patch_size = to_2tuple(patch_size)
525 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
526 | self.img_size = img_size
527 | self.patch_size = patch_size
528 | self.patches_resolution = patches_resolution
529 | self.num_patches = patches_resolution[0] * patches_resolution[1]
530 |
531 | self.in_chans = in_chans
532 | self.embed_dim = embed_dim
533 |
534 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
535 | if norm_layer is not None:
536 | self.norm = norm_layer(embed_dim)
537 | else:
538 | self.norm = None
539 |
540 | def forward(self, x):
541 | B, C, H, W = x.shape
542 | # FIXME look at relaxing size constraints
543 | assert H == self.img_size[0] and W == self.img_size[1], \
544 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
545 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
546 | if self.norm is not None:
547 | x = self.norm(x)
548 | return x
549 |
550 | def flops(self):
551 | Ho, Wo = self.patches_resolution
552 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
553 | if self.norm is not None:
554 | flops += Ho * Wo * self.embed_dim
555 | return flops
556 |
557 |
558 | class SwinTransformerSys(nn.Module):
559 | r""" Swin Transformer
560 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
561 | https://arxiv.org/pdf/2103.14030
562 |
563 | Args:
564 | img_size (int | tuple(int)): Input image size. Default 224
565 | patch_size (int | tuple(int)): Patch size. Default: 4
566 | in_chans (int): Number of input image channels. Default: 3
567 | num_classes (int): Number of classes for classification head. Default: 1000
568 | embed_dim (int): Patch embedding dimension. Default: 96
569 | depths (tuple(int)): Depth of each Swin Transformer layer.
570 | num_heads (tuple(int)): Number of attention heads in different layers.
571 | window_size (int): Window size. Default: 7
572 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
573 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
574 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
575 | drop_rate (float): Dropout rate. Default: 0
576 | attn_drop_rate (float): Attention dropout rate. Default: 0
577 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
578 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
579 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
580 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
581 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
582 | """
583 |
584 | # FIXME embed_dim=96 check一下是512还是96
585 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
586 | embed_dim=512, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
587 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
588 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
589 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
590 | use_checkpoint=False, final_upsample="expand_first", **kwargs):
591 | super().__init__()
592 |
593 | print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,
594 | depths_decoder,drop_path_rate,num_classes))
595 |
596 | self.num_classes = num_classes
597 | self.num_layers = len(depths)
598 | self.embed_dim = embed_dim
599 | self.ape = ape
600 | self.patch_norm = patch_norm
601 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
602 | self.num_features_up = int(embed_dim * 2)
603 | self.mlp_ratio = mlp_ratio
604 | self.final_upsample = final_upsample
605 |
606 | # split image into non-overlapping patches
607 | self.patch_embed = PatchEmbed(
608 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
609 | norm_layer=norm_layer if self.patch_norm else None)
610 | num_patches = self.patch_embed.num_patches
611 | patches_resolution = self.patch_embed.patches_resolution
612 | self.patches_resolution = patches_resolution
613 |
614 | # absolute position embedding
615 | if self.ape:
616 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
617 | trunc_normal_(self.absolute_pos_embed, std=.02)
618 |
619 | self.pos_drop = nn.Dropout(p=drop_rate)
620 |
621 | # stochastic depth
622 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
623 |
624 | # build encoder and bottleneck layers
625 | self.layers = nn.ModuleList()
626 | for i_layer in range(self.num_layers):
627 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
628 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
629 | patches_resolution[1] // (2 ** i_layer)),
630 | depth=depths[i_layer],
631 | num_heads=num_heads[i_layer],
632 | window_size=window_size,
633 | mlp_ratio=self.mlp_ratio,
634 | qkv_bias=qkv_bias, qk_scale=qk_scale,
635 | drop=drop_rate, attn_drop=attn_drop_rate,
636 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
637 | norm_layer=norm_layer,
638 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
639 | use_checkpoint=use_checkpoint)
640 | self.layers.append(layer)
641 |
642 | # build decoder layers
643 | self.layers_up = nn.ModuleList()
644 | self.concat_back_dim = nn.ModuleList()
645 | for i_layer in range(self.num_layers):
646 | concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),
647 | int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()
648 | if i_layer ==0 :
649 | layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
650 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
651 | else:
652 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
653 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
654 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
655 | depth=depths[(self.num_layers-1-i_layer)],
656 | num_heads=num_heads[(self.num_layers-1-i_layer)],
657 | window_size=window_size,
658 | mlp_ratio=self.mlp_ratio,
659 | qkv_bias=qkv_bias, qk_scale=qk_scale,
660 | drop=drop_rate, attn_drop=attn_drop_rate,
661 | drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])],
662 | norm_layer=norm_layer,
663 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
664 | use_checkpoint=use_checkpoint)
665 | self.layers_up.append(layer_up)
666 | self.concat_back_dim.append(concat_linear)
667 |
668 | self.norm = norm_layer(self.num_features)
669 | self.norm_up= norm_layer(self.embed_dim)
670 | # self.norm_up= norm_layer(192)
671 |
672 | if self.final_upsample == "expand_first":
673 | print("---final upsample expand_first---")
674 | self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim)
675 | # self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False)
676 | self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=3, stride=2, padding=1, bias=False)
677 |
678 | self.output_2 = nn.Conv2d(in_channels=self.num_classes,out_channels=self.num_classes,kernel_size=3, stride=2, padding=1, bias=False)
679 |
680 | self.apply(self._init_weights)
681 |
682 | def _init_weights(self, m):
683 | if isinstance(m, nn.Linear):
684 | trunc_normal_(m.weight, std=.02)
685 | if isinstance(m, nn.Linear) and m.bias is not None:
686 | nn.init.constant_(m.bias, 0)
687 | elif isinstance(m, nn.LayerNorm):
688 | nn.init.constant_(m.bias, 0)
689 | nn.init.constant_(m.weight, 1.0)
690 |
691 | @torch.jit.ignore
692 | def no_weight_decay(self):
693 | return {'absolute_pos_embed'}
694 |
695 | @torch.jit.ignore
696 | def no_weight_decay_keywords(self):
697 | return {'relative_position_bias_table'}
698 |
699 | #Encoder and Bottleneck
700 | def forward_features(self, x):
701 | x = self.patch_embed(x)
702 | if self.ape:
703 | x = x + self.absolute_pos_embed
704 | x = self.pos_drop(x)
705 | x_downsample = []
706 |
707 | for layer in self.layers:
708 | x_downsample.append(x)
709 | x = layer(x)
710 |
711 | x = self.norm(x) # B L C [1,49,768]
712 |
713 | return x, x_downsample
714 |
715 | #Dencoder and Skip connection
716 | ## FIXME 修改网络结构 保证预训练的参数还可以用 跳过一些层实时
717 | def forward_up_features(self, x, x_downsample):
718 | for inx, layer_up in enumerate(self.layers_up):
719 | # print("swin decoder",self.layers_up)
720 | if inx == 0:
721 | x = layer_up(x)
722 | else:
723 | x = torch.cat([x,x_downsample[3-inx]],-1)
724 | x = self.concat_back_dim[inx](x)
725 | x = layer_up(x)
726 |
727 | x = self.norm_up(x) # B L C [1,3136,96]
728 |
729 | return x
730 |
731 | def up_x4(self, x):
732 | H, W = self.patches_resolution
733 |
734 | B, L, C = x.shape
735 | assert L == H*W, "input features has wrong size"
736 |
737 | if self.final_upsample=="expand_first":
738 | x = self.up(x)
739 | x = x.view(B,4*H,4*W,-1)
740 | x = x.permute(0,3,1,2) #B,C,H,W
741 | x = self.output(x) # 1/2
742 | x = self.output_2(x) # 1/4
743 | x = self.output_2(x) # 1/8
744 |
745 | ## FIXME without patch expanding
746 | # x = x.view(B, H, W, -1) # 1/4
747 | # x = x.permute(0,3,1,2) #B,C,H,W
748 | # x = self.output(x) # 1/8
749 |
750 |
751 | return x
752 |
753 | def forward(self, x):
754 | x, x_downsample = self.forward_features(x)
755 | x = self.forward_up_features(x,x_downsample)
756 | x = self.up_x4(x)
757 | # print("out shape", x.shape)
758 | # x = x.view(1,28,28,-1)
759 | # x = x.permute(0,3,2,1)
760 | # x = self.output(x)
761 | # print("output:", x.shape)
762 | # x = self.up_x4(x)
763 |
764 | return x
765 |
766 | def flops(self):
767 | flops = 0
768 | flops += self.patch_embed.flops()
769 | for i, layer in enumerate(self.layers):
770 | flops += layer.flops()
771 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
772 | flops += self.num_features * self.num_classes
773 | return flops
774 |
--------------------------------------------------------------------------------
/lib/backbone_model/swin_unet.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import copy
7 | import logging
8 | import math
9 |
10 | from os.path import join as pjoin
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 |
16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17 | from torch.nn.modules.utils import _pair
18 | from scipy import ndimage
19 | from .swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class SwinUnet(nn.Module):
25 | def __init__(self, pretrained_path, img_size=224, num_classes=21843, zero_head=False, vis=False):
26 | super(SwinUnet, self).__init__()
27 | self.num_classes = num_classes
28 | self.zero_head = zero_head
29 |
30 | self.swin_unet = SwinTransformerSys(img_size=224,
31 | patch_size=4,
32 | in_chans=3,
33 | num_classes=512,
34 | embed_dim=96,
35 | depths=[2, 2, 2, 2],
36 | num_heads=[3, 6, 12, 24],
37 | window_size=7,
38 | mlp_ratio=4,
39 | qkv_bias=True,
40 | qk_scale=None,
41 | drop_rate=0,
42 | drop_path_rate=0.1,
43 | ape=True,
44 | patch_norm=True,
45 | use_checkpoint=True)
46 |
47 | def forward(self, x):
48 | if x.size()[1] == 1:
49 | x = x.repeat(1, 3, 1, 1)
50 | logits = self.swin_unet(x)
51 | return logits
52 |
53 | def load_from(self, pretrained_path):
54 | pretrained_path = pretrained_path
55 | if pretrained_path is not None:
56 | print("pretrained_path:{}".format(pretrained_path))
57 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58 | pretrained_dict = torch.load(pretrained_path, map_location=device)
59 | if "model" not in pretrained_dict:
60 | print("---start load pretrained modle by splitting---")
61 | pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()}
62 | for k in list(pretrained_dict.keys()):
63 | if "output" in k:
64 | print("delete key:{}".format(k))
65 | del pretrained_dict[k]
66 | msg = self.swin_unet.load_state_dict(pretrained_dict, strict=False)
67 | # print(msg)
68 | return
69 | pretrained_dict = pretrained_dict['model']
70 | print("---start load pretrained modle of swin encoder---")
71 |
72 | model_dict = self.swin_unet.state_dict()
73 | full_dict = copy.deepcopy(pretrained_dict)
74 | for k, v in pretrained_dict.items():
75 | if "layers." in k:
76 | current_layer_num = 3 - int(k[7:8])
77 | current_k = "layers_up." + str(current_layer_num) + k[8:]
78 | full_dict.update({current_k: v})
79 | for k in list(full_dict.keys()):
80 | if k in model_dict:
81 | if full_dict[k].shape != model_dict[k].shape:
82 | print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape))
83 | del full_dict[k]
84 |
85 | msg = self.swin_unet.load_state_dict(full_dict, strict=False)
86 | # print(msg)
87 | else:
88 | print("none pretrain")
89 |
--------------------------------------------------------------------------------
/lib/backbone_model/unet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Baseline unet: Test baseline without any operations
3 | '''
4 |
5 | import torch.nn as nn
6 | import torch
7 |
8 | class conv_block_nested(nn.Module):
9 | def __init__(self, in_ch, mid_ch, out_ch):
10 | super(conv_block_nested, self).__init__()
11 | self.activation = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
13 | self.bn1 = nn.BatchNorm2d(mid_ch)
14 | self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
15 | self.bn2 = nn.BatchNorm2d(out_ch)
16 |
17 | def forward(self, x):
18 | x = self.conv1(x)
19 | identity = x
20 | x = self.bn1(x)
21 | x = self.activation(x)
22 |
23 | x = self.conv2(x)
24 | x = self.bn2(x)
25 | output = self.activation(x + identity)
26 | return output
27 |
28 |
29 | class up(nn.Module):
30 | def __init__(self, in_ch, bilinear=False):
31 | super(up, self).__init__()
32 |
33 | if bilinear:
34 | self.up = nn.Upsample(scale_factor=2,
35 | mode='bilinear',
36 | align_corners=True)
37 | else:
38 | self.up = nn.ConvTranspose2d(in_ch, in_ch, 2, stride=2)
39 |
40 | def forward(self, x):
41 | x = self.up(x)
42 | return x
43 |
44 | class ChannelAttention(nn.Module):
45 | def __init__(self, in_channels, ratio = 16):
46 | super(ChannelAttention, self).__init__()
47 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
48 | self.max_pool = nn.AdaptiveMaxPool2d(1)
49 | self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False)
50 | self.relu1 = nn.ReLU()
51 | self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False)
52 | self.sigmod = nn.Sigmoid()
53 | def forward(self,x):
54 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
55 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
56 | out = avg_out + max_out
57 | return self.sigmod(out)
58 |
59 |
60 | n1 = 32
61 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
62 | class Encoder(nn.Module):
63 | def __init__(self):
64 | super().__init__()
65 | in_ch = 3
66 | self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
67 | self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
68 | self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
69 | self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
70 | self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])
71 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
72 |
73 | def forward(self, input):
74 | x0_0 = self.conv0_0(input)
75 | x1_0 = self.conv1_0(self.pool(x0_0))
76 | x2_0 = self.conv2_0(self.pool(x1_0))
77 | x3_0 = self.conv3_0(self.pool(x2_0)) # [1, 256, 32, 32]
78 | x4_0 = self.conv4_0(self.pool(x3_0)) # [1, 512, 16, 16]
79 | # print("x30 shape:", x3_0.shape)
80 | return [x0_0, x1_0, x2_0, x3_0, x4_0]
81 |
82 |
83 | class Decoder(nn.Module):
84 | def __init__(self):
85 | super().__init__()
86 | out_ch = 2
87 | # self.conv3_1 = conv_block_nested(filters[3] * 2 + filters[4], filters[3], filters[3])
88 | self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[4], filters[4])
89 | self.conv2_2 = conv_block_nested(filters[2] * 2 + filters[3], filters[2], filters[2])
90 | self.conv1_3 = conv_block_nested(filters[1] * 2 + filters[2], filters[1], filters[1])
91 | self.conv0_4 = conv_block_nested(filters[0] * 2 + filters[1], filters[0], filters[0])
92 | self.Up4_0 = up(filters[4])
93 | self.Up3_1 = up(filters[3])
94 | self.Up2_2 = up(filters[2])
95 | self.Up1_3 = up(filters[1])
96 | self.conv_final = nn.Conv2d(filters[0], out_ch, kernel_size=1)
97 |
98 | def forward(self, x_A):
99 | self.x0_0A, self.x1_0A, self.x2_0A, self.x3_0A, self.x4_0A = x_A
100 | # self.x0_0B, self.x1_0B, self.x2_0B, self.x3_0B, self.x4_0B = x_B
101 | x3_1 = self.conv3_1(torch.cat([self.x3_0A, self.Up4_0(self.x4_0A)], 1)) # [1,512,32,32]
102 | print("x31 shape:", x3_1.shape)
103 | # x2_2 = self.conv2_2(torch.cat([self.x2_0A, self.x2_0B, self.Up3_1(x3_1)], 1)) # [1, 128, 64, 64]
104 | # x1_3 = self.conv1_3(torch.cat([self.x1_0A, self.x1_0B, self.Up2_2(x2_2)], 1)) # [1, 64, 128, 128]
105 | # x0_4 = self.conv0_4(torch.cat([self.x0_0A, self.x0_0B, self.Up1_3(x1_3)], 1)) # [1, 64, 256, 128]
106 | # print("x04:", x0_4.shape)
107 | # out = self.conv_final(x0_4)
108 | return x3_1
109 |
110 |
111 | class UNet(nn.Module):
112 | def __init__(self):
113 | super(UNet, self).__init__()
114 | torch.nn.Module.dump_patches = True
115 |
116 | for m in self.modules():
117 | if isinstance(m, nn.Conv2d):
118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
119 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
120 | nn.init.constant_(m.weight, 1)
121 | nn.init.constant_(m.bias, 0)
122 |
123 | '''Encoder return multi feature map to nested
124 | '''
125 | self.encoder = Encoder()
126 | self.decoder = Decoder()
127 |
128 |
129 | def forward(self, batch):
130 | encode = self.encoder(batch)
131 | output = self.decoder.forward(encode)
132 |
133 | return output
134 | # x0_B = self.encoder(xB)
135 | # out = self.decoder.forward(x0_A, x0_B)
136 | # return (out, ), [out]
137 | #
138 | # if __name__ == "__main__":
139 | # num_classes = [2,2]
140 | # current_task = 1
141 | # unet = UNet()
142 | # image_A = torch.randn((4, 3, 256, 256))
143 | # image_B = torch.randn((4, 3, 256, 256))
144 | # for name, m in unet.named_parameters():
145 | # print(name)
146 | # outputs_change, feature_map = unet(image_A, image_B, current_task)
147 |
--------------------------------------------------------------------------------
/lib/dataset_terra.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | from PIL import Image
4 | import os
5 | import torch
6 | from torch.utils.data import Dataset
7 | import time
8 | from tqdm import tqdm
9 | from lib.utils import preprocess_image
10 |
11 | class TerraDataset(Dataset):
12 | def __init__(
13 | self,
14 | scene_list_path='terratrack_utils/train_scenes_500.txt',
15 | scene_info_path='/home/a409/users/huboni/Projects/dataset/TerraTrack/process_output_query_ref_500',
16 | base_path='/home/a409/users/huboni/Projects/dataset/TerraTrack',
17 | train=True,
18 | preprocessing=None,
19 | min_overlap_ratio=.5,
20 | max_overlap_ratio=1,
21 | max_scale_ratio=np.inf,
22 | pairs_per_scene=5,
23 | image_size=224
24 | ):
25 | self.scenes = []
26 | with open(scene_list_path, 'r') as f:
27 | lines = f.readlines()
28 | for line in lines:
29 | self.scenes.append(line.strip('\n'))
30 | print("scenes:", self.scenes)
31 |
32 | self.scene_info_path = scene_info_path
33 | self.base_path = base_path
34 |
35 | self.train = train
36 |
37 | self.preprocessing = preprocessing
38 |
39 | self.min_overlap_ratio = min_overlap_ratio
40 | self.max_overlap_ratio = max_overlap_ratio
41 | self.max_scale_ratio = max_scale_ratio
42 |
43 | self.pairs_per_scene = pairs_per_scene
44 |
45 | self.image_size = image_size
46 |
47 | self.dataset = []
48 |
49 | def build_dataset(self):
50 | self.dataset = []
51 | if not self.train:
52 | np_random_state = np.random.get_state()
53 | np.random.seed(42)
54 | print('Building the validation dataset...')
55 | else:
56 | print('Building a new training dataset...')
57 | for scene in tqdm(self.scenes, total=len(self.scenes)):
58 | scene_info_path = os.path.join(
59 | self.scene_info_path, '%s.npz' % scene
60 | )
61 | if not os.path.exists(scene_info_path):
62 | continue
63 | scene_info = np.load(scene_info_path, allow_pickle=True)
64 | overlap_matrix = scene_info['overlap_matrix']
65 | scale_ratio_matrix = scene_info['scale_ratio_matrix']
66 |
67 | valid = np.logical_and(
68 | np.logical_and(
69 | overlap_matrix >= self.min_overlap_ratio,
70 | overlap_matrix <= self.max_overlap_ratio
71 | ),
72 | scale_ratio_matrix <= self.max_scale_ratio
73 | )
74 | # 得到匹配对
75 | pairs = np.vstack(np.where(valid))
76 | # 如果该场景中配配对= 0)
143 | image_path1 = os.path.join(
144 | self.base_path, pair_metadata['image_path1']
145 | )
146 | image1 = Image.open(image_path1)
147 | if image1.mode != 'RGB':
148 | image1 = image1.convert('RGB')
149 | image1 = np.array(image1)
150 | # make sure image and depth have same weight and height
151 | assert(image1.shape[0] == depth1.shape[0] and image1.shape[1] == depth1.shape[1])
152 | assert(image1.shape[0]>100)
153 | intrinsics1 = pair_metadata['intrinsics1']
154 | pose1 = pair_metadata['pose1']
155 |
156 | depth_path2 = os.path.join(
157 | self.base_path, pair_metadata['depth_path2']
158 | )
159 | with h5py.File(depth_path2, 'r') as hdf5_file:
160 | depth2 = np.array(hdf5_file['/depth'])
161 | assert(np.min(depth2) >= 0)
162 | image_path2 = os.path.join(
163 | self.base_path, pair_metadata['image_path2']
164 | )
165 | image2 = Image.open(image_path2)
166 | if image2.mode != 'RGB':
167 | image2 = image2.convert('RGB')
168 | image2 = np.array(image2)
169 | assert(image2.shape[0] == depth2.shape[0] and image2.shape[1] == depth2.shape[1])
170 | assert(image2.shape[0]>100)
171 | intrinsics2 = pair_metadata['intrinsics2']
172 | pose2 = pair_metadata['pose2']
173 |
174 | central_match = pair_metadata['central_match']
175 | # 通过centeral_match确定两张图像需要保留的区域,并计算出响应的边界框
176 | image1, bbox1, image2, bbox2 = self.crop(image1, image2, central_match)
177 |
178 | depth1 = depth1[
179 | bbox1[0] : bbox1[0] + self.image_size,
180 | bbox1[1] : bbox1[1] + self.image_size
181 | ]
182 | depth2 = depth2[
183 | bbox2[0] : bbox2[0] + self.image_size,
184 | bbox2[1] : bbox2[1] + self.image_size
185 | ]
186 |
187 | return (
188 | image1, depth1, intrinsics1, pose1, bbox1,
189 | image2, depth2, intrinsics2, pose2, bbox2
190 | )
191 |
192 | def crop(self, image1, image2, central_match):
193 | bbox1_i = max(int(central_match[0]) - self.image_size // 2, 0)
194 | if bbox1_i + self.image_size >= image1.shape[0]:
195 | bbox1_i = image1.shape[0] - self.image_size
196 | bbox1_j = max(int(central_match[1]) - self.image_size // 2, 0)
197 | if bbox1_j + self.image_size >= image1.shape[1]:
198 | bbox1_j = image1.shape[1] - self.image_size
199 |
200 | bbox2_i = max(int(central_match[2]) - self.image_size // 2, 0)
201 | if bbox2_i + self.image_size >= image2.shape[0]:
202 | bbox2_i = image2.shape[0] - self.image_size
203 | bbox2_j = max(int(central_match[3]) - self.image_size // 2, 0)
204 | if bbox2_j + self.image_size >= image2.shape[1]:
205 | bbox2_j = image2.shape[1] - self.image_size
206 |
207 | return (
208 | image1[
209 | bbox1_i : bbox1_i + self.image_size,
210 | bbox1_j : bbox1_j + self.image_size
211 | ],
212 | np.array([bbox1_i, bbox1_j]),
213 | image2[
214 | bbox2_i : bbox2_i + self.image_size,
215 | bbox2_j : bbox2_j + self.image_size
216 | ],
217 | np.array([bbox2_i, bbox2_j])
218 | )
219 |
220 | def __getitem__(self, idx):
221 | (
222 | image1, depth1, intrinsics1, pose1, bbox1,
223 | image2, depth2, intrinsics2, pose2, bbox2
224 | ) = self.recover_pair(self.dataset[idx])
225 |
226 | # if use model vit ignore preprocess_image
227 | image1 = preprocess_image(image1, preprocessing=self.preprocessing)
228 | image2 = preprocess_image(image2, preprocessing=self.preprocessing)
229 |
230 | return {
231 | 'image1': torch.from_numpy(image1.astype(np.float32)),
232 | 'depth1': torch.from_numpy(depth1.astype(np.float32)),
233 | 'intrinsics1': torch.from_numpy(intrinsics1.astype(np.float32)),
234 | 'pose1': torch.from_numpy(pose1.astype(np.float32)),
235 | 'bbox1': torch.from_numpy(bbox1.astype(np.float32)),
236 | 'image2': torch.from_numpy(image2.astype(np.float32)),
237 | 'depth2': torch.from_numpy(depth2.astype(np.float32)),
238 | 'intrinsics2': torch.from_numpy(intrinsics2.astype(np.float32)),
239 | 'pose2': torch.from_numpy(pose2.astype(np.float32)),
240 | 'bbox2': torch.from_numpy(bbox2.astype(np.float32))
241 | }
242 |
--------------------------------------------------------------------------------
/lib/exceptions.py:
--------------------------------------------------------------------------------
1 | class EmptyTensorError(Exception):
2 | pass
3 |
4 |
5 | class NoGradientError(Exception):
6 | pass
7 |
--------------------------------------------------------------------------------
/lib/full_model/__pycache__/model_d2.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/full_model/__pycache__/model_d2.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/full_model/__pycache__/model_d2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/full_model/__pycache__/model_d2.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/full_model/__pycache__/model_swin_unet_d2.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/full_model/__pycache__/model_swin_unet_d2.cpython-310.pyc
--------------------------------------------------------------------------------
/lib/full_model/__pycache__/model_swin_unet_d2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/full_model/__pycache__/model_swin_unet_d2.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/full_model/__pycache__/model_unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/full_model/__pycache__/model_unet.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/full_model/model_erf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from transformers import ViTFeatureExtractor, ViTModel
5 | import torchvision.models as models
6 | class DenseFeatureExtractionModule(nn.Module):
7 | def __init__(self, finetune_feature_extraction=False, use_cuda=True):
8 | super(DenseFeatureExtractionModule, self).__init__()
9 | # VGG16
10 | model = models.vgg16()
11 | vgg16_layers = [
12 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2',
13 | 'pool1',
14 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2',
15 | 'pool2',
16 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3',
17 | 'pool3',
18 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3',
19 | 'pool4',
20 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
21 | 'pool5'
22 | ]
23 | conv4_3_idx = vgg16_layers.index('conv4_3')
24 |
25 | self.model = nn.Sequential(
26 | *list(model.features.children())[: conv4_3_idx + 1]
27 | )
28 | self.num_channels = 512
29 |
30 |
31 | # Fix forward parameters
32 | for param in self.model.parameters():
33 | param.requires_grad = False
34 | if finetune_feature_extraction:
35 | # Unlock conv4_3
36 | for param in list(self.model.parameters())[-2 :]:
37 | param.requires_grad = True
38 |
39 | if use_cuda:
40 | self.model = self.model.cuda()
41 |
42 | def forward(self, batch):
43 |
44 | # VGG
45 | output = self.model(batch)
46 | return output
47 |
48 |
49 | class SoftDetectionModule(nn.Module):
50 | def __init__(self, soft_local_max_size=3):
51 | super(SoftDetectionModule, self).__init__()
52 |
53 | self.soft_local_max_size = soft_local_max_size
54 |
55 | self.pad = self.soft_local_max_size // 2
56 |
57 | def forward(self, batch):
58 | b = batch.size(0)
59 |
60 | batch = F.relu(batch) # [2,512,28,28]
61 |
62 | max_per_sample = torch.max(batch.view(b, -1), dim=1)[0] # [1,2]
63 | exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1)) # [2,512,28,28]
64 |
65 | sum_exp = (
66 | self.soft_local_max_size ** 2 *
67 | F.avg_pool2d(
68 | F.pad(exp, [self.pad] * 4, mode='constant', value=1.),# [2,512,30,30]
69 | self.soft_local_max_size, stride=1
70 | ) # [2,512,28,28]
71 | ) # [2, 512,28,28]
72 | local_max_score = exp / sum_exp # alpha
73 |
74 | depth_wise_max = torch.max(batch, dim=1)[0] # [2,28,28]
75 |
76 | depth_wise_max_score = batch / depth_wise_max.unsqueeze(1) # beta [2, 512, 28, 28]
77 |
78 | all_scores = local_max_score * depth_wise_max_score # [2, 512, 28, 28]
79 | score = torch.max(all_scores, dim=1)[0] # r [2,28,28]
80 | score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1) # s [2,28,28]
81 |
82 | return score
83 |
84 |
85 | class D2Net(nn.Module):
86 | def __init__(self, model_file=None, use_cuda=True):
87 | super(D2Net, self).__init__()
88 |
89 | self.dense_feature_extraction = DenseFeatureExtractionModule(
90 | finetune_feature_extraction=True,
91 | use_cuda=use_cuda
92 | )
93 |
94 | self.detection = SoftDetectionModule()
95 |
96 | if model_file is not None:
97 | if use_cuda:
98 | self.load_state_dict(torch.load(model_file)['model'])
99 | else:
100 | self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
101 |
102 | def forward(self, batch):
103 | b = batch['image1'].size(0)
104 |
105 | dense_features = self.dense_feature_extraction(
106 | torch.cat([batch['image1'], batch['image2']], dim=0)
107 | )
108 |
109 | dense_features1 = dense_features[: b, :, :, :]
110 | dense_features2 = dense_features[b :, :, :, :]
111 | scores = self.detection(dense_features)
112 | scores1 = scores[: b, :, :]
113 | scores2 = scores[b :, :, :]
114 |
115 | return {
116 | 'dense_features1': dense_features1,
117 | 'scores1': scores1,
118 | 'dense_features2': dense_features2,
119 | 'scores2': scores2
120 | }
--------------------------------------------------------------------------------
/lib/full_model/model_swin_unet_d2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | # from transformers import ViTFeatureExtractor, ViTModel
5 | # import torchvision.models as models
6 | from ..backbone_model.swin_unet import SwinUnet
7 |
8 | # FixMe 将特征提取的基础网络替换为swin-unet 其他计算scores计算匹配点对部分保持
9 | class DenseFeatureExtractionModule(nn.Module):
10 | def __init__(self, finetune_feature_extraction=False, use_cuda=True):
11 | super(DenseFeatureExtractionModule, self).__init__()
12 |
13 | # TODO model:transformer
14 | ## VGG16
15 | # model = models.vgg16()
16 | # vgg16_layers = [
17 | # 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2',
18 | # 'pool1',
19 | # 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2',
20 | # 'pool2',
21 | # 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3',
22 | # 'pool3',
23 | # 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3',
24 | # 'pool4',
25 | # 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
26 | # 'pool5'
27 | # ]
28 | # conv4_3_idx = vgg16_layers.index('conv4_3')
29 | #
30 | # self.model = nn.Sequential(
31 | # *list(model.features.children())[: conv4_3_idx + 1]
32 | # )
33 | # self.num_channels = 512
34 |
35 | ## TODO Swin Unet
36 | trans_pretrained_path = '/home/a409/users/huboni/Projects/code/d2-net/models/swin_tiny_patch4_window7_224.pth'
37 | trans_pretrained_path = None
38 | self.model = SwinUnet(trans_pretrained_path).cuda()
39 | self.model.load_from(trans_pretrained_path)
40 |
41 |
42 | # # Fix forward parameters
43 | # for param in self.model.parameters():
44 | # param.requires_grad = False
45 | # if finetune_feature_extraction:
46 | # # Unlock conv4_3
47 | # for param in list(self.model.parameters())[-2 :]:
48 | # param.requires_grad = True
49 |
50 | if use_cuda:
51 | self.model = self.model.cuda()
52 |
53 | def forward(self, batch):
54 |
55 | output = self.model(batch)
56 |
57 | return output
58 |
59 |
60 | class SoftDetectionModule(nn.Module):
61 | def __init__(self, soft_local_max_size=3):
62 | super(SoftDetectionModule, self).__init__()
63 |
64 | self.soft_local_max_size = soft_local_max_size
65 |
66 | self.pad = self.soft_local_max_size // 2
67 |
68 | def forward(self, batch):
69 | b = batch.size(0)
70 | batch = F.relu(batch)
71 | max_per_sample = torch.max(batch.reshape(b, -1), dim=1)[0]
72 | # print("max pre sample:", max_per_sample)
73 | exp = torch.exp(batch / max_per_sample.reshape(b, 1, 1, 1))
74 |
75 | sum_exp = (
76 | self.soft_local_max_size ** 2 *
77 | F.avg_pool2d(
78 | F.pad(exp, [self.pad] * 4, mode='constant', value=1.),
79 | self.soft_local_max_size, stride=1
80 | )
81 | )
82 | local_max_score = exp / sum_exp
83 |
84 | depth_wise_max = torch.max(batch, dim=1)[0]
85 | depth_wise_max_score = batch / depth_wise_max.unsqueeze(1)
86 |
87 | all_scores = local_max_score * depth_wise_max_score
88 | score = torch.max(all_scores, dim=1)[0]
89 |
90 | score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1)
91 |
92 |
93 | return score
94 |
95 |
96 | class Swin_D2UNet(nn.Module):
97 | def __init__(self, use_cuda=True):
98 | super(Swin_D2UNet, self).__init__()
99 |
100 | self.dense_feature_extraction = DenseFeatureExtractionModule(
101 | finetune_feature_extraction=True,
102 | use_cuda=use_cuda
103 | )
104 |
105 | self.detection = SoftDetectionModule()
106 |
107 | # if model_file is not None:
108 | # if use_cuda:
109 | # self.load_state_dict(torch.load(model_file)['model'])
110 | # else:
111 | # self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
112 |
113 | def forward(self, batch):
114 | b = batch['image1'].size(0)
115 |
116 | # dense_features = self.dense_feature_extraction(
117 | # torch.cat([batch['image1'], batch['image2']], dim=0)
118 | # )
119 | # scores = self.detection(dense_features)
120 | # dense_features1 = dense_features[: b, :, :, :]
121 | # dense_features2 = dense_features[b :, :, :, :]
122 | # scores1 = scores[: b, :, :]
123 | # scores2 = scores[b :, :, :]
124 |
125 | dense_features1 = self.dense_feature_extraction(batch['image1']) # [1,1000,224,224]
126 | # print("dense feature1:", dense_features1)
127 | # print("feature max:", torch.max(dense_features1))
128 | # print("feature min:", torch.min(dense_features1))
129 | dense_features2 = self.dense_feature_extraction(batch['image2'])
130 |
131 | # scores = self.detection(torch.cat([dense_features1, dense_features2], dim=0))
132 | # scores2 = scores[b :, :, :]
133 | # scores1 = scores[: b, :, :]
134 | scores1 = self.detection(dense_features1)
135 | scores2 = self.detection(dense_features2)
136 |
137 | return {
138 | 'dense_features1': dense_features1,
139 | 'scores1': scores1,
140 | 'dense_features2': dense_features2,
141 | 'scores2': scores2
142 | }
--------------------------------------------------------------------------------
/lib/full_model/model_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from transformers import ViTFeatureExtractor, ViTModel
5 | import torchvision.models as models
6 | from ..backbone_model.unet import UNet
7 | class DenseFeatureExtractionModule(nn.Module):
8 | def __init__(self, finetune_feature_extraction=False, use_cuda=True):
9 | super(DenseFeatureExtractionModule, self).__init__()
10 | # VGG16
11 | # model = models.vgg16()
12 | # vgg16_layers = [
13 | # 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2',
14 | # 'pool1',
15 | # 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2',
16 | # 'pool2',
17 | # 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3',
18 | # 'pool3',
19 | # 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3',
20 | # 'pool4',
21 | # 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
22 | # 'pool5'
23 | # ]
24 | # conv4_3_idx = vgg16_layers.index('conv4_3')
25 | #
26 | # self.model = nn.Sequential(
27 | # *list(model.features.children())[: conv4_3_idx + 1]
28 | # )
29 | # self.num_channels = 512
30 |
31 | self.model = UNet()
32 | print("model list", self.model)
33 |
34 |
35 | # Fix forward parameters
36 | # for param in self.model.parameters():
37 | # param.requires_grad = False
38 | # if finetune_feature_extraction:
39 | # # Unlock conv4_3
40 | # for param in list(self.model.parameters())[-2 :]:
41 | # param.requires_grad = True
42 |
43 | if use_cuda:
44 | self.model = self.model.cuda()
45 |
46 | def forward(self, batch):
47 |
48 | # VGG
49 | output = self.model(batch)
50 | return output
51 |
52 |
53 | class SoftDetectionModule(nn.Module):
54 | def __init__(self, soft_local_max_size=3):
55 | super(SoftDetectionModule, self).__init__()
56 |
57 | self.soft_local_max_size = soft_local_max_size
58 |
59 | self.pad = self.soft_local_max_size // 2
60 |
61 | def forward(self, batch):
62 | b = batch.size(0)
63 |
64 | batch = F.relu(batch) # [2,512,28,28]
65 |
66 | max_per_sample = torch.max(batch.view(b, -1), dim=1)[0] # [1,2]
67 | exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1)) # [2,512,28,28]
68 |
69 | sum_exp = (
70 | self.soft_local_max_size ** 2 *
71 | F.avg_pool2d(
72 | F.pad(exp, [self.pad] * 4, mode='constant', value=1.),# [2,512,30,30]
73 | self.soft_local_max_size, stride=1
74 | ) # [2,512,28,28]
75 | ) # [2, 512,28,28]
76 | local_max_score = exp / sum_exp # alpha
77 |
78 | depth_wise_max = torch.max(batch, dim=1)[0] # [2,28,28]
79 |
80 | depth_wise_max_score = batch / depth_wise_max.unsqueeze(1) # beta [2, 512, 28, 28]
81 |
82 | all_scores = local_max_score * depth_wise_max_score # [2, 512, 28, 28]
83 | score = torch.max(all_scores, dim=1)[0] # r [2,28,28]
84 | score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1) # s [2,28,28]
85 |
86 | return score
87 |
88 |
89 | class U2Net(nn.Module):
90 | def __init__(self, model_file=None, use_cuda=True):
91 | super(U2Net, self).__init__()
92 |
93 | self.dense_feature_extraction = DenseFeatureExtractionModule(
94 | finetune_feature_extraction=True,
95 | use_cuda=use_cuda
96 | )
97 |
98 | self.detection = SoftDetectionModule()
99 |
100 | # if model_file is not None:
101 | # if use_cuda:
102 | # self.load_state_dict(torch.load(model_file)['model'])
103 | # else:
104 | # self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
105 |
106 | def forward(self, batch):
107 | b = batch['image1'].size(0)
108 |
109 | # print("image1 origin", batch['image1'])
110 | dense_features1 = self.dense_feature_extraction(batch['image1'])
111 | dense_features2 = self.dense_feature_extraction(batch['image2'])
112 | print("dense_features1 max:", dense_features1.max())
113 | print("dense_features1 min:", dense_features1.min())
114 |
115 | scores = self.detection(torch.cat([dense_features1, dense_features2], dim=0))
116 |
117 | # dense_features = self.dense_feature_extraction(
118 | # torch.cat([batch['image1'], batch['image2']], dim=0)
119 | # )
120 | # dense_features1 = dense_features[: b, :, :, :]
121 | # dense_features2 = dense_features[b :, :, :, :]
122 | # scores = self.detection(dense_features)
123 |
124 | scores1 = scores[: b, :, :]
125 | scores2 = scores[b :, :, :]
126 |
127 |
128 | return {
129 | 'dense_features1': dense_features1,
130 | 'scores1': scores1,
131 | 'dense_features2': dense_features2,
132 | 'scores2': scores2
133 | }
--------------------------------------------------------------------------------
/lib/full_model/model_vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from transformers import ViTFeatureExtractor, ViTModel
5 | import torchvision.models as models
6 | class DenseFeatureExtractionModule(nn.Module):
7 | def __init__(self, finetune_feature_extraction=False, use_cuda=True):
8 | super(DenseFeatureExtractionModule, self).__init__()
9 |
10 | # TODO model:transformer
11 | feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
12 | model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
13 |
14 | self.model = model
15 | self.feature_extractor = feature_extractor
16 | ## VGG16
17 | # model = models.vgg16()
18 | # vgg16_layers = [
19 | # 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2',
20 | # 'pool1',
21 | # 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2',
22 | # 'pool2',
23 | # 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3',
24 | # 'pool3',
25 | # 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3',
26 | # 'pool4',
27 | # 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
28 | # 'pool5'
29 | # ]
30 | # conv4_3_idx = vgg16_layers.index('conv4_3')
31 | #
32 | # self.model = nn.Sequential(
33 | # *list(model.features.children())[: conv4_3_idx + 1]
34 | # )
35 | # self.num_channels = 512
36 | #
37 | #
38 | # # Fix forward parameters
39 | # for param in self.model.parameters():
40 | # param.requires_grad = False
41 | # if finetune_feature_extraction:
42 | # # Unlock conv4_3
43 | # for param in list(self.model.parameters())[-2 :]:
44 | # param.requires_grad = True
45 |
46 | if use_cuda:
47 | self.model = self.model.cuda()
48 |
49 | def forward(self, batch):
50 |
51 | # VGG
52 | # output = self.model(batch)
53 | # transformer
54 | batch = self.feature_extractor(batch, return_tensors='pt')
55 | batch.to(device="cuda")
56 | output = self.model(**batch).last_hidden_state
57 | return output
58 |
59 |
60 | class SoftDetectionModule(nn.Module):
61 | def __init__(self, soft_local_max_size=3):
62 | super(SoftDetectionModule, self).__init__()
63 |
64 | self.soft_local_max_size = soft_local_max_size
65 |
66 | self.pad = self.soft_local_max_size // 2
67 |
68 | def forward(self, batch):
69 | b = batch.size(0)
70 |
71 | batch = F.relu(batch) # [2,512,28,28]
72 |
73 | max_per_sample = torch.max(batch.view(b, -1), dim=1)[0] # [1,2]
74 | exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1)) # [2,512,28,28]
75 |
76 | sum_exp = (
77 | self.soft_local_max_size ** 2 *
78 | F.avg_pool2d(
79 | F.pad(exp, [self.pad] * 4, mode='constant', value=1.),# [2,512,30,30]
80 | self.soft_local_max_size, stride=1
81 | ) # [2,512,28,28]
82 | ) # [2, 512,28,28]
83 | local_max_score = exp / sum_exp # alpha
84 |
85 | depth_wise_max = torch.max(batch, dim=1)[0] # [2,28,28]
86 |
87 | depth_wise_max_score = batch / depth_wise_max.unsqueeze(1) # beta [2, 512, 28, 28]
88 |
89 | all_scores = local_max_score * depth_wise_max_score # [2, 512, 28, 28]
90 | score = torch.max(all_scores, dim=1)[0] # r [2,28,28]
91 | score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1) # s [2,28,28]
92 |
93 | return score
94 |
95 |
96 | class D2Net(nn.Module):
97 | def __init__(self, model_file=None, use_cuda=True):
98 | super(D2Net, self).__init__()
99 |
100 | self.dense_feature_extraction = DenseFeatureExtractionModule(
101 | finetune_feature_extraction=True,
102 | use_cuda=use_cuda
103 | )
104 |
105 | self.detection = SoftDetectionModule()
106 | self.conv_out = nn.Conv2d(in_channels=768, out_channels=512, kernel_size=3, stride=2, padding=1)
107 |
108 |
109 | # if model_file is not None:
110 | # if use_cuda:
111 | # self.load_state_dict(torch.load(model_file)['model'])
112 | # else:
113 | # self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
114 |
115 | def forward(self, batch):
116 | b = batch['image1'].size(0)
117 |
118 | # dense_features = self.dense_feature_extraction(
119 | # torch.cat([batch['image1'], batch['image2']], dim=0)
120 | # )
121 | dense_features1 = self.dense_feature_extraction(batch['image1'])[:,1:, :]
122 | dense_features2 = self.dense_feature_extraction(batch['image2'])[:,1:, :]
123 | dense_features1 = dense_features1.view(1, 14, 14, 768)
124 | dense_features2 = dense_features2.view(1, 14, 14, 768)
125 | # dense_features1 = dense_features1.permute(0,3,1,2)
126 | # dense_features2 = dense_features2.permute(0,3,1,2)
127 | print("dense_features1 max 1", dense_features1.max)
128 | print("dense_features1 min 1", dense_features1.min)
129 |
130 | # dense_features1 = self.conv_out(dense_features1)
131 | # print("dense_features1 shape conv", dense_features1.shape)
132 | # dense_features2 = self.conv_out(dense_features2)
133 |
134 |
135 | scores = self.detection(torch.cat([dense_features1, dense_features2], dim=0))
136 | print("scores shape:", scores.shape)
137 | # scores1 = self.detection(dense_features1)
138 | # scores2 = self.detection(dense_features2)
139 |
140 |
141 | # dense_features1 = dense_features[: b, :, :, :]
142 | # print("dense features:", dense_features1)
143 | print("feature max:", torch.max(dense_features1))
144 | print("feature min:", torch.min(dense_features1))
145 |
146 | # print("dense feaures 1 shape:", dense_features1.shape) [1,512,28,28]
147 |
148 | # dense_features2 = dense_features[b :, :, :, :]
149 | scores1 = scores[: b, :, :]
150 | scores2 = scores[b :, :, :]
151 |
152 | return {
153 | 'dense_features1': dense_features1,
154 | 'scores1': scores1,
155 | 'dense_features2': dense_features2,
156 | 'scores2': scores2
157 | }
158 |
--------------------------------------------------------------------------------
/lib/model_swin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .swin_unet import SwinUnet
5 |
6 | class DenseFeatureExtractionModule(nn.Module):
7 | def __init__(self, use_relu=True, use_cuda=True):
8 | super(DenseFeatureExtractionModule, self).__init__()
9 | self.model = SwinUnet().cuda()
10 |
11 | if use_cuda:
12 | self.model = self.model.cuda()
13 |
14 | self.num_channels = 512
15 |
16 | self.use_relu = use_relu
17 |
18 | def forward(self, batch):
19 | # img_size = batch.size()[-1]
20 | # print("img size:", img_size)
21 | output = self.model(batch)
22 | if self.use_relu:
23 | output = F.relu(output)
24 | return output
25 |
26 |
27 | class SwinU2Net(nn.Module):
28 | def __init__(self, model_file=None, use_relu=True, use_cuda=True):
29 | super(SwinU2Net, self).__init__()
30 |
31 | self.dense_feature_extraction = DenseFeatureExtractionModule(
32 | use_relu=use_relu, use_cuda=use_cuda
33 | )
34 |
35 | self.detection = HardDetectionModule()
36 |
37 | self.localization = HandcraftedLocalizationModule()
38 |
39 | if model_file is not None:
40 | # self.load_state_dict(torch.load(model_file)['model'])
41 | self.load_state_dict(torch.load(model_file)['model'])
42 | def forward(self, batch):
43 | _, _, h, w = batch.size()
44 | dense_features = self.dense_feature_extraction(batch, h)
45 |
46 | detections = self.detection(dense_features)
47 |
48 | displacements = self.localization(dense_features)
49 |
50 | return {
51 | 'dense_features': dense_features,
52 | 'detections': detections,
53 | 'displacements': displacements
54 | }
55 |
56 |
57 | class HardDetectionModule(nn.Module):
58 | def __init__(self, edge_threshold=5):
59 | super(HardDetectionModule, self).__init__()
60 |
61 | self.edge_threshold = edge_threshold
62 |
63 | self.dii_filter = torch.tensor(
64 | [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
65 | ).view(1, 1, 3, 3)
66 | self.dij_filter = 0.25 * torch.tensor(
67 | [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
68 | ).view(1, 1, 3, 3)
69 | self.djj_filter = torch.tensor(
70 | [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
71 | ).view(1, 1, 3, 3)
72 |
73 | def forward(self, batch):
74 | b, c, h, w = batch.size()
75 | device = batch.device
76 |
77 | depth_wise_max = torch.max(batch, dim=1)[0]
78 | is_depth_wise_max = (batch == depth_wise_max)
79 | del depth_wise_max
80 |
81 | local_max = F.max_pool2d(batch, 3, stride=1, padding=1)
82 | is_local_max = (batch == local_max)
83 | del local_max
84 |
85 | dii = F.conv2d(
86 | batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
87 | ).view(b, c, h, w)
88 | dij = F.conv2d(
89 | batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
90 | ).view(b, c, h, w)
91 | djj = F.conv2d(
92 | batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
93 | ).view(b, c, h, w)
94 |
95 | det = dii * djj - dij * dij
96 | tr = dii + djj
97 | del dii, dij, djj
98 |
99 | threshold = (self.edge_threshold + 1) ** 2 / self.edge_threshold
100 | is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)
101 |
102 | detected = torch.min(
103 | is_depth_wise_max,
104 | torch.min(is_local_max, is_not_edge)
105 | )
106 | del is_depth_wise_max, is_local_max, is_not_edge
107 |
108 | return detected
109 |
110 |
111 | class HandcraftedLocalizationModule(nn.Module):
112 | def __init__(self):
113 | super(HandcraftedLocalizationModule, self).__init__()
114 |
115 | self.di_filter = torch.tensor(
116 | [[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]]
117 | ).view(1, 1, 3, 3)
118 | self.dj_filter = torch.tensor(
119 | [[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]]
120 | ).view(1, 1, 3, 3)
121 |
122 | self.dii_filter = torch.tensor(
123 | [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
124 | ).view(1, 1, 3, 3)
125 | self.dij_filter = 0.25 * torch.tensor(
126 | [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
127 | ).view(1, 1, 3, 3)
128 | self.djj_filter = torch.tensor(
129 | [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
130 | ).view(1, 1, 3, 3)
131 |
132 | def forward(self, batch):
133 | b, c, h, w = batch.size()
134 | device = batch.device
135 |
136 | dii = F.conv2d(
137 | batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
138 | ).view(b, c, h, w)
139 | dij = F.conv2d(
140 | batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
141 | ).view(b, c, h, w)
142 | djj = F.conv2d(
143 | batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
144 | ).view(b, c, h, w)
145 | det = dii * djj - dij * dij
146 |
147 | inv_hess_00 = djj / det
148 | inv_hess_01 = -dij / det
149 | inv_hess_11 = dii / det
150 | del dii, dij, djj, det
151 |
152 | di = F.conv2d(
153 | batch.view(-1, 1, h, w), self.di_filter.to(device), padding=1
154 | ).view(b, c, h, w)
155 | dj = F.conv2d(
156 | batch.view(-1, 1, h, w), self.dj_filter.to(device), padding=1
157 | ).view(b, c, h, w)
158 |
159 | step_i = -(inv_hess_00 * di + inv_hess_01 * dj)
160 | step_j = -(inv_hess_01 * di + inv_hess_11 * dj)
161 | del inv_hess_00, inv_hess_01, inv_hess_11, di, dj
162 |
163 | return torch.stack([step_i, step_j], dim=1)
164 |
--------------------------------------------------------------------------------
/lib/model_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class DenseFeatureExtractionModule(nn.Module):
7 | def __init__(self, use_relu=True, use_cuda=True):
8 | super(DenseFeatureExtractionModule, self).__init__()
9 |
10 | self.model = nn.Sequential(
11 | nn.Conv2d(3, 64, 3, padding=1),
12 | nn.ReLU(inplace=True),
13 | nn.Conv2d(64, 64, 3, padding=1),
14 | nn.ReLU(inplace=True),
15 | nn.MaxPool2d(2, stride=2),
16 | nn.Conv2d(64, 128, 3, padding=1),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(128, 128, 3, padding=1),
19 | nn.ReLU(inplace=True),
20 | nn.MaxPool2d(2, stride=2),
21 | nn.Conv2d(128, 256, 3, padding=1),
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(256, 256, 3, padding=1),
24 | nn.ReLU(inplace=True),
25 | nn.Conv2d(256, 256, 3, padding=1),
26 | nn.ReLU(inplace=True),
27 | nn.AvgPool2d(2, stride=1),
28 | nn.Conv2d(256, 512, 3, padding=2, dilation=2),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(512, 512, 3, padding=2, dilation=2),
31 | nn.ReLU(inplace=True),
32 | nn.Conv2d(512, 512, 3, padding=2, dilation=2),
33 | )
34 | self.num_channels = 512
35 |
36 | self.use_relu = use_relu
37 |
38 | if use_cuda:
39 | self.model = self.model.cuda()
40 |
41 | def forward(self, batch):
42 | output = self.model(batch)
43 | if self.use_relu:
44 | output = F.relu(output)
45 | return output
46 |
47 |
48 | class D2Net(nn.Module):
49 | def __init__(self, model_file=None, use_relu=True, use_cuda=True):
50 | super(D2Net, self).__init__()
51 |
52 | self.dense_feature_extraction = DenseFeatureExtractionModule(
53 | use_relu=use_relu, use_cuda=use_cuda
54 | )
55 |
56 | self.detection = HardDetectionModule()
57 |
58 | self.localization = HandcraftedLocalizationModule()
59 |
60 | if model_file is not None:
61 | if use_cuda:
62 | self.load_state_dict(torch.load(model_file)['model'])
63 | else:
64 | self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
65 |
66 | def forward(self, batch):
67 | _, _, h, w = batch.size()
68 | dense_features = self.dense_feature_extraction(batch)
69 |
70 | detections = self.detection(dense_features)
71 |
72 | displacements = self.localization(dense_features)
73 |
74 | return {
75 | 'dense_features': dense_features,
76 | 'detections': detections,
77 | 'displacements': displacements
78 | }
79 |
80 |
81 | class HardDetectionModule(nn.Module):
82 | def __init__(self, edge_threshold=5):
83 | super(HardDetectionModule, self).__init__()
84 |
85 | self.edge_threshold = edge_threshold
86 |
87 | self.dii_filter = torch.tensor(
88 | [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
89 | ).view(1, 1, 3, 3)
90 | self.dij_filter = 0.25 * torch.tensor(
91 | [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
92 | ).view(1, 1, 3, 3)
93 | self.djj_filter = torch.tensor(
94 | [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
95 | ).view(1, 1, 3, 3)
96 |
97 | def forward(self, batch):
98 | b, c, h, w = batch.size()
99 | device = batch.device
100 |
101 | depth_wise_max = torch.max(batch, dim=1)[0]
102 | is_depth_wise_max = (batch == depth_wise_max)
103 | del depth_wise_max
104 |
105 | local_max = F.max_pool2d(batch, 3, stride=1, padding=1)
106 | is_local_max = (batch == local_max)
107 | del local_max
108 |
109 | dii = F.conv2d(
110 | batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
111 | ).view(b, c, h, w)
112 | dij = F.conv2d(
113 | batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
114 | ).view(b, c, h, w)
115 | djj = F.conv2d(
116 | batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
117 | ).view(b, c, h, w)
118 |
119 | det = dii * djj - dij * dij
120 | tr = dii + djj
121 | del dii, dij, djj
122 |
123 | threshold = (self.edge_threshold + 1) ** 2 / self.edge_threshold
124 | is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)
125 |
126 | detected = torch.min(
127 | is_depth_wise_max,
128 | torch.min(is_local_max, is_not_edge)
129 | )
130 | del is_depth_wise_max, is_local_max, is_not_edge
131 |
132 | return detected
133 |
134 |
135 | class HandcraftedLocalizationModule(nn.Module):
136 | def __init__(self):
137 | super(HandcraftedLocalizationModule, self).__init__()
138 |
139 | self.di_filter = torch.tensor(
140 | [[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]]
141 | ).view(1, 1, 3, 3)
142 | self.dj_filter = torch.tensor(
143 | [[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]]
144 | ).view(1, 1, 3, 3)
145 |
146 | self.dii_filter = torch.tensor(
147 | [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
148 | ).view(1, 1, 3, 3)
149 | self.dij_filter = 0.25 * torch.tensor(
150 | [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
151 | ).view(1, 1, 3, 3)
152 | self.djj_filter = torch.tensor(
153 | [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
154 | ).view(1, 1, 3, 3)
155 |
156 | def forward(self, batch):
157 | b, c, h, w = batch.size()
158 | device = batch.device
159 |
160 | dii = F.conv2d(
161 | batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
162 | ).view(b, c, h, w)
163 | dij = F.conv2d(
164 | batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
165 | ).view(b, c, h, w)
166 | djj = F.conv2d(
167 | batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
168 | ).view(b, c, h, w)
169 | det = dii * djj - dij * dij
170 |
171 | inv_hess_00 = djj / det
172 | inv_hess_01 = -dij / det
173 | inv_hess_11 = dii / det
174 | del dii, dij, djj, det
175 |
176 | di = F.conv2d(
177 | batch.view(-1, 1, h, w), self.di_filter.to(device), padding=1
178 | ).view(b, c, h, w)
179 | dj = F.conv2d(
180 | batch.view(-1, 1, h, w), self.dj_filter.to(device), padding=1
181 | ).view(b, c, h, w)
182 |
183 | step_i = -(inv_hess_00 * di + inv_hess_01 * dj)
184 | step_j = -(inv_hess_01 * di + inv_hess_11 * dj)
185 | del inv_hess_00, inv_hess_01, inv_hess_11, di, dj
186 |
187 | return torch.stack([step_i, step_j], dim=1)
188 |
--------------------------------------------------------------------------------
/lib/swin_transformer_unet_skip_expand_decoder_sys.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 | from einops import rearrange
5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6 |
7 |
8 | class Mlp(nn.Module):
9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
10 | super().__init__()
11 | out_features = out_features or in_features
12 | hidden_features = hidden_features or in_features
13 | self.fc1 = nn.Linear(in_features, hidden_features)
14 | self.act = act_layer()
15 | self.fc2 = nn.Linear(hidden_features, out_features)
16 | self.drop = nn.Dropout(drop)
17 |
18 | def forward(self, x):
19 | x = self.fc1(x)
20 | x = self.act(x)
21 | x = self.drop(x)
22 | x = self.fc2(x)
23 | x = self.drop(x)
24 | return x
25 |
26 |
27 | def window_partition(x, window_size):
28 | """
29 | Args:
30 | x: (B, H, W, C)
31 | window_size (int): window size
32 |
33 | Returns:
34 | windows: (num_windows*B, window_size, window_size, C)
35 | """
36 | B, H, W, C = x.shape
37 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
38 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
39 | return windows
40 |
41 |
42 | def window_reverse(windows, window_size, H, W):
43 | """
44 | Args:
45 | windows: (num_windows*B, window_size, window_size, C)
46 | window_size (int): Window size
47 | H (int): Height of image
48 | W (int): Width of image
49 |
50 | Returns:
51 | x: (B, H, W, C)
52 | """
53 | B = int(windows.shape[0] / (H * W / window_size / window_size))
54 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
55 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
56 | return x
57 |
58 |
59 | class WindowAttention(nn.Module):
60 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
61 | It supports both of shifted and non-shifted window.
62 |
63 | Args:
64 | dim (int): Number of input channels.
65 | window_size (tuple[int]): The height and width of the window.
66 | num_heads (int): Number of attention heads.
67 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
68 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
69 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
70 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
71 | """
72 |
73 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
74 |
75 | super().__init__()
76 | self.dim = dim
77 | self.window_size = window_size # Wh, Ww
78 | self.num_heads = num_heads
79 | head_dim = dim // num_heads
80 | self.scale = qk_scale or head_dim ** -0.5
81 |
82 | # define a parameter table of relative position bias
83 | self.relative_position_bias_table = nn.Parameter(
84 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
85 |
86 | # get pair-wise relative position index for each token inside the window
87 | coords_h = torch.arange(self.window_size[0])
88 | coords_w = torch.arange(self.window_size[1])
89 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
90 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
91 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
92 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
93 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
94 | relative_coords[:, :, 1] += self.window_size[1] - 1
95 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
96 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
97 | self.register_buffer("relative_position_index", relative_position_index)
98 |
99 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
100 | self.attn_drop = nn.Dropout(attn_drop)
101 | self.proj = nn.Linear(dim, dim)
102 | self.proj_drop = nn.Dropout(proj_drop)
103 |
104 | trunc_normal_(self.relative_position_bias_table, std=.02)
105 | self.softmax = nn.Softmax(dim=-1)
106 |
107 | def forward(self, x, mask=None):
108 | """
109 | Args:
110 | x: input features with shape of (num_windows*B, N, C)
111 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
112 | """
113 | B_, N, C = x.shape
114 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
115 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
116 |
117 | q = q * self.scale
118 | attn = (q @ k.transpose(-2, -1))
119 |
120 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
121 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
122 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
123 | attn = attn + relative_position_bias.unsqueeze(0)
124 |
125 | if mask is not None:
126 | nW = mask.shape[0]
127 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
128 | attn = attn.view(-1, self.num_heads, N, N)
129 | attn = self.softmax(attn)
130 | else:
131 | attn = self.softmax(attn)
132 |
133 | attn = self.attn_drop(attn)
134 |
135 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
136 | x = self.proj(x)
137 | x = self.proj_drop(x)
138 | return x
139 |
140 | def extra_repr(self) -> str:
141 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
142 |
143 | def flops(self, N):
144 | # calculate flops for 1 window with token length of N
145 | flops = 0
146 | # qkv = self.qkv(x)
147 | flops += N * self.dim * 3 * self.dim
148 | # attn = (q @ k.transpose(-2, -1))
149 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
150 | # x = (attn @ v)
151 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
152 | # x = self.proj(x)
153 | flops += N * self.dim * self.dim
154 | return flops
155 |
156 |
157 | class SwinTransformerBlock(nn.Module):
158 | r""" Swin Transformer Block.
159 |
160 | Args:
161 | dim (int): Number of input channels.
162 | input_resolution (tuple[int]): Input resulotion.
163 | num_heads (int): Number of attention heads.
164 | window_size (int): Window size.
165 | shift_size (int): Shift size for SW-MSA.
166 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
167 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
168 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
169 | drop (float, optional): Dropout rate. Default: 0.0
170 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
171 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
172 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
173 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
174 | """
175 |
176 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
177 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
178 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
179 | super().__init__()
180 | self.dim = dim
181 | self.input_resolution = input_resolution
182 | self.num_heads = num_heads
183 | self.window_size = window_size
184 | self.shift_size = shift_size
185 | self.mlp_ratio = mlp_ratio
186 | if min(self.input_resolution) <= self.window_size:
187 | # if window size is larger than input resolution, we don't partition windows
188 | self.shift_size = 0
189 | self.window_size = min(self.input_resolution)
190 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
191 |
192 | self.norm1 = norm_layer(dim)
193 | self.attn = WindowAttention(
194 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
195 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
196 |
197 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
198 | self.norm2 = norm_layer(dim)
199 | mlp_hidden_dim = int(dim * mlp_ratio)
200 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
201 |
202 | if self.shift_size > 0:
203 | # calculate attention mask for SW-MSA
204 | H, W = self.input_resolution
205 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
206 | h_slices = (slice(0, -self.window_size),
207 | slice(-self.window_size, -self.shift_size),
208 | slice(-self.shift_size, None))
209 | w_slices = (slice(0, -self.window_size),
210 | slice(-self.window_size, -self.shift_size),
211 | slice(-self.shift_size, None))
212 | cnt = 0
213 | for h in h_slices:
214 | for w in w_slices:
215 | img_mask[:, h, w, :] = cnt
216 | cnt += 1
217 |
218 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
219 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
220 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
221 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
222 | else:
223 | attn_mask = None
224 |
225 | self.register_buffer("attn_mask", attn_mask)
226 |
227 | def forward(self, x):
228 | H, W = self.input_resolution
229 | B, L, C = x.shape
230 | assert L == H * W, "input feature has wrong size"
231 |
232 | shortcut = x
233 | x = self.norm1(x)
234 | x = x.view(B, H, W, C)
235 |
236 | # cyclic shift
237 | if self.shift_size > 0:
238 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
239 | else:
240 | shifted_x = x
241 |
242 | # partition windows
243 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
244 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
245 |
246 | # W-MSA/SW-MSA
247 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
248 |
249 | # merge windows
250 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
251 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
252 |
253 | # reverse cyclic shift
254 | if self.shift_size > 0:
255 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
256 | else:
257 | x = shifted_x
258 | x = x.view(B, H * W, C)
259 |
260 | # FFN
261 | x = shortcut + self.drop_path(x)
262 | x = x + self.drop_path(self.mlp(self.norm2(x)))
263 |
264 | return x
265 |
266 | def extra_repr(self) -> str:
267 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
268 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
269 |
270 | def flops(self):
271 | flops = 0
272 | H, W = self.input_resolution
273 | # norm1
274 | flops += self.dim * H * W
275 | # W-MSA/SW-MSA
276 | nW = H * W / self.window_size / self.window_size
277 | flops += nW * self.attn.flops(self.window_size * self.window_size)
278 | # mlp
279 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
280 | # norm2
281 | flops += self.dim * H * W
282 | return flops
283 |
284 |
285 | class PatchMerging(nn.Module):
286 | r""" Patch Merging Layer.
287 |
288 | Args:
289 | input_resolution (tuple[int]): Resolution of input feature.
290 | dim (int): Number of input channels.
291 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
292 | """
293 |
294 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
295 | super().__init__()
296 | self.input_resolution = input_resolution
297 | self.dim = dim
298 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
299 | self.norm = norm_layer(4 * dim)
300 |
301 | def forward(self, x):
302 | """
303 | x: B, H*W, C
304 | """
305 | H, W = self.input_resolution
306 | B, L, C = x.shape
307 | assert L == H * W, "input feature has wrong size"
308 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
309 |
310 | x = x.view(B, H, W, C)
311 |
312 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
313 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
314 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
315 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
316 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
317 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
318 |
319 | x = self.norm(x)
320 | x = self.reduction(x)
321 |
322 | return x
323 |
324 | def extra_repr(self) -> str:
325 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
326 |
327 | def flops(self):
328 | H, W = self.input_resolution
329 | flops = H * W * self.dim
330 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
331 | return flops
332 |
333 | class PatchExpand(nn.Module):
334 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
335 | super().__init__()
336 | self.input_resolution = input_resolution
337 | self.dim = dim
338 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
339 | self.norm = norm_layer(dim // dim_scale)
340 |
341 | def forward(self, x):
342 | """
343 | x: B, H*W, C
344 | """
345 | H, W = self.input_resolution
346 | x = self.expand(x)
347 | B, L, C = x.shape
348 | assert L == H * W, "input feature has wrong size"
349 |
350 | x = x.view(B, H, W, C)
351 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
352 | x = x.view(B,-1,C//4)
353 | x= self.norm(x)
354 |
355 | return x
356 |
357 | class FinalPatchExpand_X4(nn.Module):
358 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
359 | super().__init__()
360 | self.input_resolution = input_resolution
361 | self.dim = dim
362 | self.dim_scale = dim_scale
363 | self.expand = nn.Linear(dim, 16*dim, bias=False)
364 | self.output_dim = dim
365 | self.norm = norm_layer(self.output_dim)
366 |
367 | def forward(self, x):
368 | """
369 | x: B, H*W, C
370 | """
371 | H, W = self.input_resolution
372 | x = self.expand(x)
373 | B, L, C = x.shape
374 | assert L == H * W, "input feature has wrong size"
375 |
376 | x = x.view(B, H, W, C)
377 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
378 | x = x.view(B,-1,self.output_dim)
379 | x= self.norm(x)
380 |
381 | return x
382 |
383 | class BasicLayer(nn.Module):
384 | """ A basic Swin Transformer layer for one stage.
385 |
386 | Args:
387 | dim (int): Number of input channels.
388 | input_resolution (tuple[int]): Input resolution.
389 | depth (int): Number of blocks.
390 | num_heads (int): Number of attention heads.
391 | window_size (int): Local window size.
392 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
393 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
394 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
395 | drop (float, optional): Dropout rate. Default: 0.0
396 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
397 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
398 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
399 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
400 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
401 | """
402 |
403 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
404 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
405 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
406 |
407 | super().__init__()
408 | self.dim = dim
409 | self.input_resolution = input_resolution
410 | self.depth = depth
411 | self.use_checkpoint = use_checkpoint
412 |
413 | # build blocks
414 | self.blocks = nn.ModuleList([
415 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
416 | num_heads=num_heads, window_size=window_size,
417 | shift_size=0 if (i % 2 == 0) else window_size // 2,
418 | mlp_ratio=mlp_ratio,
419 | qkv_bias=qkv_bias, qk_scale=qk_scale,
420 | drop=drop, attn_drop=attn_drop,
421 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
422 | norm_layer=norm_layer)
423 | for i in range(depth)])
424 |
425 | # patch merging layer
426 | if downsample is not None:
427 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
428 | else:
429 | self.downsample = None
430 |
431 | def forward(self, x):
432 | for blk in self.blocks:
433 | if self.use_checkpoint:
434 | x = checkpoint.checkpoint(blk, x)
435 | else:
436 | x = blk(x)
437 | if self.downsample is not None:
438 | x = self.downsample(x)
439 | return x
440 |
441 | def extra_repr(self) -> str:
442 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
443 |
444 | def flops(self):
445 | flops = 0
446 | for blk in self.blocks:
447 | flops += blk.flops()
448 | if self.downsample is not None:
449 | flops += self.downsample.flops()
450 | return flops
451 |
452 | class BasicLayer_up(nn.Module):
453 | """ A basic Swin Transformer layer for one stage.
454 |
455 | Args:
456 | dim (int): Number of input channels.
457 | input_resolution (tuple[int]): Input resolution.
458 | depth (int): Number of blocks.
459 | num_heads (int): Number of attention heads.
460 | window_size (int): Local window size.
461 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
462 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
463 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
464 | drop (float, optional): Dropout rate. Default: 0.0
465 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
466 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
467 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
468 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
469 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
470 | """
471 |
472 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
473 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
474 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
475 |
476 | super().__init__()
477 | self.dim = dim
478 | self.input_resolution = input_resolution
479 | self.depth = depth
480 | self.use_checkpoint = use_checkpoint
481 |
482 | # build blocks
483 | self.blocks = nn.ModuleList([
484 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
485 | num_heads=num_heads, window_size=window_size,
486 | shift_size=0 if (i % 2 == 0) else window_size // 2,
487 | mlp_ratio=mlp_ratio,
488 | qkv_bias=qkv_bias, qk_scale=qk_scale,
489 | drop=drop, attn_drop=attn_drop,
490 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
491 | norm_layer=norm_layer)
492 | for i in range(depth)])
493 |
494 | # patch merging layer
495 | if upsample is not None:
496 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
497 | else:
498 | self.upsample = None
499 |
500 | def forward(self, x):
501 | for blk in self.blocks:
502 | if self.use_checkpoint:
503 | x = checkpoint.checkpoint(blk, x)
504 | else:
505 | x = blk(x)
506 | if self.upsample is not None:
507 | x = self.upsample(x)
508 | return x
509 |
510 | class PatchEmbed(nn.Module):
511 | r""" Image to Patch Embedding
512 |
513 | Args:
514 | img_size (int): Image size. Default: 224.
515 | patch_size (int): Patch token size. Default: 4.
516 | in_chans (int): Number of input image channels. Default: 3.
517 | embed_dim (int): Number of linear projection output channels. Default: 96.
518 | norm_layer (nn.Module, optional): Normalization layer. Default: None
519 | """
520 |
521 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
522 | super().__init__()
523 | img_size = to_2tuple(img_size)
524 | patch_size = to_2tuple(patch_size)
525 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
526 | self.img_size = img_size
527 | self.patch_size = patch_size
528 | self.patches_resolution = patches_resolution
529 | self.num_patches = patches_resolution[0] * patches_resolution[1]
530 |
531 | self.in_chans = in_chans
532 | self.embed_dim = embed_dim
533 |
534 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
535 | if norm_layer is not None:
536 | self.norm = norm_layer(embed_dim)
537 | else:
538 | self.norm = None
539 |
540 | def forward(self, x):
541 | B, C, H, W = x.shape
542 | # FIXME look at relaxing size constraints
543 | assert H == self.img_size[0] and W == self.img_size[1], \
544 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
545 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
546 | if self.norm is not None:
547 | x = self.norm(x)
548 | return x
549 |
550 | def flops(self):
551 | Ho, Wo = self.patches_resolution
552 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
553 | if self.norm is not None:
554 | flops += Ho * Wo * self.embed_dim
555 | return flops
556 |
557 |
558 | class SwinTransformerSys(nn.Module):
559 | r""" Swin Transformer
560 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
561 | https://arxiv.org/pdf/2103.14030
562 |
563 | Args:
564 | img_size (int | tuple(int)): Input image size. Default 224
565 | patch_size (int | tuple(int)): Patch size. Default: 4
566 | in_chans (int): Number of input image channels. Default: 3
567 | num_classes (int): Number of classes for classification head. Default: 1000
568 | embed_dim (int): Patch embedding dimension. Default: 96
569 | depths (tuple(int)): Depth of each Swin Transformer layer.
570 | num_heads (tuple(int)): Number of attention heads in different layers.
571 | window_size (int): Window size. Default: 7
572 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
573 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
574 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
575 | drop_rate (float): Dropout rate. Default: 0
576 | attn_drop_rate (float): Attention dropout rate. Default: 0
577 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
578 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
579 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
580 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
581 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
582 | """
583 |
584 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
585 | embed_dim=512, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
586 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
587 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
588 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
589 | use_checkpoint=False, final_upsample="expand_first", **kwargs):
590 | super().__init__()
591 |
592 | # print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,
593 | # depths_decoder,drop_path_rate,num_classes))
594 | # print("img --- size:", img_size)
595 |
596 | self.num_classes = num_classes
597 | self.num_layers = len(depths)
598 | self.embed_dim = embed_dim
599 | self.ape = ape
600 | self.patch_norm = patch_norm
601 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
602 | self.num_features_up = int(embed_dim * 2)
603 | self.mlp_ratio = mlp_ratio
604 | self.final_upsample = final_upsample
605 |
606 |
607 | # split image into non-overlapping patches
608 | self.patch_embed = PatchEmbed(
609 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
610 | norm_layer=norm_layer if self.patch_norm else None)
611 | num_patches = self.patch_embed.num_patches
612 | patches_resolution = self.patch_embed.patches_resolution
613 | self.patches_resolution = patches_resolution
614 |
615 | # absolute position embedding
616 | if self.ape:
617 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
618 | trunc_normal_(self.absolute_pos_embed, std=.02)
619 |
620 | self.pos_drop = nn.Dropout(p=drop_rate)
621 |
622 | # stochastic depth
623 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
624 |
625 | # build encoder and bottleneck layers
626 | self.layers = nn.ModuleList()
627 | for i_layer in range(self.num_layers):
628 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
629 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
630 | patches_resolution[1] // (2 ** i_layer)),
631 | depth=depths[i_layer],
632 | num_heads=num_heads[i_layer],
633 | window_size=window_size,
634 | mlp_ratio=self.mlp_ratio,
635 | qkv_bias=qkv_bias, qk_scale=qk_scale,
636 | drop=drop_rate, attn_drop=attn_drop_rate,
637 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
638 | norm_layer=norm_layer,
639 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
640 | use_checkpoint=use_checkpoint)
641 | self.layers.append(layer)
642 |
643 | # build decoder layers
644 | self.layers_up = nn.ModuleList()
645 | self.concat_back_dim = nn.ModuleList()
646 | for i_layer in range(self.num_layers):
647 | concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),
648 | int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()
649 | if i_layer ==0 :
650 | layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
651 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
652 | else:
653 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
654 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
655 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
656 | depth=depths[(self.num_layers-1-i_layer)],
657 | num_heads=num_heads[(self.num_layers-1-i_layer)],
658 | window_size=window_size,
659 | mlp_ratio=self.mlp_ratio,
660 | qkv_bias=qkv_bias, qk_scale=qk_scale,
661 | drop=drop_rate, attn_drop=attn_drop_rate,
662 | drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])],
663 | norm_layer=norm_layer,
664 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
665 | use_checkpoint=use_checkpoint)
666 | self.layers_up.append(layer_up)
667 | self.concat_back_dim.append(concat_linear)
668 |
669 | self.norm = norm_layer(self.num_features)
670 | self.norm_up= norm_layer(self.embed_dim)
671 | # self.norm_up= norm_layer(192)
672 |
673 | if self.final_upsample == "expand_first":
674 | # print("---final upsample expand_first---") # FIXME
675 | self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim)
676 | # self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False)
677 | self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=3, stride=2, padding=1, bias=False)
678 |
679 | self.output_2 = nn.Conv2d(in_channels=self.num_classes,out_channels=self.num_classes,kernel_size=3, stride=2, padding=1, bias=False)
680 | # self.output_3 = nn.Conv2d(in_channels=self.num_classes,out_channels=self.num_classes,kernel_size=3, dilation=2, padding=2, bias=False)
681 |
682 | self.apply(self._init_weights)
683 |
684 | def _init_weights(self, m):
685 | if isinstance(m, nn.Linear):
686 | trunc_normal_(m.weight, std=.02)
687 | if isinstance(m, nn.Linear) and m.bias is not None:
688 | nn.init.constant_(m.bias, 0)
689 | elif isinstance(m, nn.LayerNorm):
690 | nn.init.constant_(m.bias, 0)
691 | nn.init.constant_(m.weight, 1.0)
692 |
693 | @torch.jit.ignore
694 | def no_weight_decay(self):
695 | return {'absolute_pos_embed'}
696 |
697 | @torch.jit.ignore
698 | def no_weight_decay_keywords(self):
699 | return {'relative_position_bias_table'}
700 |
701 | #Encoder and Bottleneck
702 | def forward_features(self, x):
703 | x = self.patch_embed(x)
704 | if self.ape:
705 | x = x + self.absolute_pos_embed
706 | x = self.pos_drop(x)
707 | x_downsample = []
708 |
709 | for layer in self.layers:
710 | x_downsample.append(x)
711 | x = layer(x)
712 |
713 | x = self.norm(x) # B L C [1,49,768]
714 |
715 | return x, x_downsample
716 |
717 | #Dencoder and Skip connection
718 | ## FIXME 修改网络结构 保证预训练的参数还可以用 跳过一些层实时
719 | def forward_up_features(self, x, x_downsample):
720 | for inx, layer_up in enumerate(self.layers_up):
721 | # print("swin decoder",self.layers_up)
722 | if inx == 0:
723 | x = layer_up(x)
724 | else:
725 | x = torch.cat([x,x_downsample[3-inx]],-1)
726 | x = self.concat_back_dim[inx](x)
727 | x = layer_up(x)
728 |
729 | x = self.norm_up(x) # B L C [1,3136,96]
730 |
731 | return x
732 |
733 | def up_x4(self, x):
734 | H, W = self.patches_resolution
735 |
736 | B, L, C = x.shape
737 | assert L == H*W, "input features has wrong size"
738 |
739 | if self.final_upsample=="expand_first":
740 | x = self.up(x)
741 | x = x.view(B,4*H,4*W,-1)
742 | # x = x.view(B,H,W,-1)
743 | x = x.permute(0,3,1,2) #B,C,H,W
744 | x = self.output(x) # 1/2
745 | x = self.output_2(x) # 1/4
746 | # x = self.output_2(x) # 1/8 # FIXME 最后一层空洞卷积保持尺寸还是直接去掉
747 | # x = self.output_3(x) # 1/4 # FIXME 这里注意训练的时候保持一致
748 |
749 |
750 | return x
751 |
752 | def forward(self, x):
753 | x, x_downsample = self.forward_features(x)
754 | x = self.forward_up_features(x,x_downsample)
755 | x = self.up_x4(x)
756 | # print("out shape", x.shape)
757 | # x = x.view(1,28,28,-1)
758 | # x = x.permute(0,3,2,1)
759 | # x = self.output(x)
760 | # print("output:", x.shape)
761 | # x = self.up_x4(x)
762 |
763 | return x
764 |
765 | def flops(self):
766 | flops = 0
767 | flops += self.patch_embed.flops()
768 | for i, layer in enumerate(self.layers):
769 | flops += layer.flops()
770 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
771 | flops += self.num_features * self.num_classes
772 | return flops
773 |
--------------------------------------------------------------------------------
/lib/swin_unet.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import copy
7 | import logging
8 | import math
9 |
10 | from os.path import join as pjoin
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 |
16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17 | from torch.nn.modules.utils import _pair
18 | from scipy import ndimage
19 | from .swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class SwinUnet(nn.Module):
25 | def __init__(self, img_size=224, num_classes=21843, zero_head=False, vis=False):
26 | super(SwinUnet, self).__init__()
27 | self.num_classes = num_classes
28 | self.zero_head = zero_head
29 |
30 | self.swin_unet = SwinTransformerSys(img_size=224,
31 | patch_size=4,
32 | in_chans=3,
33 | num_classes=512,
34 | embed_dim=96,
35 | depths=[2, 2, 2, 2],
36 | num_heads=[3, 6, 12, 24],
37 | window_size=7,
38 | mlp_ratio=4,
39 | qkv_bias=True,
40 | qk_scale=None,
41 | drop_rate=0,
42 | drop_path_rate=0.1,
43 | ape=True,
44 | patch_norm=True,
45 | use_checkpoint=True)
46 |
47 | def forward(self, x):
48 | if x.size()[1] == 1:
49 | x = x.repeat(1, 3, 1, 1)
50 | logits = self.swin_unet(x)
51 | return logits
52 |
53 | # def load_from(self, pretrained_path):
54 | # pretrained_path = pretrained_path
55 | # if pretrained_path is not None:
56 | # print("pretrained_path:{}".format(pretrained_path))
57 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58 | # pretrained_dict = torch.load(pretrained_path, map_location=device)
59 | # if "model" not in pretrained_dict:
60 | # print("---start load pretrained modle by splitting---")
61 | # pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()}
62 | # for k in list(pretrained_dict.keys()):
63 | # if "output" in k:
64 | # print("delete key:{}".format(k))
65 | # del pretrained_dict[k]
66 | # msg = self.swin_unet.load_state_dict(pretrained_dict, strict=False)
67 | # # print(msg)
68 | # return
69 | # pretrained_dict = pretrained_dict['model']
70 | # print("---start load pretrained modle of swin encoder---")
71 | #
72 | # model_dict = self.swin_unet.state_dict()
73 | # full_dict = copy.deepcopy(pretrained_dict)
74 | # for k, v in pretrained_dict.items():
75 | # if "layers." in k:
76 | # current_layer_num = 3 - int(k[7:8])
77 | # current_k = "layers_up." + str(current_layer_num) + k[8:]
78 | # full_dict.update({current_k: v})
79 | # for k in list(full_dict.keys()):
80 | # if k in model_dict:
81 | # if full_dict[k].shape != model_dict[k].shape:
82 | # print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape))
83 | # del full_dict[k]
84 | #
85 | # msg = self.swin_unet.load_state_dict(full_dict, strict=False)
86 | # # print(msg)
87 | # else:
88 | # print("none pretrain")
89 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | import numpy as np
4 |
5 | import torch
6 |
7 | from lib.exceptions import EmptyTensorError
8 |
9 |
10 | def preprocess_image(image, preprocessing=None):
11 | image = image.astype(np.float32)
12 | image = np.transpose(image, [2, 0, 1])
13 | if preprocessing is None:
14 | pass
15 | elif preprocessing == 'caffe':
16 | # RGB -> BGR
17 | image = image[:: -1, :, :]
18 | # Zero-center by mean pixel
19 | mean = np.array([103.939, 116.779, 123.68])
20 | image = image - mean.reshape([3, 1, 1])
21 | elif preprocessing == 'torch':
22 | image /= 255.0
23 | mean = np.array([0.559, 0.573, 0.555])
24 | std = np.array([0.185, 0.172, 0.176])
25 | image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1])
26 | else:
27 | raise ValueError('Unknown preprocessing parameter.')
28 | return image
29 |
30 |
31 | def imshow_image(image, preprocessing=None):
32 | if preprocessing is None:
33 | pass
34 | elif preprocessing == 'caffe':
35 | mean = np.array([103.939, 116.779, 123.68])
36 | image = image + mean.reshape([3, 1, 1])
37 | # RGB -> BGR
38 | image = image[:: -1, :, :]
39 | elif preprocessing == 'torch':
40 | # mean = np.array([0.485, 0.456, 0.406])
41 | # std = np.array([0.229, 0.224, 0.225])
42 | mean = np.array([0.559, 0.573, 0.555])
43 | std = np.array([0.185, 0.172, 0.176])
44 | image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1])
45 | image *= 255.0
46 | else:
47 | raise ValueError('Unknown preprocessing parameter.')
48 | image = np.transpose(image, [1, 2, 0])
49 | image = np.round(image).astype(np.uint8)
50 | return image
51 |
52 |
53 | def grid_positions(h, w, device, matrix=False):
54 | lines = torch.arange(
55 | 0, h, device=device
56 | ).view(-1, 1).float().repeat(1, w)
57 | columns = torch.arange(
58 | 0, w, device=device
59 | ).view(1, -1).float().repeat(h, 1)
60 | if matrix:
61 | return torch.stack([lines, columns], dim=0)
62 | else:
63 | return torch.cat([lines.view(1, -1), columns.view(1, -1)], dim=0)
64 |
65 |
66 | def upscale_positions(pos, scaling_steps=0):
67 | for _ in range(scaling_steps):
68 | pos = pos * 2 + 0.5
69 | return pos
70 |
71 |
72 | def downscale_positions(pos, scaling_steps=0):
73 | for _ in range(scaling_steps):
74 | pos = (pos - 0.5) / 2
75 | return pos
76 |
77 |
78 | def interpolate_dense_features(pos, dense_features, return_corners=False):
79 | device = pos.device
80 |
81 | ids = torch.arange(0, pos.size(1), device=device)
82 |
83 | _, h, w = dense_features.size()
84 |
85 | i = pos[0, :]
86 | j = pos[1, :]
87 |
88 | # Valid corners
89 | i_top_left = torch.floor(i).long()
90 | j_top_left = torch.floor(j).long()
91 | valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
92 |
93 | i_top_right = torch.floor(i).long()
94 | j_top_right = torch.ceil(j).long()
95 | valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
96 |
97 | i_bottom_left = torch.ceil(i).long()
98 | j_bottom_left = torch.floor(j).long()
99 | valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
100 |
101 | i_bottom_right = torch.ceil(i).long()
102 | j_bottom_right = torch.ceil(j).long()
103 | valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
104 |
105 | valid_corners = torch.min(
106 | torch.min(valid_top_left, valid_top_right),
107 | torch.min(valid_bottom_left, valid_bottom_right)
108 | )
109 |
110 | i_top_left = i_top_left[valid_corners]
111 | j_top_left = j_top_left[valid_corners]
112 |
113 | i_top_right = i_top_right[valid_corners]
114 | j_top_right = j_top_right[valid_corners]
115 |
116 | i_bottom_left = i_bottom_left[valid_corners]
117 | j_bottom_left = j_bottom_left[valid_corners]
118 |
119 | i_bottom_right = i_bottom_right[valid_corners]
120 | j_bottom_right = j_bottom_right[valid_corners]
121 |
122 | ids = ids[valid_corners]
123 | if ids.size(0) == 0:
124 | raise EmptyTensorError
125 |
126 | # Interpolation
127 | i = i[ids]
128 | j = j[ids]
129 | dist_i_top_left = i - i_top_left.float()
130 | dist_j_top_left = j - j_top_left.float()
131 | w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
132 | w_top_right = (1 - dist_i_top_left) * dist_j_top_left
133 | w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
134 | w_bottom_right = dist_i_top_left * dist_j_top_left
135 |
136 | descriptors = (
137 | w_top_left * dense_features[:, i_top_left, j_top_left] +
138 | w_top_right * dense_features[:, i_top_right, j_top_right] +
139 | w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] +
140 | w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right]
141 | )
142 |
143 | pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
144 |
145 | if not return_corners:
146 | return [descriptors, pos, ids]
147 | else:
148 | corners = torch.stack([
149 | torch.stack([i_top_left, j_top_left], dim=0),
150 | torch.stack([i_top_right, j_top_right], dim=0),
151 | torch.stack([i_bottom_left, j_bottom_left], dim=0),
152 | torch.stack([i_bottom_right, j_bottom_right], dim=0)
153 | ], dim=0)
154 | return [descriptors, pos, ids, corners]
155 |
156 |
157 | def savefig(filepath, fig=None, dpi=None):
158 | # TomNorway - https://stackoverflow.com/a/53516034
159 | if not fig:
160 | fig = plt.gcf()
161 |
162 | plt.subplots_adjust(0, 0, 1, 1, 0, 0)
163 | for ax in fig.axes:
164 | ax.axis('off')
165 | ax.margins(0, 0)
166 | ax.xaxis.set_major_locator(plt.NullLocator())
167 | ax.yaxis.set_major_locator(plt.NullLocator())
168 |
169 | fig.savefig(filepath, pad_inches=0, bbox_inches='tight', dpi=dpi)
170 |
--------------------------------------------------------------------------------
/match_localization.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 | from os.path import exists
4 | import numpy as np
5 | import matchs
6 | from sklearn.neighbors import NearestNeighbors
7 | from tqdm.auto import tqdm
8 | import pandas as pd
9 | import time
10 | import argparse
11 | import cv2
12 | import configparser
13 |
14 | import warnings
15 | warnings.filterwarnings("ignore", category=Warning)
16 |
17 | def parse_gt_file(gtfile):
18 | print('Parsing ground truth data file...')
19 | gtdata = np.load(gtfile)
20 | return gtdata['utmQ'], gtdata['utmDb'], gtdata['posDistThr']
21 |
22 | def get_positives(utmQ, utmDb, posDistThr):
23 | # positives for evaluation are those within trivial threshold range
24 | # fit NN to find them, search by radius
25 | knn = NearestNeighbors(n_jobs=-1)
26 | knn.fit(utmDb)
27 | distances, positives = knn.radius_neighbors(utmQ, radius=posDistThr)
28 |
29 | return positives
30 |
31 | def compute_recall(query_idx_list, gt, predictions, numQ, n_values, recall_str=''):
32 | start2 = time.perf_counter()
33 | correct_at_n = np.zeros(len(n_values))
34 |
35 | for i, pred in enumerate(predictions):
36 | for j, n in enumerate(n_values):
37 | # if in top N then also in top NN, where NN > N
38 | if np.any(np.in1d(pred[:n], gt[query_idx_list[i]])):
39 | correct_at_n[i:] += 1
40 | break
41 | recall_at_n = correct_at_n / numQ
42 | all_recalls = {} # make dict for output
43 | for i, n in enumerate(n_values):
44 | all_recalls[n] = recall_at_n[i]
45 | tqdm.write("====> Recall {}@{}: {:.4f}".format(recall_str, n, recall_at_n[i]))
46 |
47 | recall_time = time.perf_counter()
48 | print('Compute recall time is %6.3f' % (recall_time - start2))
49 |
50 | return all_recalls
51 |
52 | def write_recalls(opt, netvlad_recalls, cnnmatch_recalls, n_values, rec_file):
53 | with open(rec_file, 'w') as res_out:
54 | res_out.write(str(opt)+'\n')
55 | res_out.write("n_values: "+str(n_values)+'\n')
56 | for n in n_values:
57 | res_out.write("Recall {}@{}: {:.4f}\n".format('netvlad_match', n, netvlad_recalls[n]))
58 | res_out.write("Recall {}@{}: {:.4f}\n".format('cnn_match', n, cnnmatch_recalls[n]))
59 |
60 | def image_point2gps_point(keypoint_position: np.array, image_position: np.array, gsd, img_size):
61 | half_img_size = img_size/2
62 | point_position = np.array([image_position[0]+(keypoint_position[0]-half_img_size)*gsd, image_position[1]+(half_img_size - keypoint_position[1])*gsd])
63 | alt = np.zeros(keypoint_position.shape[1])
64 | point_position= np.vstack([*point_position,alt])
65 | return point_position
66 |
67 |
68 |
69 | def main():
70 | parser = argparse.ArgumentParser(description='Patch-NetVLAD-Feature-Match')
71 | parser.add_argument('--config_path', type=str, default='performance.ini',
72 | help='File name (with extension) to an ini file that stores most of the configuration data for cnn-matching')
73 | parser.add_argument('--ground_truth_path', type=str, default='/home/a409/users/huboni/Projects/code/Patch-NetVLAD/patchnetvlad/dataset_gt_files/mavic-xjtu_dist50.npz',
74 | help='ground truth file dist 50')
75 | parser.add_argument('--netvlad_result_path', type=str, default='/home/a409/users/huboni/Projects/code/Patch-NetVLAD/patchnetvlad/results/mavic-xjtu_2048_7_dist50/NetVLAD_predictions.txt',
76 | help='netvlad predictions result path')
77 | parser.add_argument('--reference_origin_path', type=str, default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic-xjtu/reference.csv',
78 | help='reference csv path')
79 | parser.add_argument('--query_origin_path', type=str, default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic-xjtu/query.csv',
80 | help='query csv path')
81 | parser.add_argument('--out_dir', type=str, default='/home/a409/users/huboni/Projects/code/cnn-matching/result/mavic-xjtu/',
82 | help='Dir to save recall')
83 | parser.add_argument('--out_1019_dir', type=str, default='/home/a409/users/huboni/Projects/code/cnn-matching/result/mavic-xjtu/',
84 | help='Dir to save recall')
85 | parser.add_argument('--out_file', type=str, default='cnn_mavic-xjtu_dist50_v50',
86 | help='Dir to save recall')
87 | parser.add_argument('--fig_res_file', type=str, default='pnp_fig_res_npu_dist20.csv')
88 | parser.add_argument('--global_res_file', type=str, default='pnp_fig_res_npu_dist20.csv')
89 | parser.add_argument('--rerank_res_file', type=str, default='pnp_fig_res_npu_dist20.csv')
90 | parser.add_argument('--rerank1_res_file', type=str, default='pnp_fig_res_npu_dist20.csv')
91 | parser.add_argument('--pinpoint_res_file', type=str, default='pnp_fig_res_npu_dist20.csv')
92 |
93 |
94 | parser.add_argument('--model_type', type=str, default='cnn', help='or cnn')
95 | parser.add_argument('--img_size', type=int, default=500, help='or 224')
96 | parser.add_argument('--checkpoint', type=str, default='/home/a409/users/huboni/Projects/code/cnn-matching/models/d2_tf.pth',
97 | help='or ./models/d2_tf.pth')
98 |
99 | opt = parser.parse_args()
100 | print(opt)
101 |
102 | n_values = [1, 5, 10, 20, 50] # FIXME
103 | if opt.img_size == 500:
104 | # gsd = 300 / 500 # alto
105 | # gsd = 300 / 500 # xjtu
106 | gsd = 0.167 # npu
107 | print("gsd:", gsd)
108 | # k = np.array([[345, 0, 247.328], [0, 345, 245], [0, 0, 1]]) # ALTO
109 | # k = np.array([[350, 0, 250], [0, 350, 250], [0, 0, 1]]) # xjtu
110 | k = np.array([[475.48, 0, 250], [0, 475.48, 250], [0, 0, 1]]) # npu
111 | # k = np.array([[3051.6 / (3000/500), 0, 1500/(3000/500)], [0, 3051.6/(3000/500), 1500/(3000/500)], [0, 0, 1]]) # TerraTrack
112 | elif opt.img_size == 224:
113 | gsd = 300 / 500 * (500/224)
114 | print("gsd:", gsd)
115 | k = np.array([[345 / (500/224), 0, 247.328 / (500/224)], [0, 345 / (500/224), 245 / (500/224)], [0, 0, 1]])
116 | num_rec = (n_values[-1])
117 | with open(opt.netvlad_result_path, 'r') as file:
118 | lines = file.readlines()[2:] # 前两行是注释数据
119 | numQ = len(lines)//num_rec
120 | print("numQ:", numQ)
121 |
122 | query_idx_list_all = []
123 | refidx_list = []
124 | refidx_inliners_list = []
125 |
126 | start0 = time.perf_counter()
127 | # rerank by cnn match + flann + ransac
128 | qData = pd.read_csv(opt.query_origin_path)
129 | # qData = qData.sort_values(by='name', ascending=True)
130 | print("qData", qData)
131 | dbData = pd.read_csv(opt.reference_origin_path)
132 | # dbData = dbData.sort_values(by='name', ascending=True)
133 | # print("daData", dbData)
134 | for i in tqdm(range(len(lines)), desc='cnn_match_ing'):
135 |
136 | query_file = lines[i].split(',')[0].strip() # query_file 按照idx的顺序逐个存取的,所以只需要将ref的idx关联起来
137 | query_name = os.path.basename(query_file)
138 | reference_file = lines[i].split(',')[1].strip() # reference_name
139 | # reference_name = os.path.join(reference_file.split('/')[-2], reference_file.split('/')[-1]) # alto
140 | reference_name = os.path.basename(reference_file) # terratrack FIXME
141 | # print("reference_name", reference_name)
142 | query_idx = qData[qData.name == query_name].index.to_list()[0]
143 | query_idx_list_all.append(query_idx)
144 | ref_idx = dbData[dbData.name == reference_name].index.to_list()[0]
145 | inliners, avg_dist, q_kps, r_kps = matchs.cnn_match(query_file, reference_file, opt.model_type, opt.checkpoint, opt.img_size)
146 | # print("-------current avg_list------:", avg_dist)
147 | # print("--------num inliners---------:", inliners)
148 | refidx_list.append(ref_idx)
149 | refidx_inliners_list.append([ref_idx, avg_dist, q_kps, r_kps])
150 | query_idx_list = query_idx_list_all[0: len(query_idx_list_all) : num_rec]
151 | print("-----------q idx list:", query_idx_list)
152 | match_time = time.perf_counter()
153 | print('CNN match time is %6.3f' % (match_time - start0))
154 |
155 | # gen origin input netvlad predictions
156 | refidx_list_split = [refidx_list[i:i + num_rec] for i in range(0, len(refidx_list), num_rec)]
157 | netvlad_predictions = np.array(refidx_list_split)
158 |
159 | # split every query rematch[ref_idx inliners]
160 | refidx_inliners_list_split =[refidx_inliners_list[i:i+num_rec] for i in range(0, len(refidx_inliners_list), num_rec)]
161 | # reranking every query rematch[ref_idx inliners] by inliners avg_dist
162 | start1 = time.perf_counter()
163 | cnnmatch_predictions = []
164 | pinpoint_utms = []
165 | name_l = []
166 | arr_ppl = []
167 | arr_r1l = []
168 | global_candidates_utm = []
169 | rerank_candidates_utm = []
170 | rerank_1_utm = []
171 | pinpoint_utm_list = []
172 | pass_count = 0
173 | fig_csv_file = os.path.join(opt.out_dir, opt.fig_res_file)
174 | global_csv_file = os.path.join(opt.out_1019_dir, opt.global_res_file)
175 | rerank_csv_file = os.path.join(opt.out_1019_dir, opt.rerank_res_file)
176 | rerank1_csv_file = os.path.join(opt.out_1019_dir, opt.rerank1_res_file)
177 | pinpoint_csv_file = os.path.join(opt.out_1019_dir, opt.pinpoint_res_file)
178 |
179 |
180 |
181 | pnp_select_res_dict = pd.DataFrame()
182 | for i, idx_inliners_list in enumerate(refidx_inliners_list_split):
183 | print("-----------------",i)
184 | reranked_rf_idx = np.array(sorted(idx_inliners_list, key=lambda x:x[1], reverse=False))[:, 0][:5] # 按照第二位 降序排列 并输出第一位 FIXME Terratrack
185 | cnnmatch_predictions.append(reranked_rf_idx)
186 | reranked_query_kps = np.array(sorted(idx_inliners_list, key=lambda x:x[1], reverse=False))[:, 2][:5] # FIXME
187 | reranked_ref_kps = np.array(sorted(idx_inliners_list, key=lambda x:x[1], reverse=False))[:, 3][:5] # FIXME
188 | query_utm = np.array(qData.loc[query_idx_list[i], ["easting", "northing"]])
189 | query_utm = query_utm[::-1] # terratrack swap easting northing
190 | print("query name:", qData.loc[query_idx_list[i], ["name"]])
191 |
192 | # print("query name:", qData.loc[query_id, ["name"]])
193 | ref_image_position_o = dbData.loc[reranked_rf_idx.tolist(), ["easting", "northing"]]
194 | # ref_image_position = np.array([ref_image_position_o["easting"], ref_image_position_o["northing"]]).transpose() # alto
195 | ref_image_position = np.array([ref_image_position_o["northing"], ref_image_position_o["easting"]]).transpose() # terratrack
196 |
197 | ref_kpts_position = []
198 | for j, image_position in enumerate(ref_image_position):
199 | ref_kpts_position.append(image_point2gps_point(np.array(reranked_ref_kps[j]).transpose(), image_position, gsd, opt.img_size))
200 | ref_kpts_position = np.hstack(ref_kpts_position).transpose() # 关键点GPS坐标
201 | query_kps = [np.array(ii) for ii in reranked_query_kps]
202 | query_kps_position = np.vstack(query_kps)
203 |
204 | success, R_vec, t, inliers = cv2.solvePnPRansac(ref_kpts_position, query_kps_position, k, np.zeros(4),
205 | flags=cv2.SOLVEPNP_ITERATIVE, iterationsCount=5000,
206 | reprojectionError=10)
207 | # print("R_vec:", R_vec)
208 | print("success:", success)
209 | if not success:
210 | continue
211 | r_w2c, _ = cv2.Rodrigues(R_vec)
212 | t_w2c = t
213 | r_c2w = np.linalg.inv(r_w2c)
214 | t_c2w = -r_c2w @ t_w2c
215 | # print("---------t_c2w-------:", t_c2w)
216 | pinpoint_utm = t_c2w[:, 0][:2]
217 | name_l.append(qData.loc[query_idx_list[i], ["name"]])
218 | pinpoint_utms.append(pinpoint_utm)
219 | # print("pinpoint utm:", pinpoint_utm)
220 | pp_loss = np.linalg.norm(query_utm - pinpoint_utm)
221 | print("pp_loss", pp_loss)
222 | # if pp_loss>10:
223 | # pass_count += 1
224 | # continue
225 | arr_ppl.append(pp_loss)
226 | r1_utm = dbData.loc[reranked_rf_idx.tolist()[0], ["easting", "northing"]]
227 | print("r1 name:", dbData.loc[reranked_rf_idx.tolist()[0], ["name"]])
228 | r2_utm = dbData.loc[reranked_rf_idx.tolist()[1], ["easting", "northing"]]
229 | r3_utm = dbData.loc[reranked_rf_idx.tolist()[2], ["easting", "northing"]]
230 | r4_utm = dbData.loc[reranked_rf_idx.tolist()[3], ["easting", "northing"]]
231 | r5_utm = dbData.loc[reranked_rf_idx.tolist()[4], ["easting", "northing"]]
232 | # swap easting northing below
233 | r1_loss = np.linalg.norm(query_utm - np.array(r1_utm)[::-1])
234 | r2_loss = np.linalg.norm(query_utm - np.array(r2_utm)[::-1])
235 | r3_loss = np.linalg.norm(query_utm - np.array(r3_utm)[::-1])
236 | r4_loss = np.linalg.norm(query_utm - np.array(r4_utm)[::-1])
237 | r5_loss = np.linalg.norm(query_utm - np.array(r5_utm)[::-1])
238 |
239 | r_mean_loss = sum([r1_loss,r2_loss,r3_loss,r4_loss,r5_loss])/5
240 | arr_r1l.append(r1_loss)
241 | print("localization loss:", pp_loss)
242 | print("recall@1 loss:", r1_loss)
243 | # save csv for localization fig
244 | # if r_mean_loss-pp_loss>=50.0 and pp_loss<2.0:
245 | # if r1_loss-pp_loss>10.0:
246 | if True:
247 | global_candidates_utm.extend((dbData.loc[np.array(idx_inliners_list)[:,0][:5], ["easting", "northing"]]).to_numpy().tolist())
248 | rerank_candidates_utm.extend((dbData.loc[reranked_rf_idx.tolist()[:5], ["easting", "northing"]]).to_numpy().tolist())
249 | rerank_1_utm.append((dbData.loc[reranked_rf_idx.tolist()[0], ["easting", "northing"]]).to_numpy().tolist())
250 | pinpoint_utm_list.append(pinpoint_utm.tolist())
251 |
252 | fig_res = dict(query_name = (qData.loc[query_idx_list[i], ["name"]]).to_numpy(),
253 | global_recall5_idx = (np.array(idx_inliners_list)[:, 0][:5]), global_recall5_name = (dbData.loc[np.array(idx_inliners_list)[:, 0][:5],["name"]]).to_numpy(),
254 | global_recall5_utm = (dbData.loc[np.array(idx_inliners_list)[:,0][:5], ["easting", "northing"]]).to_numpy(),
255 | rerank_recall5_idx = (reranked_rf_idx.tolist()[:5]), rerank_recall5_name = (dbData.loc[reranked_rf_idx.tolist()[:5], ["name"]]).to_numpy(),
256 | rerank_recall5_utm = (dbData.loc[reranked_rf_idx.tolist()[:5], ["easting", "northing"]]).to_numpy(),
257 | rerank_recall1_idx = (reranked_rf_idx.tolist()[0]), rerank_recall1_name = (dbData.loc[reranked_rf_idx.tolist()[:1], ["name"]]).to_numpy(),
258 | rerank_recall1_utm = (dbData.loc[reranked_rf_idx.tolist()[0], ["easting", "northing"]]).to_numpy(),
259 | pinpoint_utm = pinpoint_utm,
260 | query_utm = query_utm, r1_loss = r1_loss, pp_loss=pp_loss)
261 | # print(fig_res)
262 | pnp_select_res_dict = pnp_select_res_dict._append(fig_res, ignore_index=True)
263 |
264 | pd.DataFrame(pnp_select_res_dict).to_csv(fig_csv_file, index=True, encoding='gbk',float_format='%.6f')
265 | pd.DataFrame(global_candidates_utm).to_csv(global_csv_file, index=True, encoding='gbk',float_format='%.6f')
266 | pd.DataFrame(rerank_candidates_utm).to_csv(rerank_csv_file, index=True, encoding='gbk',float_format='%.6f')
267 | pd.DataFrame(rerank_1_utm).to_csv(rerank1_csv_file, index=True, encoding='gbk',float_format='%.6f')
268 | pd.DataFrame(pinpoint_utm_list).to_csv(pinpoint_csv_file, index=True, encoding='gbk',float_format='%.6f')
269 |
270 |
271 |
272 | # print("global candidates list:", global_candidates_utm)
273 | # print("rerank candidates list:", rerank_candidates_utm)
274 | # print("rerank 1 utm list:", rerank_1_utm)
275 | # print("pinpoint utm list", pinpoint_utm_list)
276 |
277 | print("pass count:", pass_count)
278 | print("avg pinpoint loss:", np.mean(np.array(arr_ppl)))
279 | print("avg recall@1 loss:", np.mean(np.array(arr_r1l)))
280 | print("var pinpoint loss:", np.var(np.array(arr_ppl)))
281 | print("var recall@1 loss:", np.var(np.array(arr_r1l)))
282 | print("std pinpoint loss:", np.std(np.array(arr_ppl)))
283 | print("std recall@1 loss:", np.std(np.array(arr_r1l)))
284 | # 保存pinpoint.csv
285 | pinpoint_dict = {"easting":np.array(pinpoint_utms)[:,0],
286 | "northing":np.array(pinpoint_utms)[:,1],
287 | "name":name_l}
288 |
289 | out_csv_file = os.path.join(opt.out_file + '.csv')
290 | print("pinpoint utm save to %s" % out_csv_file)
291 | pd.DataFrame(pinpoint_dict).to_csv(os.path.join(opt.out_dir, 'pinpoint_utm', out_csv_file), index=False)
292 |
293 |
294 | cnnmatch_predictions = np.array(cnnmatch_predictions)
295 | cnn_pred_time = time.perf_counter()
296 | print('rerank match get cnn predict time is %6.3f' % (cnn_pred_time - start1))
297 |
298 | # get ground truth
299 | utmQ,utmDb,posDistThr = parse_gt_file(opt.ground_truth_path)
300 | gt = get_positives(utmQ, utmDb, posDistThr)
301 |
302 | netvlad_recalls = compute_recall(query_idx_list, gt, netvlad_predictions, numQ, n_values, 'netvlad_match')
303 | cnnmatch_recalls = compute_recall(query_idx_list, gt, cnnmatch_predictions, numQ, n_values, 'cnn_match')
304 |
305 | # out_recall_file = os.path.join(opt.out_dir, 'recalls', opt.checkpoint.split('/')[-3] + '_' + str(opt.img_size) + '.txt')
306 | out_recall_file = os.path.join(opt.out_dir, 'recalls', opt.out_file + '.txt')
307 |
308 | print('Writing recalls to', out_recall_file)
309 | write_recalls(opt, netvlad_recalls, cnnmatch_recalls, n_values, out_recall_file)
310 |
311 |
312 |
313 |
314 | if __name__ == "__main__":
315 | main()
316 |
317 |
--------------------------------------------------------------------------------
/matching.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import imageio
5 | import plotmatch
6 | from lib.cnn_feature import cnn_feature_extract
7 | import matplotlib.pyplot as plt
8 | import time
9 | from skimage import measure
10 | from skimage import transform
11 | import math
12 | import os
13 | from os.path import exists
14 | from os.path import join
15 |
16 |
17 | #time count
18 | start = time.perf_counter()
19 |
20 | _RESIDUAL_THRESHOLD = 20
21 |
22 | imgfile1 = '/home/a409/users/huboni/Paper/locally_global_match_pnp/fig/fig-matching/q000000.png'
23 | imgfile2 = '/home/a409/users/huboni/Paper/locally_global_match_pnp/fig/fig-matching/r000000.png'
24 |
25 | start = time.perf_counter()
26 |
27 | # read left image
28 | image1_o = cv2.imread(imgfile1)
29 | image1 = cv2.resize(image1_o, (224,224))
30 | image2_o = cv2.imread(imgfile2)
31 | image2 = cv2.resize(image2_o, (224,224))
32 | print('read image time is %6.3f' % (time.perf_counter() - start))
33 |
34 | start0 = time.perf_counter()
35 |
36 | kps_left, sco_left, des_left = cnn_feature_extract(image1_o, nfeatures = -1)
37 | kps_right, sco_right, des_right = cnn_feature_extract(image2_o, nfeatures = -1)
38 |
39 | print('Feature_extract time is %6.3f, left: %6.3f,right %6.3f' % ((time.perf_counter() - start), len(kps_left), len(kps_right)))
40 | start = time.perf_counter()
41 |
42 | #Flann特征匹配
43 | FLANN_INDEX_KDTREE = 1
44 | index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
45 | search_params = dict(checks=40)
46 | flann = cv2.FlannBasedMatcher(index_params, search_params)
47 | matches = flann.knnMatch(des_left, des_right, k=2)
48 | matches_reverse = flann.knnMatch(des_right, des_left, k=2)
49 |
50 |
51 | goodMatch = []
52 | locations_1_to_use = []
53 | locations_2_to_use = []
54 |
55 | disdif_avg = 0
56 | for m, n in matches:
57 | disdif_avg += n.distance - m.distance
58 | disdif_avg = disdif_avg / len(matches)
59 |
60 | for m, n in matches:
61 | if n.distance > m.distance + disdif_avg and matches_reverse[m.trainIdx][1].distance > matches_reverse[m.trainIdx][0].distance+disdif_avg and matches_reverse[m.trainIdx][0].trainIdx == m.queryIdx:
62 | goodMatch.append(m)
63 | p2 = cv2.KeyPoint(kps_right[m.trainIdx][0], kps_right[m.trainIdx][1], 1)
64 | p1 = cv2.KeyPoint(kps_left[m.queryIdx][0], kps_left[m.queryIdx][1], 1)
65 | locations_1_to_use.append([p1.pt[0], p1.pt[1]])
66 | locations_2_to_use.append([p2.pt[0], p2.pt[1]])
67 | print('match num is %d' % len(goodMatch))
68 | locations_1_to_use = np.array(locations_1_to_use)
69 | locations_2_to_use = np.array(locations_2_to_use)
70 |
71 | # Perform geometric verification using RANSAC.
72 | _, inliers = measure.ransac((locations_1_to_use, locations_2_to_use),
73 | transform.AffineTransform,
74 | min_samples=3,
75 | residual_threshold=_RESIDUAL_THRESHOLD,
76 | max_trials=1000)
77 |
78 | print('Found %d inliers' % sum(inliers))
79 |
80 | inlier_idxs = np.nonzero(inliers)[0]
81 | #最终匹配结果
82 | matches = np.column_stack((inlier_idxs, inlier_idxs))
83 | print('whole time is %6.3f' % (time.perf_counter() - start0))
84 |
85 | # save inliners 像素坐标
86 | coordinate_out_dir = '/home/a409/users/chenlin/VPR_huboni/Val/GPR_Dataset/VPR_results/'
87 | coordinate_out_ref = []
88 | coordinate_out_query = []
89 |
90 | dist = 0
91 | for i in range(matches.shape[0]):
92 | idx1 = matches[i, 0]
93 | idx2 = matches[i, 1]
94 | # 计算两幅匹配图像最终匹配点对的像素距离
95 | dist += math.sqrt(((locations_1_to_use[idx1,0]-locations_2_to_use[idx2,0])**2)+((locations_1_to_use[idx1,1]-locations_2_to_use[idx2,1])**2))
96 | avg_dist = dist/matches.shape[0]
97 | print("avg_dist:", avg_dist)
98 | # coordinate_out_query.append([locations_1_to_use[idx1, 0], locations_1_to_use[idx1, 1]])
99 | # coordinate_out_ref.append([locations_2_to_use[idx2, 0], locations_2_to_use[idx2, 1]])
100 |
101 | # 输出最终两幅匹配图像匹配点对坐标
102 | # if not exists(coordinate_out_dir):
103 | # os.mkdir(coordinate_out_dir)
104 | # out_file = join(coordinate_out_dir, 'coordinate.txt')
105 | # print('Writing recalls to', out_file)
106 | # with open(out_file, 'a+') as res_out:
107 | # res_out.write(imgfile1 + '\n' + str(coordinate_out_query) + '\n')
108 | # res_out.write(imgfile2 + '\n' + str(coordinate_out_ref) + '\n')
109 |
110 | # Visualize correspondences, and save to file.
111 | #1 绘制匹配连线
112 | plt.rcParams['savefig.dpi'] = 500 #图片像素
113 | plt.rcParams['figure.dpi'] = 500 #分辨率
114 | plt.rcParams['figure.figsize'] = (3.0, 2.0) # 设置figure_size尺寸
115 | _, ax = plt.subplots()
116 | plotmatch.plot_matches(
117 | ax,
118 | image1_o,
119 | image2_o,
120 | locations_1_to_use,
121 | locations_2_to_use,
122 | np.column_stack((inlier_idxs, inlier_idxs)),
123 | plot_matche_points = False,
124 | matchline = True,
125 | matchlinewidth = 0.1)
126 | ax.axis('off')
127 | ax.set_title('')
128 | # plt.show()
129 | plt.savefig('/home/a409/users/huboni/Paper/locally_global_match_pnp/fig/fig-matching/0.png')
--------------------------------------------------------------------------------
/performance.ini:
--------------------------------------------------------------------------------
1 | [path]
2 | input_predictions = '/home/a409/users/huboni/Projects/code/Patch-NetVLAD/patchnetvlad/results/mavic_npu_2048_7_dist20/NetVLAD_predictions.txt'
3 | ground_truth = '/home/a409/users/huboni/Projects/code/Patch-NetVLAD/patchnetvlad/dataset_gt_files/mavic_npu_dist20.npz'
4 | reference_csv = '/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/reference.csv'
5 | out_dir = '/home/a409/users/huboni/Projects/code/cnn-matching/result/TerraTrack_rs/'
6 |
7 | [feature_match]
8 | n_values_all = 1,5,10,20,50
9 |
--------------------------------------------------------------------------------
/plotmatch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def plot_matches(ax, image1, image2, keypoints1, keypoints2, matches,
5 | keypoints_color='r', matches_color=None, plot_matche_points=True, matchline = True, matchlinewidth = 0.5,
6 | alignment='horizontal'):
7 | """Plot matched features.
8 |
9 | Parameters
10 | ----------
11 | ax : matplotlib.axes.Axes
12 | Matches and image are drawn in this ax.
13 | image1 : (N, M [, 3]) array
14 | First grayscale or color image.
15 | image2 : (N, M [, 3]) array
16 | Second grayscale or color image.
17 | keypoints1 : (K1, 2) array
18 | First keypoint coordinates as ``(row, col)``.
19 | keypoints2 : (K2, 2) array
20 | Second keypoint coordinates as ``(row, col)``.
21 | matches : (Q, 2) array
22 | Indices of corresponding matches in first and second set of
23 | descriptors, where ``matches[:, 0]`` denote the indices in the first
24 | and ``matches[:, 1]`` the indices in the second set of descriptors.
25 | keypoints_color : matplotlib color, optional
26 | Color for keypoint locations.
27 | matches_color : matplotlib color, optional
28 | Color for lines which connect keypoint matches. By default the
29 | color is chosen randomly.
30 | only_matches : bool, optional
31 | Whether to only plot matches and not plot the keypoint locations.
32 | alignment : {'horizontal', 'vertical'}, optional
33 | Whether to show images side by side, ``'horizontal'``, or one above
34 | the other, ``'vertical'``.
35 |
36 | """
37 |
38 | #image1 = img_as_float(image1)
39 | #image2 = img_as_float(image2)
40 |
41 | new_shape1 = list(image1.shape)
42 | new_shape2 = list(image2.shape)
43 |
44 | if image1.shape[0] < image2.shape[0]:
45 | new_shape1[0] = image2.shape[0]
46 | elif image1.shape[0] > image2.shape[0]:
47 | new_shape2[0] = image1.shape[0]
48 |
49 | if image1.shape[1] < image2.shape[1]:
50 | new_shape1[1] = image2.shape[1]
51 | elif image1.shape[1] > image2.shape[1]:
52 | new_shape2[1] = image1.shape[1]
53 |
54 | if new_shape1 != image1.shape:
55 | #new_image1 = np.zeros(new_shape1, dtype=image1.dtype)
56 | new_image1 = np.full(new_shape1, 255)
57 | new_image1[:image1.shape[0], :image1.shape[1]] = image1
58 | image1 = new_image1
59 |
60 | if new_shape2 != image2.shape:
61 | #new_image2 = np.zeros(new_shape2, dtype=image2.dtype)
62 | new_image2 = np.full(new_shape2, 255)
63 | new_image2[:image2.shape[0], :image2.shape[1]] = image2
64 | image2 = new_image2
65 |
66 | offset = np.array(image1.shape)
67 | if alignment == 'horizontal':
68 | if image2.ndim == 3:
69 | blank = np.full((new_shape2[0], 10, 3), 255)
70 | if image1.ndim == 2 or image2.ndim == 2:
71 | blank = np.full((new_shape2[0], 10), 255)
72 | image = np.concatenate([image1, blank, image2], axis=1)
73 | offset[0] = 0
74 | offset[1] += 10
75 | elif alignment == 'vertical':
76 | if image2.ndim == 3 :
77 | blank = np.full(10,(new_shape2[1], 3), 255)
78 | if image1.ndim == 2:
79 | blank = np.full(10,(new_shape2[1]), 255)
80 | image = np.concatenate([image1, blank, image2], axis=0)
81 | offset[1] = 0
82 | offset[0] += 10
83 | else:
84 | mesg = ("plot_matches accepts either 'horizontal' or 'vertical' for "
85 | "alignment, but '{}' was given. See "
86 | "https://scikit-image.org/docs/dev/api/skimage.feature.html#skimage.feature.plot_matches " # noqa
87 | "for details.").format(alignment)
88 | raise ValueError(mesg)
89 |
90 |
91 | if plot_matche_points:
92 | ax.scatter(keypoints1[:, 0], keypoints1[:, 1],
93 | facecolors='none', edgecolors=keypoints_color, marker = '.')
94 | ax.scatter(keypoints2[:, 0] + offset[1], keypoints2[:, 1] + offset[0],
95 | facecolors='none', edgecolors=keypoints_color, marker = '.')
96 |
97 | ax.imshow(image, interpolation='nearest', cmap='gray')
98 | ax.axis((0, image1.shape[1] + offset[1], image1.shape[0] + offset[0], 0))
99 |
100 | if matchline == True:
101 | for i in range(matches.shape[0]):
102 | idx1 = matches[i, 0]
103 | idx2 = matches[i, 1]
104 |
105 | if matches_color is None:
106 | color = np.random.rand(3)
107 | else:
108 | color = matches_color
109 |
110 | ax.plot((keypoints1[idx1, 0], keypoints2[idx2, 0] + offset[1]),
111 | (keypoints1[idx1, 1], keypoints2[idx2, 1] + offset[0]),
112 | '-', color=color, linewidth=matchlinewidth, marker='+', markersize=4)
113 |
114 |
115 | def plot_matches2(ax, image1, image2, keypoints1, keypoints2,
116 | keypoints_color='r', matches_color=None, plot_matche_points=True, matchline = True, matchlinewidth = 0.5,
117 | alignment='horizontal'):
118 |
119 |
120 | new_shape1 = list(image1.shape)
121 | new_shape2 = list(image2.shape)
122 |
123 | if image1.shape[0] < image2.shape[0]:
124 | new_shape1[0] = image2.shape[0]
125 | elif image1.shape[0] > image2.shape[0]:
126 | new_shape2[0] = image1.shape[0]
127 |
128 | if image1.shape[1] < image2.shape[1]:
129 | new_shape1[1] = image2.shape[1]
130 | elif image1.shape[1] > image2.shape[1]:
131 | new_shape2[1] = image1.shape[1]
132 |
133 | if new_shape1 != image1.shape:
134 | #new_image1 = np.zeros(new_shape1, dtype=image1.dtype)
135 | new_image1 = np.full(new_shape1, 255)
136 | new_image1[:image1.shape[0], :image1.shape[1]] = image1
137 | image1 = new_image1
138 |
139 | if new_shape2 != image2.shape:
140 | #new_image2 = np.zeros(new_shape2, dtype=image2.dtype)
141 | new_image2 = np.full(new_shape2, 255)
142 | new_image2[:image2.shape[0], :image2.shape[1]] = image2
143 | image2 = new_image2
144 |
145 | offset = np.array(image1.shape)
146 | if alignment == 'horizontal':
147 | if image2.ndim == 3:
148 | blank = np.full((new_shape2[0], 10, 3), 255)
149 | if image1.ndim == 2 or image2.ndim == 2:
150 | blank = np.full((new_shape2[0], 10), 255)
151 | image = np.concatenate([image1, blank, image2], axis=1)
152 | offset[0] = 0
153 | offset[1] += 10
154 | elif alignment == 'vertical':
155 | if image2.ndim == 3 :
156 | blank = np.full(10,(new_shape2[1], 3), 255)
157 | if image1.ndim == 2:
158 | blank = np.full(10,(new_shape2[1]), 255)
159 | image = np.concatenate([image1, blank, image2], axis=0)
160 | offset[1] = 0
161 | offset[0] += 10
162 | else:
163 | mesg = ("plot_matches accepts either 'horizontal' or 'vertical' for "
164 | "alignment, but '{}' was given. See "
165 | "https://scikit-image.org/docs/dev/api/skimage.feature.html#skimage.feature.plot_matches " # noqa
166 | "for details.").format(alignment)
167 | raise ValueError(mesg)
168 |
169 |
170 | if plot_matche_points:
171 | ax.scatter(keypoints1[:, 0], keypoints1[:, 1],
172 | facecolors='none', edgecolors=keypoints_color, marker = '.')
173 | ax.scatter(keypoints2[:, 0] + offset[1], keypoints2[:, 1] + offset[0],
174 | facecolors='none', edgecolors=keypoints_color, marker = '.')
175 |
176 | ax.imshow(image, interpolation='nearest', cmap='gray')
177 | ax.axis((0, image1.shape[1] + offset[1], image1.shape[0] + offset[0], 0))
178 |
179 | if matchline == True:
180 | for i in range(keypoints1.shape[0]):
181 |
182 | if matches_color is None:
183 | color = np.random.rand(3)
184 | else:
185 | color = matches_color
186 |
187 | ax.plot((keypoints1[i, 0], keypoints2[i, 0] + offset[1]),
188 | (keypoints1[i, 1], keypoints2[i, 1] + offset[0]),
189 | '-', color=color, linewidth=matchlinewidth, marker='+', markersize=2)
190 |
191 |
192 |
193 |
--------------------------------------------------------------------------------
/result/cv-match.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from matplotlib import pyplot as plt
4 |
5 | imgname1 = '/mnt/a409/users/chenlin/VPR_huboni/Val/GPR_Dataset/query_images/000000.png'
6 | imgname2 = '/mnt/a409/users/chenlin/VPR_huboni/Val/GPR_Dataset/reference_images/offset_0_None/000000.png'
7 |
8 | sift = cv2.SIFT_create()
9 |
10 | # FLANN 参数设计
11 | FLANN_INDEX_KDTREE = 0
12 | index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
13 | search_params = dict(checks=50)
14 | flann = cv2.FlannBasedMatcher(index_params,search_params)
15 |
16 | img1 = cv2.imread(imgname1)
17 | gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) #灰度处理图像
18 | kp1, des1 = sift.detectAndCompute(img1,None)#des是描述子
19 |
20 | img2 = cv2.imread(imgname2)
21 | gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
22 | kp2, des2 = sift.detectAndCompute(img2,None)
23 |
24 | hmerge = np.hstack((gray1, gray2)) #水平拼接
25 | cv2.imshow("./gray.png", hmerge) #拼接显示为gray
26 | cv2.waitKey(0)
27 |
28 | img3 = cv2.drawKeypoints(img1,kp1,img1,color=(255,0,255))
29 | img4 = cv2.drawKeypoints(img2,kp2,img2,color=(255,0,255))
30 |
31 | hmerge = np.hstack((img3, img4)) #水平拼接
32 | cv2.imshow("./point.png", hmerge) #拼接显示为gray
33 | cv2.waitKey(0)
34 | matches = flann.knnMatch(des1,des2,k=2)
35 | matchesMask = [[0,0] for i in range(len(matches))]
36 |
37 | good = []
38 | for m,n in matches:
39 | if m.distance < 0.1*n.distance:
40 | good.append([m])
41 |
42 | img5 = cv2.drawMatchesKnn(img1,kp1,img2,kp2,matches,None,flags=2)
43 | cv2.imshow("./FLANN.png", img5)
44 | cv2.waitKey(0)
45 | cv2.destroyAllWindows()
--------------------------------------------------------------------------------
/result/plot_rs.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 |
5 |
6 |
7 | x_axis_data = [1,5,10,20,50,100,200]
8 | y_dist20_patch = [0.2043,0.3290,0.3545,0.3622,0.3697,0.3355,0.3314]
9 | y_dist20_swin = [0.2043,0.5629,0.5806,0.5987,0.6259,0.6021,0.5903]
10 | y_dist50_patch = [0.6200,0.8023,0.8266,0.8468,0.8504,0.8456,0.8325]
11 | y_dist50_swin = [0.6200,0.8634,0.9030,0.9231,0.9448,0.9293,0.9086]
12 |
13 |
14 | #画图
15 |
16 | plt.plot(x_axis_data, y_dist20_patch, 'go--', alpha=0.5, linewidth=1, label='Patch-NetVLAD-2048 dist20')#'
17 | plt.plot(x_axis_data, y_dist20_swin, 'ms-', alpha=0.5, linewidth=1, label='CurriculumLoc(ours) dist20')
18 | plt.plot(x_axis_data, y_dist50_patch, 'bo--', alpha=0.5, linewidth=1, label='Patch-NetVLAD-2048 dist50')
19 | plt.plot(x_axis_data, y_dist50_swin, 'rs-', alpha=0.5, linewidth=1, label='CurriculumLoc(ours) dist50')
20 | plt.xticks([1,5,10,20,50,100,200])
21 | plt.yticks([0,0.2,0.4,0.6,0.8,1.0])
22 |
23 | for a, b in zip(x_axis_data, y_dist20_patch):
24 | plt.text(a, b, str(b), ha='center', va='bottom', fontsize=8) # ha='center', va='top'
25 | for a, b1 in zip(x_axis_data, y_dist20_swin):
26 | plt.text(a, b1, str(b1), ha='center', va='bottom', fontsize=8)
27 | for a, b2 in zip(x_axis_data, y_dist50_patch):
28 | plt.text(a, b2, str(b2), ha='center', va='bottom', fontsize=8)
29 | for a, b3 in zip(x_axis_data, y_dist50_swin):
30 | plt.text(a, b3, str(b3), ha='center', va='bottom', fontsize=8)
31 |
32 |
33 | plt.legend() #显示上面的label
34 | plt.xlabel('Candidates number')
35 | plt.ylabel('rerank R@1')#accuracy
36 |
37 |
38 | plt.savefig("./plot_rs.png")
39 | plt.show()
40 |
41 |
--------------------------------------------------------------------------------
/terratrack_utils/compute_mean_std.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import matplotlib.pyplot as plt
4 | from matplotlib.pyplot import imread
5 | import numpy as np
6 | # from scipy.misc import imread
7 | filepath = '/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/query_images_512' # 数据集目录
8 | pathDir = os.listdir(filepath)
9 | R_channel = 0
10 | G_channel = 0
11 | B_channel = 0
12 | for idx in range(len(pathDir)):
13 | filename = pathDir[idx]
14 | img = imread(os.path.join(filepath, filename)) / 255.0
15 | R_channel = R_channel + np.sum(img[:, :, 0])
16 | G_channel = G_channel + np.sum(img[:, :, 1])
17 | B_channel = B_channel + np.sum(img[:, :, 2])
18 | num = len(pathDir) * 512 * 512 # 这里(512,512)是每幅图片的大小,所有图片尺寸都一样
19 | R_mean = R_channel / num
20 | G_mean = G_channel / num
21 | B_mean = B_channel / num
22 | R_channel = 0
23 | G_channel = 0
24 | B_channel = 0
25 | for idx in range(len(pathDir)):
26 | filename = pathDir[idx]
27 | img = imread(os.path.join(filepath, filename)) / 255.0
28 | R_channel = R_channel + np.sum((img[:, :, 0] - R_mean) ** 2)
29 | G_channel = G_channel + np.sum((img[:, :, 1] - G_mean) ** 2)
30 | B_channel = B_channel + np.sum((img[:, :, 2] - B_mean) ** 2)
31 | R_var = np.sqrt(R_channel / num)
32 | G_var = np.sqrt(G_channel / num)
33 | B_var = np.sqrt(B_channel / num)
34 | print("R_mean is %f, G_mean is %f, B_mean is %f" % (R_mean, G_mean, B_mean))
35 | print("R_var is %f, G_var is %f, B_var is %f" % (R_var, G_var, B_var))
36 |
--------------------------------------------------------------------------------
/terratrack_utils/depth_bin2array.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import warnings
4 | import os
5 | import h5py
6 | import argparse
7 |
8 |
9 | '''
10 | convert colmap depth map from .bin to .png or to .h5
11 | '''
12 | warnings.filterwarnings('ignore') # 屏蔽nan与min_depth比较时产生的警告
13 |
14 | # camnum = 12
15 | # fB = 32504;
16 | min_depth_percentile = 2
17 | max_depth_percentile = 98
18 |
19 | parser = argparse.ArgumentParser(description='convert depth bin to array and save h5')
20 |
21 | parser.add_argument(
22 | '--depthmapsdir', type=str,
23 | help='path to the origin colmap output depth',
24 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/query_reference_SfM_model_500/dense/stereo/depth_maps'
25 | )
26 | parser.add_argument(
27 | '--output_h5_path', type=str,
28 | help='path to the save h5 depth',
29 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/query_reference_SfM_model_500/dense/stereo/depth_bin_h5'
30 | )
31 | args = parser.parse_args()
32 |
33 |
34 | if not os.path.exists(args.output_h5_path):
35 | os.mkdir(args.output_h5_path)
36 |
37 | def read_array(path):
38 | with open(path, "rb") as fid:
39 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
40 | usecols=(0, 1, 2), dtype=int)
41 | # print("width:", width)
42 | # print("height:", height)
43 | # print("channels:", channels)
44 | fid.seek(0)
45 | num_delimiter = 0
46 | byte = fid.read(1)
47 | while True:
48 | if byte == b"&":
49 | num_delimiter += 1
50 | if num_delimiter >= 3:
51 | break
52 | byte = fid.read(1)
53 | array = np.fromfile(fid, np.float32)
54 | # print(array.shape)
55 | array = array.reshape((width, height, channels), order="F")
56 | return np.transpose(array, (1, 0, 2)).squeeze()
57 |
58 | def bin2depth(inputdepth, output_h5_path):
59 | # depth_map = '0.png.geometric.bin'
60 | # print(depthdir)
61 | # if min_depth_percentile > max_depth_percentile:
62 | # raise ValueError("min_depth_percentile should be less than or equal "
63 | # "to the max_depth_perceintile.")
64 |
65 | # Read depth and normal maps corresponding to the same image.
66 | if not os.path.exists(inputdepth):
67 | raise FileNotFoundError("file not found: {}".format(inputdepth))
68 |
69 | # np.set_printoptions(threshold=np.inf)
70 |
71 | depth_map = read_array(inputdepth)
72 | # depth_map[depth_map<=0] = 0
73 | min_depth, max_depth = np.percentile(depth_map[depth_map>0], [min_depth_percentile, max_depth_percentile])
74 | depth_map[depth_map <= 0] = np.nan # 把0和负数都设置为nan,防止被min_depth取代
75 | depth_map[depth_map < min_depth] = min_depth
76 | depth_map[depth_map > max_depth] = max_depth
77 | # depth_map[depth_map<=0] = 0
78 | depth_map = np.nan_to_num(depth_map)
79 | # print("input_bin_path:", inputdepth)
80 | depth_map_shape = depth_map.shape[0]*depth_map.shape[1]
81 | if depth_map_shape<100:
82 | print("depth_map_shape < 480*480!!!! and is:", depth_map_shape)
83 | print("depth name:", inputdepth)
84 | print(np.any(depth_map<0)) # 深度图存在负数 存在0, 取值在0~1之间
85 |
86 | # save depth as h5 FIXME
87 | h5_path = os.path.join(output_h5_path, '.'.join(os.path.basename(inputdepth).split('.')[:-1])+'.h5')
88 | with h5py.File(h5_path, 'w') as f:
89 | f.create_dataset('depth', data=depth_map)
90 |
91 | # bin 2 png
92 | # min_depth, max_depth = np.percentile(depth_map[depth_map>0], [min_depth_percentile, max_depth_percentile])
93 | # depth_map[depth_map <= 0] = np.nan # 把0和负数都设置为nan,防止被min_depth取代
94 | # depth_map[depth_map < min_depth] = min_depth
95 | # depth_map[depth_map > max_depth] = max_depth
96 |
97 | # maxdisp = fB / min_depth;
98 | # mindisp = fB / max_depth;
99 | # depth_map = (fB/depth_map - mindisp) * 255 / (maxdisp - mindisp);
100 | # depth_map = np.nan_to_num(depth_map*255) # nan全都变为0
101 | # depth_map = depth_map.astype(int)
102 |
103 | # image = Image.fromarray(np.uint8(depth_map)).convert('L')
104 | # image = image.resize((500, 500), Image.ANTIALIAS) # 保证resize为500*500
105 | # print(image)
106 | # print(np.array(image))
107 | # ouputdepth = os.path.join(outputdir, '.'.join(os.path.basename(inputdepth).split('.')[:-1])+'.png')
108 | # print(ouputdepth)
109 | # image.save(ouputdepth)
110 |
111 |
112 | for depthbin in os.listdir(args.depthmapsdir):
113 | inputdepth = os.path.join(args.depthmapsdir, depthbin)
114 | if os.path.exists(inputdepth):
115 | bin2depth(inputdepth, args.output_h5_path)
--------------------------------------------------------------------------------
/terratrack_utils/depth_png2h5.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import os
4 | from PIL import Image
5 |
6 | # convert depth from png to h5
7 | depth_png_path = '/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/reference_SfM_model_500/dense/stereo/depth_maps_png'
8 | output_h5_path = '/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/reference_SfM_model_500/dense/stereo/depth_bin_h5'
9 | if not os.path.exists(output_h5_path):
10 | os.mkdir(output_h5_path)
11 |
12 | for png_name in os.listdir(depth_png_path):
13 | print("png_name:", png_name)
14 | depth_png = Image.open(os.path.join(depth_png_path, png_name))
15 | print(depth_png)
16 | png_array = np.array(depth_png)
17 | print(png_array)
18 |
19 | # Save depth map as HDF5 file
20 | h5_path = os.path.join(output_h5_path, '.'.join(png_name.split('.')[:3])+'.h5')
21 | with h5py.File(h5_path, 'w') as f:
22 | f.create_dataset('depth', data=png_array)
23 |
24 | ## Read h5 depth
25 | # for h5_file in os.listdir(output_h5_path):
26 | # print("h5 name:", h5_file)
27 | # depth_h5 = h5py.File(os.path.join(output_h5_path, h5_file), 'r')
28 | # depth = depth_h5['depth']
29 | # print("h5_depth:",depth)
30 | # depth_array = np.array(depth)
31 | # print("depth_array:", depth_array)
32 | # depth_h5.close()
--------------------------------------------------------------------------------
/terratrack_utils/preprocess_terra.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import imagesize
4 |
5 | import numpy as np
6 |
7 | import os
8 |
9 | parser = argparse.ArgumentParser(description='MegaDepth preprocessing script')
10 |
11 | parser.add_argument(
12 | '--base_path', type=str,
13 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu',
14 | help='path to mavic_npu'
15 | )
16 |
17 | parser.add_argument(
18 | '--output_path', type=str,
19 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/process_output_query_ref_500',
20 | help='path to the output directory'
21 | )
22 |
23 |
24 | args = parser.parse_args()
25 |
26 | if not os.path.exists(args.output_path):
27 | os.mkdir(args.output_path)
28 |
29 | base_path = args.base_path
30 |
31 | # megadepth TODO
32 | base_depth_path = os.path.join(
33 | base_path, 'query_reference_SfM_model_500/dense/stereo'
34 | )
35 |
36 | base_undistorted_sfm_path = os.path.join(
37 | base_path, 'query_reference_Undistorted_SfM_500'
38 | )
39 |
40 | undistorted_sparse_path = os.path.join(
41 | base_undistorted_sfm_path, 'sparse-txt'
42 | )
43 |
44 | depths_path = os.path.join(
45 | base_depth_path, 'depth_bin_h5'
46 | )
47 |
48 |
49 | images_path = os.path.join(
50 | base_undistorted_sfm_path, 'images'
51 | )
52 |
53 |
54 | # Process cameras.txt
55 | with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f:
56 | raw = f.readlines()[3 :] # skip the header
57 |
58 | camera_intrinsics = {}
59 | for camera in raw:
60 | camera = camera.split(' ')
61 | camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]]
62 |
63 | # Process points3D.txt
64 | with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f:
65 | raw = f.readlines()[3 :] # skip the header
66 |
67 | points3D = {} # 3d点 字典:{点id:x,y,z}
68 | for point3D in raw:
69 | point3D = point3D.split(' ')
70 | points3D[int(point3D[0])] = np.array([
71 | float(point3D[1]), float(point3D[2]), float(point3D[3])
72 | ])
73 |
74 | # Process images.txt
75 | with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f:
76 | raw = f.readlines()[4 :] # skip the header
77 |
78 | image_id_to_idx = {}
79 | image_names = []
80 | raw_pose = []
81 | camera = []
82 | points3D_id_to_2D = []
83 | n_points3D = []
84 | for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])):
85 | image = image.split(' ')
86 | points = points.split(' ')
87 |
88 | image_id_to_idx[int(image[0])] = idx
89 |
90 | image_name = image[-1].strip('\n')
91 | image_names.append(image_name)
92 | # print("image_name:", image_name)
93 |
94 | raw_pose.append([float(elem) for elem in image[1 : -2]]) # image pose[ QW, QX, QY, QZ, TX, TY, TZ]
95 | camera.append(int(image[-2]))
96 | current_points3D_id_to_2D = {}
97 | for x, y, point3D_id in zip(points[:: 3], points[1 :: 3], points[2 :: 3]):
98 | if int(point3D_id) == -1: # 3d点id为-1的跳过,说明图像中该2d点没有对应的3d点
99 | continue
100 | current_points3D_id_to_2D[int(point3D_id)] = [float(x), float(y)]
101 | points3D_id_to_2D.append(current_points3D_id_to_2D) # 3d点的id与对应图像中2d点的坐标
102 | n_points3D.append(len(current_points3D_id_to_2D))
103 | n_images = len(image_names)
104 |
105 | # Image and depthmaps paths
106 | image_paths = []
107 | depth_paths = []
108 | for image_name in image_names:
109 | image_path = os.path.join(images_path, image_name)
110 |
111 | # Path to the depth file
112 | depth_path = os.path.join(
113 | depths_path, image_name + '.geometric.h5'
114 | )
115 | # print("depth_path:", depth_path)
116 |
117 | if os.path.exists(depth_path):
118 | # Check if depth map or background / foreground mask
119 | depth_paths.append(depth_path)
120 | print("depth path:", depth_path)
121 | image_paths.append(image_path)
122 | print("image path:", image_path)
123 | else:
124 | depth_paths.append(None)
125 | image_paths.append(None)
126 |
127 | # Camera configuration
128 | intrinsics = []
129 | poses = []
130 | principal_axis = []
131 | points3D_id_to_ndepth = []
132 | for idx, image_name in enumerate(image_names):
133 | if image_paths[idx] is None:
134 | intrinsics.append(None)
135 | poses.append(None)
136 | principal_axis.append([0, 0, 0])
137 | points3D_id_to_ndepth.append({})
138 | continue
139 | image_intrinsics = camera_intrinsics[camera[idx]]
140 | K = np.zeros([3, 3])
141 | K[0, 0] = image_intrinsics[2]
142 | K[0, 2] = image_intrinsics[4]
143 | K[1, 1] = image_intrinsics[3]
144 | K[1, 2] = image_intrinsics[5]
145 | K[2, 2] = 1
146 | intrinsics.append(K)
147 |
148 | image_pose = raw_pose[idx]
149 | qvec = image_pose[: 4]
150 | qvec = qvec / np.linalg.norm(qvec)
151 | w, x, y, z = qvec
152 | R = np.array([
153 | [
154 | 1 - 2 * y * y - 2 * z * z,
155 | 2 * x * y - 2 * z * w,
156 | 2 * x * z + 2 * y * w
157 | ],
158 | [
159 | 2 * x * y + 2 * z * w,
160 | 1 - 2 * x * x - 2 * z * z,
161 | 2 * y * z - 2 * x * w
162 | ],
163 | [
164 | 2 * x * z - 2 * y * w,
165 | 2 * y * z + 2 * x * w,
166 | 1 - 2 * x * x - 2 * y * y
167 | ]
168 | ])
169 | principal_axis.append(R[2, :])
170 | t = image_pose[4 : 7]
171 | # World-to-Camera pose
172 | current_pose = np.zeros([4, 4])
173 | current_pose[: 3, : 3] = R
174 | current_pose[: 3, 3] = t
175 | current_pose[3, 3] = 1
176 | # Camera-to-World pose
177 | # pose = np.zeros([4, 4])
178 | # pose[: 3, : 3] = np.transpose(R)
179 | # pose[: 3, 3] = -np.matmul(np.transpose(R), t)
180 | # pose[3, 3] = 1
181 | poses.append(current_pose)
182 |
183 | current_points3D_id_to_ndepth = {}
184 | for point3D_id in points3D_id_to_2D[idx].keys():
185 | p3d = points3D[point3D_id]
186 | current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (.5 * (K[0, 0] + K[1, 1]))
187 | points3D_id_to_ndepth.append(current_points3D_id_to_ndepth)
188 | principal_axis = np.array(principal_axis)
189 | # 相机的朝向?
190 | angles = np.rad2deg(np.arccos(
191 | np.clip(
192 | np.dot(principal_axis, np.transpose(principal_axis)),
193 | -1, 1
194 | )
195 | ))
196 |
197 | # Compute overlap score
198 | overlap_matrix = np.full([n_images, n_images], -1.)
199 | scale_ratio_matrix = np.full([n_images, n_images], -1.)
200 | for idx1 in range(n_images):
201 | if image_paths[idx1] is None or depth_paths[idx1] is None:
202 | continue
203 | for idx2 in range(idx1 + 1, n_images):
204 | if image_paths[idx2] is None or depth_paths[idx2] is None:
205 | continue
206 | # FIXME 计算match按位&的话,如果两个对应的id的顺序不一直呢
207 | # 共同观测到的特征点对应的三维点的 ID
208 | matches = (
209 | points3D_id_to_2D[idx1].keys() &
210 | points3D_id_to_2D[idx2].keys()
211 | )
212 | min_num_points3D = min(
213 | len(points3D_id_to_2D[idx1]), len(points3D_id_to_2D[idx2])
214 | ) # FIXME 计算两个图像之间的重叠率 这个参数并没有用到
215 | overlap_matrix[idx1, idx2] = len(matches) / len(points3D_id_to_2D[idx1]) # min_num_points3D
216 | overlap_matrix[idx2, idx1] = len(matches) / len(points3D_id_to_2D[idx2]) # min_num_points3D
217 | if len(matches) == 0:
218 | continue
219 | points3D_id_to_ndepth1 = points3D_id_to_ndepth[idx1]
220 | points3D_id_to_ndepth2 = points3D_id_to_ndepth[idx2]
221 | nd1 = np.array([points3D_id_to_ndepth1[match] for match in matches])
222 | nd2 = np.array([points3D_id_to_ndepth2[match] for match in matches])
223 | # FIXME 这个比例缩放倍率最小值的计算可以用来判断这两个点在相机坐标系下的物理尺寸比例是否接近,从而用于后续的相对姿态估计。
224 | min_scale_ratio = np.min(np.maximum(nd1 / nd2, nd2 / nd1))
225 | scale_ratio_matrix[idx1, idx2] = min_scale_ratio
226 | scale_ratio_matrix[idx2, idx1] = min_scale_ratio
227 | print(overlap_matrix)
228 | np.savez(
229 | os.path.join(args.output_path, os.path.basename(base_path)+'.npz'),
230 | image_paths=image_paths,
231 | depth_paths=depth_paths,
232 | intrinsics=intrinsics,
233 | poses=poses,
234 | overlap_matrix=overlap_matrix,
235 | scale_ratio_matrix=scale_ratio_matrix,
236 | angles=angles,
237 | n_points3D=n_points3D,
238 | points3D_id_to_2D=points3D_id_to_2D,
239 | points3D_id_to_ndepth=points3D_id_to_ndepth
240 | )
241 |
--------------------------------------------------------------------------------
/terratrack_utils/train_scenes_500.txt:
--------------------------------------------------------------------------------
1 | mavic_npu
2 |
--------------------------------------------------------------------------------
/terratrack_utils/train_scenes_origin.txt:
--------------------------------------------------------------------------------
1 | mavic-river
2 | mavic-factory
3 | mavic-fengniao
4 | mavic-hongkong
5 | inspire1-rail-kfs
6 | phantom3-grass-kfs
7 | phantom3-centralPark-kfs
8 | phantom3-npu-kfs
9 | phantom3-freeway-kfs
10 | phantom3-village-kfs
11 | gopro-npu-kfs
12 | gopro-saplings-kfs
13 |
--------------------------------------------------------------------------------
/terratrack_utils/undistort_reconstructions_terra.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import imagesize
4 |
5 | import os
6 |
7 | import subprocess
8 |
9 | parser = argparse.ArgumentParser(description='MegaDepth Undistortion')
10 |
11 | parser.add_argument(
12 | '--colmap_path', type=str,
13 | default='/usr/bin',
14 | help='path to colmap executable'
15 | )
16 | parser.add_argument(
17 | '--base_path', type=str,
18 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu',
19 | help='path to mavic_npu'
20 | )
21 |
22 | args = parser.parse_args()
23 |
24 | sfm_path = os.path.join(
25 | args.base_path, 'query_reference_SfM_model_500'
26 | )
27 | base_depth_path = os.path.join(
28 | sfm_path, 'dense'
29 | )
30 | output_path = os.path.join(
31 | args.base_path, 'query_reference_Undistorted_SfM_500'
32 | )
33 |
34 | os.mkdir(output_path)
35 |
36 | image_path = os.path.join(
37 | base_depth_path, 'images'
38 | )
39 |
40 | # Find the maximum image size in scene.
41 | max_image_size = 0
42 | for image_name in os.listdir(image_path):
43 | max_image_size = max(
44 | max_image_size,
45 | max(imagesize.get(os.path.join(image_path, image_name)))
46 | )
47 |
48 | # Undistort the images and update the reconstruction.
49 | subprocess.call([
50 | os.path.join(args.colmap_path, 'colmap'), 'image_undistorter',
51 | '--image_path', os.path.join(args.base_path, 'query_reference_images_500'),
52 | '--input_path', os.path.join(sfm_path, 'sparse', '0'),
53 | '--output_path', output_path,
54 | '--max_image_size', str(max_image_size)
55 | ])
56 |
57 | # Transform the reconstruction to raw text format.
58 | sparse_txt_path = os.path.join(output_path, 'sparse-txt')
59 | os.mkdir(sparse_txt_path)
60 | subprocess.call([
61 | os.path.join(args.colmap_path, 'colmap'), 'model_converter',
62 | '--input_path', os.path.join(output_path, 'sparse'),
63 | '--output_path', sparse_txt_path,
64 | '--output_type', 'TXT'
65 | ])
--------------------------------------------------------------------------------
/terratrack_utils/valid_scenes_500.txt:
--------------------------------------------------------------------------------
1 | mavic_npu
2 |
--------------------------------------------------------------------------------
/terratrack_utils/valid_scenes_origin.txt:
--------------------------------------------------------------------------------
1 | mavic-xjtu
2 | phantom3-huangqi-kfs
3 |
--------------------------------------------------------------------------------
/train_terra.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 |
5 | import os
6 |
7 | import shutil
8 |
9 | import torch
10 | import torch.optim as optim
11 |
12 | from torch.utils.data import DataLoader
13 |
14 | from tqdm import tqdm
15 |
16 | import warnings
17 |
18 | from lib.dataset_terra import TerraDataset
19 | from lib.exceptions import NoGradientError
20 | from lib.loss import loss_function
21 | from lib.full_model.model_swin_unet_d2 import Swin_D2UNet
22 | from lib.full_model.model_unet import U2Net
23 | from torch.utils.tensorboard import SummaryWriter
24 | from datetime import datetime
25 |
26 | # CUDA
27 | use_cuda = torch.cuda.is_available()
28 | device = torch.device("cuda:0" if use_cuda else "cpu")
29 |
30 | # Seed
31 | torch.manual_seed(1)
32 | if use_cuda:
33 | torch.cuda.manual_seed(1)
34 | np.random.seed(1)
35 |
36 | # Argument parsing
37 | parser = argparse.ArgumentParser(description='Training script')
38 |
39 | parser.add_argument(
40 | '--dataset_path', type=str,
41 | help='path to the dataset',
42 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack'
43 | )
44 | parser.add_argument(
45 | '--scene_info_path', type=str,
46 | help='path to the processed scenes',
47 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/process_output_query_ref_500'
48 | )
49 |
50 | parser.add_argument(
51 | '--preprocessing', type=str, default='torch',
52 | help='image preprocessing (caffe or torch)'
53 | )
54 | parser.add_argument(
55 | '--model_file', type=str, default='models/d2_ots.pth',
56 | help='path to the full model'
57 | )
58 |
59 | parser.add_argument(
60 | '--num_epochs', type=int, default=100,
61 | help='number of training epochs'
62 | )
63 | # default = le-3
64 | parser.add_argument(
65 | '--lr', type=float, default=1e-3,
66 | help='initial learning rate'
67 | )
68 | parser.add_argument(
69 | '--batch_size', type=int, default=1,
70 | help='batch size'
71 | )
72 | parser.add_argument(
73 | '--num_workers', type=int, default=4,
74 | help='number of workers for data loading'
75 | )
76 |
77 | parser.add_argument(
78 | '--use_validation', dest='use_validation', action='store_true',
79 | help='use the validation split'
80 | )
81 | parser.set_defaults(use_validation=True)
82 |
83 | parser.add_argument(
84 | '--log_interval', type=int, default=2000,
85 | help='loss logging interval'
86 | )
87 | parser.add_argument(
88 | '--plot', dest='plot', action='store_true',
89 | help='plot training pairs'
90 | )
91 | parser.set_defaults(plot=True)
92 |
93 | parser.add_argument(
94 | '--checkpoint_directory', type=str, default='checkpoints',
95 | help='directory for training checkpoints'
96 | )
97 |
98 | parser.add_argument(
99 | '--net', type=str, default='vgg',
100 | help='choose net vgg or swin'
101 | )
102 | args = parser.parse_args()
103 |
104 | print(args)
105 |
106 | # Create the folders for plotting if need be
107 | if args.plot:
108 | plot_path = 'train_vis_56_true'
109 | if os.path.isdir(plot_path):
110 | print('[Warning] Plotting directory already exists.')
111 | else:
112 | os.mkdir(plot_path)
113 |
114 | if args.net=='swin':
115 | model = Swin_D2UNet(
116 | # model_file=args.model_file,
117 | use_cuda=use_cuda
118 | )
119 | elif args.net=='unet':
120 | model = U2Net(
121 | model_file=args.model_file,
122 | use_cuda=use_cuda
123 | )
124 |
125 | # Optimizer
126 | optimizer = optim.Adam(
127 | filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr
128 | )
129 | # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.00001)
130 |
131 | # Dataset
132 | if args.use_validation:
133 | validation_dataset = TerraDataset(
134 | scene_list_path='terratrack_utils/valid_scenes_500.txt',
135 | scene_info_path=args.scene_info_path,
136 | base_path=args.dataset_path,
137 | train=False,
138 | preprocessing=args.preprocessing,
139 | pairs_per_scene=2
140 | )
141 | validation_dataloader = DataLoader(
142 | validation_dataset,
143 | batch_size=args.batch_size,
144 | num_workers=args.num_workers
145 | )
146 |
147 | training_dataset = TerraDataset(
148 | scene_list_path='terratrack_utils/train_scenes_500.txt',
149 | scene_info_path=args.scene_info_path,
150 | base_path=args.dataset_path,
151 | preprocessing=args.preprocessing
152 | )
153 | training_dataloader = DataLoader(
154 | training_dataset,
155 | batch_size=args.batch_size,
156 | num_workers=args.num_workers
157 | )
158 |
159 |
160 | # Define epoch function
161 | def process_epoch(
162 | epoch_idx,
163 | model, loss_function, optimizer, dataloader, device,
164 | log_file, args, writer, train=True
165 | ):
166 | epoch_losses = []
167 |
168 | torch.set_grad_enabled(train)
169 |
170 | progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
171 | # max_iterations = 20*len(progress_bar) # 20 is epoch
172 | nBatches = len(progress_bar)
173 | for batch_idx, batch in progress_bar:
174 | if train:
175 | optimizer.zero_grad()
176 |
177 | batch['train'] = train
178 | batch['epoch_idx'] = epoch_idx
179 | batch['batch_idx'] = batch_idx
180 | batch['batch_size'] = args.batch_size
181 | batch['preprocessing'] = args.preprocessing
182 | batch['log_interval'] = args.log_interval
183 |
184 | try:
185 | loss = loss_function(model, batch, device, plot=args.plot)
186 | except NoGradientError:
187 | continue
188 |
189 | current_loss = loss.data.cpu().numpy()[0]
190 | if train:
191 | writer.add_scalar('Train/CurrentBatchLoss', current_loss, (epoch_idx - 1) * nBatches + batch_idx)
192 | else:
193 | writer.add_scalar('Valid/CurrentBatchLoss', current_loss, (epoch_idx - 1) * nBatches + batch_idx)
194 |
195 | epoch_losses.append(current_loss)
196 |
197 | progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses)))
198 |
199 | if batch_idx % args.log_interval == 0:
200 | log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % (
201 | 'train' if train else 'valid',
202 | epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses)
203 | ))
204 |
205 | if train:
206 | loss.backward()
207 | optimizer.step()
208 | # lr_ = args.lr * (1.0 - batch_idx*epoch_idx / max_iterations) ** 0.9
209 | # for param_group in optimizer.param_groups:
210 | # param_group['lr'] = lr_
211 |
212 |
213 | log_file.write('[%s] epoch %d - avg_loss: %f\n' % (
214 | 'train' if train else 'valid',
215 | epoch_idx,
216 | np.mean(epoch_losses)
217 | ))
218 | if train:
219 | writer.add_scalar('Train/AvgLoss', np.mean(epoch_losses), epoch_idx)
220 | else:
221 | writer.add_scalar('Valid/AvgLoss', np.mean(epoch_losses), epoch_idx)
222 |
223 | log_file.flush()
224 |
225 | return np.mean(epoch_losses)
226 |
227 | writer = SummaryWriter(log_dir=os.path.join(args.checkpoint_directory, datetime.now().strftime('%b%d_%H-%M-%S')+'_'+args.net))
228 | # Create the checkpoint directory
229 | logdir = writer.file_writer.get_logdir()
230 | save_checkpoint_path = os.path.join(logdir, 'checkpoints')
231 |
232 | if os.path.isdir(save_checkpoint_path):
233 | print('[Warning] Checkpoint directory already exists.')
234 | else:
235 | os.mkdir(save_checkpoint_path)
236 |
237 | log_file = os.path.join(logdir, 'log.txt')
238 | # Open the log file for writing
239 | if os.path.exists(log_file):
240 | print('[Warning] Log file already exists.')
241 | log_file = open(log_file, 'a+')
242 | log_file.write("args:" + str(args))
243 |
244 | # Initialize the history
245 | train_loss_history = []
246 | validation_loss_history = []
247 | if args.use_validation:
248 | validation_dataset.build_dataset()
249 | min_validation_loss = process_epoch(
250 | 0,
251 | model, loss_function, optimizer, validation_dataloader, device,
252 | log_file, args, writer,
253 | train=False
254 | )
255 |
256 | # Start the training
257 | for epoch_idx in range(1, args.num_epochs + 1):
258 | # Process epoch
259 | print("epoch :", epoch_idx)
260 | training_dataset.build_dataset()
261 | train_loss_history.append(
262 | process_epoch(
263 | epoch_idx,
264 | model, loss_function, optimizer, training_dataloader, device,
265 | log_file, args, writer
266 | )
267 | )
268 |
269 | if args.use_validation:
270 | validation_loss_history.append(
271 | process_epoch(
272 | epoch_idx,
273 | model, loss_function, optimizer, validation_dataloader, device,
274 | log_file, args, writer,
275 | train=False
276 | )
277 | )
278 |
279 | # Save the current checkpoint
280 | checkpoint_path = os.path.join(
281 | save_checkpoint_path,
282 | 'checkpoint.pth'
283 | )
284 | checkpoint = {
285 | 'args': args,
286 | 'epoch_idx': epoch_idx,
287 | 'model': model.state_dict(),
288 | 'optimizer': optimizer.state_dict(),
289 | 'train_loss_history': train_loss_history,
290 | 'validation_loss_history': validation_loss_history
291 | }
292 | torch.save(checkpoint, checkpoint_path)
293 |
294 | if (
295 | args.use_validation and
296 | validation_loss_history[-1] < min_validation_loss
297 | ):
298 | min_validation_loss = validation_loss_history[-1]
299 | best_checkpoint_path = os.path.join(
300 | save_checkpoint_path,
301 | 'best.pth' )
302 | shutil.copy(checkpoint_path, best_checkpoint_path)
303 |
304 | # Close the log file
305 | log_file.close()
306 | writer.close()
307 |
--------------------------------------------------------------------------------