├── .idea ├── .gitignore ├── CurriculumLoc.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── ProcessData ├── gen_gt_files.py ├── gen_images_file.py └── rgb_norm.py ├── README.md ├── __pycache__ ├── cnnmatchs.cpython-310.pyc ├── cnnmatchs.cpython-38.pyc ├── plotmatch.cpython-310.pyc └── plotmatch.cpython-38.pyc ├── environment.yaml ├── fig ├── dataset.png └── outline.png ├── lib ├── __pycache__ │ ├── cnn_feature.cpython-310.pyc │ ├── cnn_feature.cpython-38.pyc │ ├── dataset.cpython-38.pyc │ ├── dataset_terra.cpython-310.pyc │ ├── dataset_terra.cpython-38.pyc │ ├── exceptions.cpython-310.pyc │ ├── exceptions.cpython-38.pyc │ ├── loss.cpython-310.pyc │ ├── loss.cpython-38.pyc │ ├── model.cpython-38.pyc │ ├── model_d2.cpython-38.pyc │ ├── model_swin.cpython-310.pyc │ ├── model_swin.cpython-38.pyc │ ├── model_swin_unet_d2.cpython-38.pyc │ ├── model_test.cpython-310.pyc │ ├── model_test.cpython-38.pyc │ ├── pyramid.cpython-310.pyc │ ├── pyramid.cpython-38.pyc │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc.139907302345344 │ ├── swin_unet.cpython-310.pyc │ ├── swin_unet.cpython-38.pyc │ ├── utils.cpython-310.pyc │ └── utils.cpython-38.pyc ├── backbone_model │ ├── __pycache__ │ │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc │ │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc │ │ ├── swin_unet.cpython-310.pyc │ │ ├── swin_unet.cpython-38.pyc │ │ └── unet.cpython-38.pyc │ ├── erfnet.py │ ├── swin_transformer_unet_skip_expand_decoder_sys.py │ ├── swin_unet.py │ └── unet.py ├── dataset_terra.py ├── exceptions.py ├── full_model │ ├── __pycache__ │ │ ├── model_d2.cpython-310.pyc │ │ ├── model_d2.cpython-38.pyc │ │ ├── model_swin_unet_d2.cpython-310.pyc │ │ ├── model_swin_unet_d2.cpython-38.pyc │ │ └── model_unet.cpython-38.pyc │ ├── model_erf.py │ ├── model_swin_unet_d2.py │ ├── model_unet.py │ └── model_vit.py ├── model_swin.py ├── model_test.py ├── swin_transformer_unet_skip_expand_decoder_sys.py ├── swin_unet.py └── utils.py ├── match_localization.py ├── matching.py ├── performance.ini ├── plotmatch.py ├── result ├── cv-match.py └── plot_rs.py ├── terratrack_utils ├── compute_mean_std.py ├── depth_bin2array.py ├── depth_png2h5.py ├── preprocess_terra.py ├── train_scenes_500.txt ├── train_scenes_origin.txt ├── undistort_reconstructions_terra.py ├── valid_scenes_500.txt └── valid_scenes_origin.txt └── train_terra.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/CurriculumLoc.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /ProcessData/gen_gt_files.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | from os.path import join 4 | import pandas as pd 5 | 6 | # data_q = [] 7 | # data_db = [] 8 | # 9 | path = '/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/' 10 | # 11 | # with open(join(path, 'reference.csv'), 'r') as csv_file_db: 12 | # csv_data = csv.DictReader(csv_file_db) 13 | # for row in csv_data: 14 | # data_db.append(row) 15 | # 16 | # with open(join(path, 'query.csv'), 'r') as csv_file_q: 17 | # csv_data_q = csv.DictReader(csv_file_q) 18 | # for row in csv_data_q: 19 | # data_q.append(row) 20 | # 21 | # numQ = len(data_q) 22 | # numDb = len(data_db) 23 | # 24 | # utmDb = [[float(f["easting"]), float(f["northing"])] for f in data_db] 25 | # utmQ = [[float(f["easting"]), float(f["northing"])] for f in data_q] 26 | posDistThr = 50 27 | qData = pd.read_csv(join(path, 'query.csv')) 28 | qData = qData.sort_values(by='name', ascending=True) 29 | dbData = pd.read_csv(join(path, 'reference.csv')) 30 | dbData = dbData.sort_values(by='name', ascending=True) 31 | utmQ = qData[['easting', 'northing']].values.reshape(-1, 2) 32 | utmDb = dbData[['easting', 'northing']].values.reshape(-1, 2) 33 | 34 | np.savez('../patchnetvlad/dataset_gt_files/mavic_npu_dist50', utmQ=utmQ, utmDb=utmDb, posDistThr=posDistThr) 35 | 36 | print("Save Done!") 37 | -------------------------------------------------------------------------------- /ProcessData/gen_images_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | parser = argparse.ArgumentParser(description='origin data path') 6 | parser.add_argument('--data_path', type=str, default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/reference_images_500', help='origin images path') 7 | parser.add_argument('--save_path', type=str, default='../patchnetvlad/dataset_imagenames/mavic_npu_reference_imageNames_index.txt', 8 | help='save images txt path') 9 | 10 | 11 | opt = parser.parse_args() 12 | print(opt) 13 | 14 | 15 | for dir in os.listdir(opt.data_path): 16 | sub_data_path = os.path.join(opt.data_path, dir) 17 | if os.path.isdir(sub_data_path): 18 | for f in os.listdir(sub_data_path): 19 | image_name = os.path.join(sub_data_path, f) 20 | image_name_w = os.path.join(*(image_name.split('/')[-3:])) 21 | with open(opt.save_path, 'a') as txt: 22 | txt.write(image_name_w) 23 | txt.write('\n') 24 | else: 25 | image_name = os.path.join(opt.data_path, dir) 26 | # image_name_w = os.path.join(*(image_name.split('/')[-2:])) 27 | with open(opt.save_path, 'a') as txt: 28 | # txt.write(image_name_w) 29 | txt.write(image_name) 30 | txt.write('\n') 31 | -------------------------------------------------------------------------------- /ProcessData/rgb_norm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def compute(img, min_percentile, max_percentile): 6 | """计算分位点,目的是去掉图1的直方图两头的异常情况""" 7 | 8 | 9 | max_percentile_pixel = np.percentile(img, max_percentile) 10 | min_percentile_pixel = np.percentile(img, min_percentile) 11 | 12 | return max_percentile_pixel, min_percentile_pixel 13 | 14 | 15 | def aug(src): 16 | """图像亮度增强""" 17 | if get_lightness(src) > 130: 18 | print("图片亮度足够,不做增强") 19 | # 先计算分位点,去掉像素值中少数异常值,这个分位点可以自己配置。 20 | # 比如1中直方图的红色在0到255上都有值,但是实际上像素值主要在0到20内。 21 | 22 | 23 | max_percentile_pixel, min_percentile_pixel = compute(src, 1, 99) 24 | 25 | # 去掉分位值区间之外的值 26 | src[src >= max_percentile_pixel] = max_percentile_pixel 27 | src[src <= min_percentile_pixel] = min_percentile_pixel 28 | 29 | # 将分位值区间拉伸到0到255,这里取了255*0.1与255*0.9是因为可能会出现像素值溢出的情况,所以最好不要设置为0到255。 30 | out = np.zeros(src.shape, src.dtype) 31 | cv2.normalize(src, out, 255 * 0.1, 255 * 0.9, cv2.NORM_MINMAX) 32 | 33 | return out 34 | 35 | 36 | def get_lightness(src): 37 | # 计算亮度 38 | hsv_image = cv2.cvtColor(src, cv2.COLOR_BGR2HSV) 39 | lightness = hsv_image[:, :, 2].mean() 40 | 41 | return lightness 42 | 43 | 44 | 45 | img = cv2.imread("/home/a409/users/chenlin/VPR_huboni/Val/GPR_Dataset/query_images/000000.png", cv2.IMREAD_UNCHANGED) 46 | img = aug(img) 47 | print("ppppp") 48 | cv2.imwrite('/home/a409/users/chenlin/VPR_huboni/Val/GPR_Dataset/output/000000.png', img) 49 | 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CurriculumLoc for Visual Geo-localization 2 | 3 | 4 | 5 | ## Introduction 6 | CurriculumLoc is a PyTorch implementation for our paper ["CurriculumLoc: Enhancing Cross-Domain Geolocalization through Multi-Stage Refinement"](https://arxiv.org/abs/2311.11604). If you use this code for your research, please cite our paper. For additional questions contact us via huboni@mail.nwpu.edu.cn or huboni7@gmail.com. 7 | 8 | 9 | ## Installation 10 | 11 | We test this repo with Python 3.10, PyTorch 1.12.1, and CUDA 11.3. However, it should be runnable with recent PyTorch versions. You can install by conda with our prove environment.taml. 12 | 13 | ```shell 14 | conda env create -f environment.yaml 15 | ``` 16 | 17 | 18 | ## Preparation 19 | 20 | We test our models on two datasets. One dataset is [ALTO](https://github.com/MetaSLAM/ALTO), it can be download at [here](https://github.com/MetaSLAM/ALTO). Another is our TerraTrack, TerraTrack is being prepared.... All of these datasets contain some challenging environmental variations, as shown in below table. 21 | 22 | 23 | 24 | 25 | ## Training 26 | 27 | ``` 28 | python train_terra.py 29 | ``` 30 | 31 | 32 | 33 | ## Testing 34 | 35 | ``` 36 | python match_localization.py 37 | ``` 38 | 39 | 40 | 41 | ## Citation 42 | 43 | If you're using CurculumLoc in your research or applications, please cite using this BibTeX: 44 | 45 | ```bibtex 46 | @misc{hu2023curriculumloc, 47 | title={CurriculumLoc: Enhancing Cross-Domain Geolocalization through Multi-Stage Refinement}, 48 | author={Boni Hu and Lin Chen and Runjian Chen and Shuhui Bu and Pengcheng Han and Haowei Li}, 49 | year={2023}, 50 | eprint={2311.11604}, 51 | archivePrefix={arXiv}, 52 | primaryClass={cs.CV} 53 | } 54 | ``` 55 | 56 | -------------------------------------------------------------------------------- /__pycache__/cnnmatchs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/__pycache__/cnnmatchs.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/cnnmatchs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/__pycache__/cnnmatchs.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/plotmatch.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/__pycache__/plotmatch.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/plotmatch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/__pycache__/plotmatch.cpython-38.pyc -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: cnn-matching 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blosc=1.21.3=h6a678d5_0 9 | - boost-cpp=1.73.0=h7f8727e_12 10 | - bzip2=1.0.8=h7b6447c_0 11 | - c-ares=1.18.1=h7f8727e_0 12 | - ca-certificates=2023.01.10=h06a4308_0 13 | - cairo=1.16.0=hb05425b_3 14 | - cfitsio=3.470=hf0d0db6_6 15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 16 | - cudatoolkit=11.3.1=h2bc3f7f_2 17 | - curl=7.87.0=h5eee18b_0 18 | - expat=2.4.9=h6a678d5_0 19 | - ffmpeg=4.3=hf484d3e_0 20 | - flit-core=3.6.0=pyhd3eb1b0_0 21 | - fontconfig=2.14.1=h52c9d5c_1 22 | - freetype=2.12.1=h4a9f257_0 23 | - freexl=1.0.6=h27cfd23_0 24 | - geos=3.8.0=he6710b0_0 25 | - geotiff=1.7.0=hd69d5b1_0 26 | - giflib=5.2.1=h5eee18b_1 27 | - glib=2.69.1=he621ea3_2 28 | - gmp=6.2.1=h295c915_3 29 | - gnutls=3.6.15=he1e5248_0 30 | - hdf4=4.2.13=h3ca952b_2 31 | - hdf5=1.10.6=hb1b8bf9_0 32 | - icu=58.2=he6710b0_3 33 | - intel-openmp=2022.1.0=h9e868ea_3769 34 | - jpeg=9e=h7f8727e_0 35 | - json-c=0.16=h5eee18b_0 36 | - kealib=1.5.0=hd940352_0 37 | - krb5=1.19.4=h568e23c_0 38 | - lame=3.100=h7b6447c_0 39 | - lcms2=2.12=h3be6417_0 40 | - ld_impl_linux-64=2.38=h1181459_1 41 | - lerc=3.0=h295c915_0 42 | - libboost=1.73.0=h28710b8_12 43 | - libcurl=7.87.0=h91b91d3_0 44 | - libdeflate=1.8=h7f8727e_5 45 | - libedit=3.1.20221030=h5eee18b_0 46 | - libev=4.33=h7f8727e_1 47 | - libffi=3.4.2=h6a678d5_6 48 | - libgcc-ng=11.2.0=h1234567_1 49 | - libgdal=3.6.0=hc0e11bb_0 50 | - libgfortran-ng=7.5.0=ha8ba4b0_17 51 | - libgfortran4=7.5.0=ha8ba4b0_17 52 | - libgomp=11.2.0=h1234567_1 53 | - libiconv=1.16=h7f8727e_2 54 | - libidn2=2.3.2=h7f8727e_0 55 | - libkml=1.3.0=h096b73e_6 56 | - libnetcdf=4.8.1=h8322cc2_2 57 | - libnghttp2=1.46.0=hce63b2e_0 58 | - libpng=1.6.37=hbc83047_0 59 | - libpq=12.9=h16c4e8d_3 60 | - libspatialite=4.3.0a=h71b31bf_21 61 | - libssh2=1.10.0=h8f2d780_0 62 | - libstdcxx-ng=11.2.0=h1234567_1 63 | - libtasn1=4.16.0=h27cfd23_0 64 | - libtiff=4.5.0=h6a678d5_1 65 | - libunistring=0.9.10=h27cfd23_0 66 | - libuuid=1.41.5=h5eee18b_0 67 | - libwebp=1.2.4=h11a3e52_0 68 | - libwebp-base=1.2.4=h5eee18b_0 69 | - libxcb=1.15=h7f8727e_0 70 | - libxml2=2.9.14 71 | - libzip=1.8.0=h5cef20c_0 72 | - lz4-c=1.9.4=h6a678d5_0 73 | - mkl=2020.2=256 74 | - ncurses=6.4=h6a678d5_0 75 | - nettle=3.7.3=hbbd107a_1 76 | - nspr=4.33=h295c915_0 77 | - nss=3.74=h0370c37_0 78 | - openh264=2.1.1=h4ff587b_0 79 | - openjpeg=2.4.0=h3ad879b_0 80 | - openssl=1.1.1t=h7f8727e_0 81 | - pcre=8.45=h295c915_0 82 | - pcre2=10.37=he7ceb23_1 83 | - pixman=0.40.0=h7f8727e_1 84 | - poppler=22.12.0=h381b16e_0 85 | - poppler-data=0.4.11=h06a4308_1 86 | - postgresql=12.9=h16c4e8d_3 87 | - proj=6.2.1=h05a3930_0 88 | - pycparser=2.21=pyhd3eb1b0_0 89 | - pyopenssl=22.0.0=pyhd3eb1b0_0 90 | - python=3.10.9=h7a1cb2a_0 91 | - pytorch=1.12.1=py3.10_cuda11.3_cudnn8.3.2_0 92 | - pytorch-mutex=1.0=cuda 93 | - qhull=2020.2=hdb19cb5_2 94 | - readline=8.2=h5eee18b_0 95 | - sqlite=3.40.1=h5082296_0 96 | - tiledb=2.3.3=h1132f93_2 97 | - tk=8.6.12=h1ccaba5_0 98 | - typing_extensions=4.4.0=py310h06a4308_0 99 | - tzdata=2022g=h04d1e81_0 100 | - xerces-c=3.2.4=h94c2ce2_0 101 | - xz=5.2.10=h5eee18b_1 102 | - zlib=1.2.13=h5eee18b_0 103 | - zstd=1.5.2=ha4553b6_0 104 | - pip: 105 | - brotlipy==0.7.0 106 | - certifi==2022.12.7 107 | - cffi==1.15.1 108 | - contourpy==1.0.7 109 | - cryptography==38.0.4 110 | - cycler==0.11.0 111 | - einops==0.7.0 112 | - filelock==3.12.4 113 | - fonttools==4.38.0 114 | - fsspec==2023.9.2 115 | - gdal==3.6.0 116 | - huggingface-hub==0.17.3 117 | - idna==3.4 118 | - imageio==2.25.1 119 | - joblib==1.2.0 120 | - kiwisolver==1.4.4 121 | - matplotlib==3.7.0 122 | - networkx==3.0 123 | - numpy==1.22.3 124 | - opencv-python==4.7.0.68 125 | - packaging==23.0 126 | - pandas==1.5.3 127 | - pillow==9.3.0 128 | - pip==22.3.1 129 | - pyparsing==3.0.9 130 | - pysocks==1.7.1 131 | - python-dateutil==2.8.2 132 | - pytz==2022.7.1 133 | - pywavelets==1.4.1 134 | - pyyaml==6.0.1 135 | - requests==2.28.1 136 | - safetensors==0.4.0 137 | - scikit-image==0.19.3 138 | - scikit-learn==1.2.1 139 | - scipy==1.10.0 140 | - six==1.16.0 141 | - threadpoolctl==3.1.0 142 | - tifffile==2023.2.3 143 | - timm==0.9.7 144 | - torch==1.12.1 145 | - torchaudio==0.12.1 146 | - torchvision==0.13.1 147 | - tqdm==4.64.1 148 | - typing-extensions==4.4.0 149 | - urllib3==1.26.14 150 | prefix: /home/a409_home/anaconda3/envs/cnn-matching 151 | -------------------------------------------------------------------------------- /fig/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/fig/dataset.png -------------------------------------------------------------------------------- /fig/outline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/fig/outline.png -------------------------------------------------------------------------------- /lib/__pycache__/cnn_feature.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/cnn_feature.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/cnn_feature.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/cnn_feature.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataset_terra.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/dataset_terra.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataset_terra.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/dataset_terra.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/exceptions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/exceptions.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/exceptions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/exceptions.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model_d2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_d2.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model_swin.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_swin.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model_swin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_swin.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model_swin_unet_d2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_swin_unet_d2.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model_test.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_test.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model_test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/model_test.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/pyramid.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/pyramid.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/pyramid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/pyramid.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc.139907302345344: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc.139907302345344 -------------------------------------------------------------------------------- /lib/__pycache__/swin_unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/swin_unet.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/swin_unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/swin_unet.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /lib/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /lib/backbone_model/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/backbone_model/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc -------------------------------------------------------------------------------- /lib/backbone_model/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/backbone_model/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc -------------------------------------------------------------------------------- /lib/backbone_model/__pycache__/swin_unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/backbone_model/__pycache__/swin_unet.cpython-310.pyc -------------------------------------------------------------------------------- /lib/backbone_model/__pycache__/swin_unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/backbone_model/__pycache__/swin_unet.cpython-38.pyc -------------------------------------------------------------------------------- /lib/backbone_model/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/backbone_model/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /lib/backbone_model/erfnet.py: -------------------------------------------------------------------------------- 1 | # ERFNet full model definition for Pytorch 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | 12 | class DownsamplerBlock (nn.Module): 13 | def __init__(self, ninput, noutput): 14 | super().__init__() 15 | 16 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True) 17 | self.pool = nn.MaxPool2d(2, stride=2) 18 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 19 | 20 | def forward(self, input): 21 | output = torch.cat([self.conv(input), self.pool(input)], 1) 22 | output = self.bn(output) 23 | return F.relu(output) 24 | 25 | 26 | class non_bottleneck_1d (nn.Module): 27 | def __init__(self, chann, dropprob, dilated): 28 | super().__init__() 29 | 30 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True) 31 | 32 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True) 33 | 34 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) 35 | 36 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=( 37 | 1*dilated, 0), bias=True, dilation=(dilated, 1)) 38 | 39 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=( 40 | 0, 1*dilated), bias=True, dilation=(1, dilated)) 41 | 42 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) 43 | 44 | self.dropout = nn.Dropout2d(dropprob) 45 | 46 | def forward(self, input): 47 | 48 | output = self.conv3x1_1(input) 49 | output = F.relu(output) 50 | output = self.conv1x3_1(output) 51 | output = self.bn1(output) 52 | output = F.relu(output) 53 | 54 | output = self.conv3x1_2(output) 55 | output = F.relu(output) 56 | output = self.conv1x3_2(output) 57 | output = self.bn2(output) 58 | 59 | if (self.dropout.p != 0): 60 | output = self.dropout(output) 61 | 62 | return F.relu(output+input) # +input = identity (residual connection) 63 | 64 | 65 | class Encoder(nn.Module): 66 | def __init__(self, num_classes): 67 | super().__init__() 68 | self.initial_block = DownsamplerBlock(3, 16) 69 | 70 | self.layers = nn.ModuleList() 71 | 72 | self.layers.append(DownsamplerBlock(16, 64)) 73 | 74 | for x in range(0, 5): # 5 times 75 | self.layers.append(non_bottleneck_1d(64, 0.03, 1)) 76 | 77 | self.layers.append(DownsamplerBlock(64, 128)) 78 | 79 | for x in range(0, 2): # 2 times 80 | self.layers.append(non_bottleneck_1d(128, 0.3, 2)) 81 | self.layers.append(non_bottleneck_1d(128, 0.3, 4)) 82 | self.layers.append(non_bottleneck_1d(128, 0.3, 8)) 83 | self.layers.append(non_bottleneck_1d(128, 0.3, 16)) 84 | 85 | # Only in encoder mode: 86 | # self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True) 87 | 88 | def forward(self, input): 89 | output = self.initial_block(input) 90 | for layer in self.layers: 91 | output = layer(output) 92 | # print(output.size()) 93 | 94 | return output 95 | 96 | 97 | class UpsamplerBlock (nn.Module): 98 | def __init__(self, ninput, noutput): 99 | super().__init__() 100 | self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, 101 | padding=1, output_padding=1, bias=True) 102 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 103 | 104 | def forward(self, input): 105 | output = self.conv(input) 106 | output = self.bn(output) 107 | return F.relu(output) 108 | 109 | 110 | class Decoder (nn.Module): 111 | def __init__(self, num_classes): 112 | super().__init__() 113 | 114 | self.layers = nn.ModuleList() 115 | 116 | self.layers.append(UpsamplerBlock(128, 64)) 117 | self.layers.append(non_bottleneck_1d(64, 0, 1)) 118 | self.layers.append(non_bottleneck_1d(64, 0, 1)) 119 | 120 | self.layers.append(UpsamplerBlock(64, 16)) 121 | self.layers.append(non_bottleneck_1d(16, 0, 1)) 122 | self.layers.append(non_bottleneck_1d(16, 0, 1)) 123 | 124 | self.output_conv = nn.ConvTranspose2d( 125 | 16, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True) 126 | 127 | def forward(self, input): 128 | output = input 129 | 130 | for layer in self.layers: 131 | output = layer(output) 132 | 133 | output = self.output_conv(output) 134 | 135 | return output 136 | 137 | # ERFNet 138 | 139 | 140 | class Net(nn.Module): 141 | def __init__(self, num_classes): # use encoder to pass pretrained encoder 142 | super().__init__() 143 | 144 | self.encoder = Encoder(num_classes) 145 | self.decoder = Decoder(num_classes) 146 | 147 | def forward(self, input): 148 | output = self.encoder(input) 149 | return self.decoder.forward(output) 150 | -------------------------------------------------------------------------------- /lib/backbone_model/swin_transformer_unet_skip_expand_decoder_sys.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | from einops import rearrange 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | 7 | 8 | class Mlp(nn.Module): 9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 10 | super().__init__() 11 | out_features = out_features or in_features 12 | hidden_features = hidden_features or in_features 13 | self.fc1 = nn.Linear(in_features, hidden_features) 14 | self.act = act_layer() 15 | self.fc2 = nn.Linear(hidden_features, out_features) 16 | self.drop = nn.Dropout(drop) 17 | 18 | def forward(self, x): 19 | x = self.fc1(x) 20 | x = self.act(x) 21 | x = self.drop(x) 22 | x = self.fc2(x) 23 | x = self.drop(x) 24 | return x 25 | 26 | 27 | def window_partition(x, window_size): 28 | """ 29 | Args: 30 | x: (B, H, W, C) 31 | window_size (int): window size 32 | 33 | Returns: 34 | windows: (num_windows*B, window_size, window_size, C) 35 | """ 36 | B, H, W, C = x.shape 37 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 38 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 39 | return windows 40 | 41 | 42 | def window_reverse(windows, window_size, H, W): 43 | """ 44 | Args: 45 | windows: (num_windows*B, window_size, window_size, C) 46 | window_size (int): Window size 47 | H (int): Height of image 48 | W (int): Width of image 49 | 50 | Returns: 51 | x: (B, H, W, C) 52 | """ 53 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 54 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 55 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 56 | return x 57 | 58 | 59 | class WindowAttention(nn.Module): 60 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 61 | It supports both of shifted and non-shifted window. 62 | 63 | Args: 64 | dim (int): Number of input channels. 65 | window_size (tuple[int]): The height and width of the window. 66 | num_heads (int): Number of attention heads. 67 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 68 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 69 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 70 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 71 | """ 72 | 73 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 74 | 75 | super().__init__() 76 | self.dim = dim 77 | self.window_size = window_size # Wh, Ww 78 | self.num_heads = num_heads 79 | head_dim = dim // num_heads 80 | self.scale = qk_scale or head_dim ** -0.5 81 | 82 | # define a parameter table of relative position bias 83 | self.relative_position_bias_table = nn.Parameter( 84 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 85 | 86 | # get pair-wise relative position index for each token inside the window 87 | coords_h = torch.arange(self.window_size[0]) 88 | coords_w = torch.arange(self.window_size[1]) 89 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 90 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 91 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 92 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 93 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 94 | relative_coords[:, :, 1] += self.window_size[1] - 1 95 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 96 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 97 | self.register_buffer("relative_position_index", relative_position_index) 98 | 99 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 100 | self.attn_drop = nn.Dropout(attn_drop) 101 | self.proj = nn.Linear(dim, dim) 102 | self.proj_drop = nn.Dropout(proj_drop) 103 | 104 | trunc_normal_(self.relative_position_bias_table, std=.02) 105 | self.softmax = nn.Softmax(dim=-1) 106 | 107 | def forward(self, x, mask=None): 108 | """ 109 | Args: 110 | x: input features with shape of (num_windows*B, N, C) 111 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 112 | """ 113 | B_, N, C = x.shape 114 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 116 | 117 | q = q * self.scale 118 | attn = (q @ k.transpose(-2, -1)) 119 | 120 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 121 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 122 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 123 | attn = attn + relative_position_bias.unsqueeze(0) 124 | 125 | if mask is not None: 126 | nW = mask.shape[0] 127 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 128 | attn = attn.view(-1, self.num_heads, N, N) 129 | attn = self.softmax(attn) 130 | else: 131 | attn = self.softmax(attn) 132 | 133 | attn = self.attn_drop(attn) 134 | 135 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 136 | x = self.proj(x) 137 | x = self.proj_drop(x) 138 | return x 139 | 140 | def extra_repr(self) -> str: 141 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 142 | 143 | def flops(self, N): 144 | # calculate flops for 1 window with token length of N 145 | flops = 0 146 | # qkv = self.qkv(x) 147 | flops += N * self.dim * 3 * self.dim 148 | # attn = (q @ k.transpose(-2, -1)) 149 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 150 | # x = (attn @ v) 151 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 152 | # x = self.proj(x) 153 | flops += N * self.dim * self.dim 154 | return flops 155 | 156 | 157 | class SwinTransformerBlock(nn.Module): 158 | r""" Swin Transformer Block. 159 | 160 | Args: 161 | dim (int): Number of input channels. 162 | input_resolution (tuple[int]): Input resulotion. 163 | num_heads (int): Number of attention heads. 164 | window_size (int): Window size. 165 | shift_size (int): Shift size for SW-MSA. 166 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 167 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 168 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 169 | drop (float, optional): Dropout rate. Default: 0.0 170 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 171 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 172 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 173 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 174 | """ 175 | 176 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 177 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 178 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 179 | super().__init__() 180 | self.dim = dim 181 | self.input_resolution = input_resolution 182 | self.num_heads = num_heads 183 | self.window_size = window_size 184 | self.shift_size = shift_size 185 | self.mlp_ratio = mlp_ratio 186 | if min(self.input_resolution) <= self.window_size: 187 | # if window size is larger than input resolution, we don't partition windows 188 | self.shift_size = 0 189 | self.window_size = min(self.input_resolution) 190 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 191 | 192 | self.norm1 = norm_layer(dim) 193 | self.attn = WindowAttention( 194 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 195 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 196 | 197 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 198 | self.norm2 = norm_layer(dim) 199 | mlp_hidden_dim = int(dim * mlp_ratio) 200 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 201 | 202 | if self.shift_size > 0: 203 | # calculate attention mask for SW-MSA 204 | H, W = self.input_resolution 205 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 206 | h_slices = (slice(0, -self.window_size), 207 | slice(-self.window_size, -self.shift_size), 208 | slice(-self.shift_size, None)) 209 | w_slices = (slice(0, -self.window_size), 210 | slice(-self.window_size, -self.shift_size), 211 | slice(-self.shift_size, None)) 212 | cnt = 0 213 | for h in h_slices: 214 | for w in w_slices: 215 | img_mask[:, h, w, :] = cnt 216 | cnt += 1 217 | 218 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 219 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 220 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 221 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 222 | else: 223 | attn_mask = None 224 | 225 | self.register_buffer("attn_mask", attn_mask) 226 | 227 | def forward(self, x): 228 | H, W = self.input_resolution 229 | B, L, C = x.shape 230 | assert L == H * W, "input feature has wrong size" 231 | 232 | shortcut = x 233 | x = self.norm1(x) 234 | x = x.view(B, H, W, C) 235 | 236 | # cyclic shift 237 | if self.shift_size > 0: 238 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 239 | else: 240 | shifted_x = x 241 | 242 | # partition windows 243 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 244 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 245 | 246 | # W-MSA/SW-MSA 247 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 248 | 249 | # merge windows 250 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 251 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 252 | 253 | # reverse cyclic shift 254 | if self.shift_size > 0: 255 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 256 | else: 257 | x = shifted_x 258 | x = x.view(B, H * W, C) 259 | 260 | # FFN 261 | x = shortcut + self.drop_path(x) 262 | x = x + self.drop_path(self.mlp(self.norm2(x))) 263 | 264 | return x 265 | 266 | def extra_repr(self) -> str: 267 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 268 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 269 | 270 | def flops(self): 271 | flops = 0 272 | H, W = self.input_resolution 273 | # norm1 274 | flops += self.dim * H * W 275 | # W-MSA/SW-MSA 276 | nW = H * W / self.window_size / self.window_size 277 | flops += nW * self.attn.flops(self.window_size * self.window_size) 278 | # mlp 279 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 280 | # norm2 281 | flops += self.dim * H * W 282 | return flops 283 | 284 | 285 | class PatchMerging(nn.Module): 286 | r""" Patch Merging Layer. 287 | 288 | Args: 289 | input_resolution (tuple[int]): Resolution of input feature. 290 | dim (int): Number of input channels. 291 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 292 | """ 293 | 294 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 295 | super().__init__() 296 | self.input_resolution = input_resolution 297 | self.dim = dim 298 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 299 | self.norm = norm_layer(4 * dim) 300 | 301 | def forward(self, x): 302 | """ 303 | x: B, H*W, C 304 | """ 305 | H, W = self.input_resolution 306 | B, L, C = x.shape 307 | assert L == H * W, "input feature has wrong size" 308 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 309 | 310 | x = x.view(B, H, W, C) 311 | 312 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 313 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 314 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 315 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 316 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 317 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 318 | 319 | x = self.norm(x) 320 | x = self.reduction(x) 321 | 322 | return x 323 | 324 | def extra_repr(self) -> str: 325 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 326 | 327 | def flops(self): 328 | H, W = self.input_resolution 329 | flops = H * W * self.dim 330 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 331 | return flops 332 | 333 | class PatchExpand(nn.Module): 334 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 335 | super().__init__() 336 | self.input_resolution = input_resolution 337 | self.dim = dim 338 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() 339 | self.norm = norm_layer(dim // dim_scale) 340 | 341 | def forward(self, x): 342 | """ 343 | x: B, H*W, C 344 | """ 345 | H, W = self.input_resolution 346 | x = self.expand(x) 347 | B, L, C = x.shape 348 | assert L == H * W, "input feature has wrong size" 349 | 350 | x = x.view(B, H, W, C) 351 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) 352 | x = x.view(B,-1,C//4) 353 | x= self.norm(x) 354 | 355 | return x 356 | 357 | class FinalPatchExpand_X4(nn.Module): 358 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 359 | super().__init__() 360 | self.input_resolution = input_resolution 361 | self.dim = dim 362 | self.dim_scale = dim_scale 363 | self.expand = nn.Linear(dim, 16*dim, bias=False) 364 | self.output_dim = dim 365 | self.norm = norm_layer(self.output_dim) 366 | 367 | def forward(self, x): 368 | """ 369 | x: B, H*W, C 370 | """ 371 | H, W = self.input_resolution 372 | x = self.expand(x) 373 | B, L, C = x.shape 374 | assert L == H * W, "input feature has wrong size" 375 | 376 | x = x.view(B, H, W, C) 377 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) 378 | x = x.view(B,-1,self.output_dim) 379 | x= self.norm(x) 380 | 381 | return x 382 | 383 | class BasicLayer(nn.Module): 384 | """ A basic Swin Transformer layer for one stage. 385 | 386 | Args: 387 | dim (int): Number of input channels. 388 | input_resolution (tuple[int]): Input resolution. 389 | depth (int): Number of blocks. 390 | num_heads (int): Number of attention heads. 391 | window_size (int): Local window size. 392 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 393 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 394 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 395 | drop (float, optional): Dropout rate. Default: 0.0 396 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 397 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 398 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 399 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 400 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 401 | """ 402 | 403 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 404 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 405 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 406 | 407 | super().__init__() 408 | self.dim = dim 409 | self.input_resolution = input_resolution 410 | self.depth = depth 411 | self.use_checkpoint = use_checkpoint 412 | 413 | # build blocks 414 | self.blocks = nn.ModuleList([ 415 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 416 | num_heads=num_heads, window_size=window_size, 417 | shift_size=0 if (i % 2 == 0) else window_size // 2, 418 | mlp_ratio=mlp_ratio, 419 | qkv_bias=qkv_bias, qk_scale=qk_scale, 420 | drop=drop, attn_drop=attn_drop, 421 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 422 | norm_layer=norm_layer) 423 | for i in range(depth)]) 424 | 425 | # patch merging layer 426 | if downsample is not None: 427 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 428 | else: 429 | self.downsample = None 430 | 431 | def forward(self, x): 432 | for blk in self.blocks: 433 | if self.use_checkpoint: 434 | x = checkpoint.checkpoint(blk, x) 435 | else: 436 | x = blk(x) 437 | if self.downsample is not None: 438 | x = self.downsample(x) 439 | return x 440 | 441 | def extra_repr(self) -> str: 442 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 443 | 444 | def flops(self): 445 | flops = 0 446 | for blk in self.blocks: 447 | flops += blk.flops() 448 | if self.downsample is not None: 449 | flops += self.downsample.flops() 450 | return flops 451 | 452 | class BasicLayer_up(nn.Module): 453 | """ A basic Swin Transformer layer for one stage. 454 | 455 | Args: 456 | dim (int): Number of input channels. 457 | input_resolution (tuple[int]): Input resolution. 458 | depth (int): Number of blocks. 459 | num_heads (int): Number of attention heads. 460 | window_size (int): Local window size. 461 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 462 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 463 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 464 | drop (float, optional): Dropout rate. Default: 0.0 465 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 466 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 467 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 468 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 469 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 470 | """ 471 | 472 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 473 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 474 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 475 | 476 | super().__init__() 477 | self.dim = dim 478 | self.input_resolution = input_resolution 479 | self.depth = depth 480 | self.use_checkpoint = use_checkpoint 481 | 482 | # build blocks 483 | self.blocks = nn.ModuleList([ 484 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 485 | num_heads=num_heads, window_size=window_size, 486 | shift_size=0 if (i % 2 == 0) else window_size // 2, 487 | mlp_ratio=mlp_ratio, 488 | qkv_bias=qkv_bias, qk_scale=qk_scale, 489 | drop=drop, attn_drop=attn_drop, 490 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 491 | norm_layer=norm_layer) 492 | for i in range(depth)]) 493 | 494 | # patch merging layer 495 | if upsample is not None: 496 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 497 | else: 498 | self.upsample = None 499 | 500 | def forward(self, x): 501 | for blk in self.blocks: 502 | if self.use_checkpoint: 503 | x = checkpoint.checkpoint(blk, x) 504 | else: 505 | x = blk(x) 506 | if self.upsample is not None: 507 | x = self.upsample(x) 508 | return x 509 | 510 | class PatchEmbed(nn.Module): 511 | r""" Image to Patch Embedding 512 | 513 | Args: 514 | img_size (int): Image size. Default: 224. 515 | patch_size (int): Patch token size. Default: 4. 516 | in_chans (int): Number of input image channels. Default: 3. 517 | embed_dim (int): Number of linear projection output channels. Default: 96. 518 | norm_layer (nn.Module, optional): Normalization layer. Default: None 519 | """ 520 | 521 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 522 | super().__init__() 523 | img_size = to_2tuple(img_size) 524 | patch_size = to_2tuple(patch_size) 525 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 526 | self.img_size = img_size 527 | self.patch_size = patch_size 528 | self.patches_resolution = patches_resolution 529 | self.num_patches = patches_resolution[0] * patches_resolution[1] 530 | 531 | self.in_chans = in_chans 532 | self.embed_dim = embed_dim 533 | 534 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 535 | if norm_layer is not None: 536 | self.norm = norm_layer(embed_dim) 537 | else: 538 | self.norm = None 539 | 540 | def forward(self, x): 541 | B, C, H, W = x.shape 542 | # FIXME look at relaxing size constraints 543 | assert H == self.img_size[0] and W == self.img_size[1], \ 544 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 545 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 546 | if self.norm is not None: 547 | x = self.norm(x) 548 | return x 549 | 550 | def flops(self): 551 | Ho, Wo = self.patches_resolution 552 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 553 | if self.norm is not None: 554 | flops += Ho * Wo * self.embed_dim 555 | return flops 556 | 557 | 558 | class SwinTransformerSys(nn.Module): 559 | r""" Swin Transformer 560 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 561 | https://arxiv.org/pdf/2103.14030 562 | 563 | Args: 564 | img_size (int | tuple(int)): Input image size. Default 224 565 | patch_size (int | tuple(int)): Patch size. Default: 4 566 | in_chans (int): Number of input image channels. Default: 3 567 | num_classes (int): Number of classes for classification head. Default: 1000 568 | embed_dim (int): Patch embedding dimension. Default: 96 569 | depths (tuple(int)): Depth of each Swin Transformer layer. 570 | num_heads (tuple(int)): Number of attention heads in different layers. 571 | window_size (int): Window size. Default: 7 572 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 573 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 574 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 575 | drop_rate (float): Dropout rate. Default: 0 576 | attn_drop_rate (float): Attention dropout rate. Default: 0 577 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 578 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 579 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 580 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 581 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 582 | """ 583 | 584 | # FIXME embed_dim=96 check一下是512还是96 585 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 586 | embed_dim=512, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], 587 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 588 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 589 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 590 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 591 | super().__init__() 592 | 593 | print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, 594 | depths_decoder,drop_path_rate,num_classes)) 595 | 596 | self.num_classes = num_classes 597 | self.num_layers = len(depths) 598 | self.embed_dim = embed_dim 599 | self.ape = ape 600 | self.patch_norm = patch_norm 601 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 602 | self.num_features_up = int(embed_dim * 2) 603 | self.mlp_ratio = mlp_ratio 604 | self.final_upsample = final_upsample 605 | 606 | # split image into non-overlapping patches 607 | self.patch_embed = PatchEmbed( 608 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 609 | norm_layer=norm_layer if self.patch_norm else None) 610 | num_patches = self.patch_embed.num_patches 611 | patches_resolution = self.patch_embed.patches_resolution 612 | self.patches_resolution = patches_resolution 613 | 614 | # absolute position embedding 615 | if self.ape: 616 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 617 | trunc_normal_(self.absolute_pos_embed, std=.02) 618 | 619 | self.pos_drop = nn.Dropout(p=drop_rate) 620 | 621 | # stochastic depth 622 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 623 | 624 | # build encoder and bottleneck layers 625 | self.layers = nn.ModuleList() 626 | for i_layer in range(self.num_layers): 627 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 628 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 629 | patches_resolution[1] // (2 ** i_layer)), 630 | depth=depths[i_layer], 631 | num_heads=num_heads[i_layer], 632 | window_size=window_size, 633 | mlp_ratio=self.mlp_ratio, 634 | qkv_bias=qkv_bias, qk_scale=qk_scale, 635 | drop=drop_rate, attn_drop=attn_drop_rate, 636 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 637 | norm_layer=norm_layer, 638 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 639 | use_checkpoint=use_checkpoint) 640 | self.layers.append(layer) 641 | 642 | # build decoder layers 643 | self.layers_up = nn.ModuleList() 644 | self.concat_back_dim = nn.ModuleList() 645 | for i_layer in range(self.num_layers): 646 | concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), 647 | int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() 648 | if i_layer ==0 : 649 | layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), 650 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) 651 | else: 652 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), 653 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), 654 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), 655 | depth=depths[(self.num_layers-1-i_layer)], 656 | num_heads=num_heads[(self.num_layers-1-i_layer)], 657 | window_size=window_size, 658 | mlp_ratio=self.mlp_ratio, 659 | qkv_bias=qkv_bias, qk_scale=qk_scale, 660 | drop=drop_rate, attn_drop=attn_drop_rate, 661 | drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], 662 | norm_layer=norm_layer, 663 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, 664 | use_checkpoint=use_checkpoint) 665 | self.layers_up.append(layer_up) 666 | self.concat_back_dim.append(concat_linear) 667 | 668 | self.norm = norm_layer(self.num_features) 669 | self.norm_up= norm_layer(self.embed_dim) 670 | # self.norm_up= norm_layer(192) 671 | 672 | if self.final_upsample == "expand_first": 673 | print("---final upsample expand_first---") 674 | self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim) 675 | # self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False) 676 | self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=3, stride=2, padding=1, bias=False) 677 | 678 | self.output_2 = nn.Conv2d(in_channels=self.num_classes,out_channels=self.num_classes,kernel_size=3, stride=2, padding=1, bias=False) 679 | 680 | self.apply(self._init_weights) 681 | 682 | def _init_weights(self, m): 683 | if isinstance(m, nn.Linear): 684 | trunc_normal_(m.weight, std=.02) 685 | if isinstance(m, nn.Linear) and m.bias is not None: 686 | nn.init.constant_(m.bias, 0) 687 | elif isinstance(m, nn.LayerNorm): 688 | nn.init.constant_(m.bias, 0) 689 | nn.init.constant_(m.weight, 1.0) 690 | 691 | @torch.jit.ignore 692 | def no_weight_decay(self): 693 | return {'absolute_pos_embed'} 694 | 695 | @torch.jit.ignore 696 | def no_weight_decay_keywords(self): 697 | return {'relative_position_bias_table'} 698 | 699 | #Encoder and Bottleneck 700 | def forward_features(self, x): 701 | x = self.patch_embed(x) 702 | if self.ape: 703 | x = x + self.absolute_pos_embed 704 | x = self.pos_drop(x) 705 | x_downsample = [] 706 | 707 | for layer in self.layers: 708 | x_downsample.append(x) 709 | x = layer(x) 710 | 711 | x = self.norm(x) # B L C [1,49,768] 712 | 713 | return x, x_downsample 714 | 715 | #Dencoder and Skip connection 716 | ## FIXME 修改网络结构 保证预训练的参数还可以用 跳过一些层实时 717 | def forward_up_features(self, x, x_downsample): 718 | for inx, layer_up in enumerate(self.layers_up): 719 | # print("swin decoder",self.layers_up) 720 | if inx == 0: 721 | x = layer_up(x) 722 | else: 723 | x = torch.cat([x,x_downsample[3-inx]],-1) 724 | x = self.concat_back_dim[inx](x) 725 | x = layer_up(x) 726 | 727 | x = self.norm_up(x) # B L C [1,3136,96] 728 | 729 | return x 730 | 731 | def up_x4(self, x): 732 | H, W = self.patches_resolution 733 | 734 | B, L, C = x.shape 735 | assert L == H*W, "input features has wrong size" 736 | 737 | if self.final_upsample=="expand_first": 738 | x = self.up(x) 739 | x = x.view(B,4*H,4*W,-1) 740 | x = x.permute(0,3,1,2) #B,C,H,W 741 | x = self.output(x) # 1/2 742 | x = self.output_2(x) # 1/4 743 | x = self.output_2(x) # 1/8 744 | 745 | ## FIXME without patch expanding 746 | # x = x.view(B, H, W, -1) # 1/4 747 | # x = x.permute(0,3,1,2) #B,C,H,W 748 | # x = self.output(x) # 1/8 749 | 750 | 751 | return x 752 | 753 | def forward(self, x): 754 | x, x_downsample = self.forward_features(x) 755 | x = self.forward_up_features(x,x_downsample) 756 | x = self.up_x4(x) 757 | # print("out shape", x.shape) 758 | # x = x.view(1,28,28,-1) 759 | # x = x.permute(0,3,2,1) 760 | # x = self.output(x) 761 | # print("output:", x.shape) 762 | # x = self.up_x4(x) 763 | 764 | return x 765 | 766 | def flops(self): 767 | flops = 0 768 | flops += self.patch_embed.flops() 769 | for i, layer in enumerate(self.layers): 770 | flops += layer.flops() 771 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 772 | flops += self.num_features * self.num_classes 773 | return flops 774 | -------------------------------------------------------------------------------- /lib/backbone_model/swin_unet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | from .swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class SwinUnet(nn.Module): 25 | def __init__(self, pretrained_path, img_size=224, num_classes=21843, zero_head=False, vis=False): 26 | super(SwinUnet, self).__init__() 27 | self.num_classes = num_classes 28 | self.zero_head = zero_head 29 | 30 | self.swin_unet = SwinTransformerSys(img_size=224, 31 | patch_size=4, 32 | in_chans=3, 33 | num_classes=512, 34 | embed_dim=96, 35 | depths=[2, 2, 2, 2], 36 | num_heads=[3, 6, 12, 24], 37 | window_size=7, 38 | mlp_ratio=4, 39 | qkv_bias=True, 40 | qk_scale=None, 41 | drop_rate=0, 42 | drop_path_rate=0.1, 43 | ape=True, 44 | patch_norm=True, 45 | use_checkpoint=True) 46 | 47 | def forward(self, x): 48 | if x.size()[1] == 1: 49 | x = x.repeat(1, 3, 1, 1) 50 | logits = self.swin_unet(x) 51 | return logits 52 | 53 | def load_from(self, pretrained_path): 54 | pretrained_path = pretrained_path 55 | if pretrained_path is not None: 56 | print("pretrained_path:{}".format(pretrained_path)) 57 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 58 | pretrained_dict = torch.load(pretrained_path, map_location=device) 59 | if "model" not in pretrained_dict: 60 | print("---start load pretrained modle by splitting---") 61 | pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()} 62 | for k in list(pretrained_dict.keys()): 63 | if "output" in k: 64 | print("delete key:{}".format(k)) 65 | del pretrained_dict[k] 66 | msg = self.swin_unet.load_state_dict(pretrained_dict, strict=False) 67 | # print(msg) 68 | return 69 | pretrained_dict = pretrained_dict['model'] 70 | print("---start load pretrained modle of swin encoder---") 71 | 72 | model_dict = self.swin_unet.state_dict() 73 | full_dict = copy.deepcopy(pretrained_dict) 74 | for k, v in pretrained_dict.items(): 75 | if "layers." in k: 76 | current_layer_num = 3 - int(k[7:8]) 77 | current_k = "layers_up." + str(current_layer_num) + k[8:] 78 | full_dict.update({current_k: v}) 79 | for k in list(full_dict.keys()): 80 | if k in model_dict: 81 | if full_dict[k].shape != model_dict[k].shape: 82 | print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape)) 83 | del full_dict[k] 84 | 85 | msg = self.swin_unet.load_state_dict(full_dict, strict=False) 86 | # print(msg) 87 | else: 88 | print("none pretrain") 89 | -------------------------------------------------------------------------------- /lib/backbone_model/unet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Baseline unet: Test baseline without any operations 3 | ''' 4 | 5 | import torch.nn as nn 6 | import torch 7 | 8 | class conv_block_nested(nn.Module): 9 | def __init__(self, in_ch, mid_ch, out_ch): 10 | super(conv_block_nested, self).__init__() 11 | self.activation = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True) 13 | self.bn1 = nn.BatchNorm2d(mid_ch) 14 | self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True) 15 | self.bn2 = nn.BatchNorm2d(out_ch) 16 | 17 | def forward(self, x): 18 | x = self.conv1(x) 19 | identity = x 20 | x = self.bn1(x) 21 | x = self.activation(x) 22 | 23 | x = self.conv2(x) 24 | x = self.bn2(x) 25 | output = self.activation(x + identity) 26 | return output 27 | 28 | 29 | class up(nn.Module): 30 | def __init__(self, in_ch, bilinear=False): 31 | super(up, self).__init__() 32 | 33 | if bilinear: 34 | self.up = nn.Upsample(scale_factor=2, 35 | mode='bilinear', 36 | align_corners=True) 37 | else: 38 | self.up = nn.ConvTranspose2d(in_ch, in_ch, 2, stride=2) 39 | 40 | def forward(self, x): 41 | x = self.up(x) 42 | return x 43 | 44 | class ChannelAttention(nn.Module): 45 | def __init__(self, in_channels, ratio = 16): 46 | super(ChannelAttention, self).__init__() 47 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 48 | self.max_pool = nn.AdaptiveMaxPool2d(1) 49 | self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False) 50 | self.relu1 = nn.ReLU() 51 | self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False) 52 | self.sigmod = nn.Sigmoid() 53 | def forward(self,x): 54 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 55 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 56 | out = avg_out + max_out 57 | return self.sigmod(out) 58 | 59 | 60 | n1 = 32 61 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] 62 | class Encoder(nn.Module): 63 | def __init__(self): 64 | super().__init__() 65 | in_ch = 3 66 | self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0]) 67 | self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1]) 68 | self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2]) 69 | self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3]) 70 | self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4]) 71 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 72 | 73 | def forward(self, input): 74 | x0_0 = self.conv0_0(input) 75 | x1_0 = self.conv1_0(self.pool(x0_0)) 76 | x2_0 = self.conv2_0(self.pool(x1_0)) 77 | x3_0 = self.conv3_0(self.pool(x2_0)) # [1, 256, 32, 32] 78 | x4_0 = self.conv4_0(self.pool(x3_0)) # [1, 512, 16, 16] 79 | # print("x30 shape:", x3_0.shape) 80 | return [x0_0, x1_0, x2_0, x3_0, x4_0] 81 | 82 | 83 | class Decoder(nn.Module): 84 | def __init__(self): 85 | super().__init__() 86 | out_ch = 2 87 | # self.conv3_1 = conv_block_nested(filters[3] * 2 + filters[4], filters[3], filters[3]) 88 | self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[4], filters[4]) 89 | self.conv2_2 = conv_block_nested(filters[2] * 2 + filters[3], filters[2], filters[2]) 90 | self.conv1_3 = conv_block_nested(filters[1] * 2 + filters[2], filters[1], filters[1]) 91 | self.conv0_4 = conv_block_nested(filters[0] * 2 + filters[1], filters[0], filters[0]) 92 | self.Up4_0 = up(filters[4]) 93 | self.Up3_1 = up(filters[3]) 94 | self.Up2_2 = up(filters[2]) 95 | self.Up1_3 = up(filters[1]) 96 | self.conv_final = nn.Conv2d(filters[0], out_ch, kernel_size=1) 97 | 98 | def forward(self, x_A): 99 | self.x0_0A, self.x1_0A, self.x2_0A, self.x3_0A, self.x4_0A = x_A 100 | # self.x0_0B, self.x1_0B, self.x2_0B, self.x3_0B, self.x4_0B = x_B 101 | x3_1 = self.conv3_1(torch.cat([self.x3_0A, self.Up4_0(self.x4_0A)], 1)) # [1,512,32,32] 102 | print("x31 shape:", x3_1.shape) 103 | # x2_2 = self.conv2_2(torch.cat([self.x2_0A, self.x2_0B, self.Up3_1(x3_1)], 1)) # [1, 128, 64, 64] 104 | # x1_3 = self.conv1_3(torch.cat([self.x1_0A, self.x1_0B, self.Up2_2(x2_2)], 1)) # [1, 64, 128, 128] 105 | # x0_4 = self.conv0_4(torch.cat([self.x0_0A, self.x0_0B, self.Up1_3(x1_3)], 1)) # [1, 64, 256, 128] 106 | # print("x04:", x0_4.shape) 107 | # out = self.conv_final(x0_4) 108 | return x3_1 109 | 110 | 111 | class UNet(nn.Module): 112 | def __init__(self): 113 | super(UNet, self).__init__() 114 | torch.nn.Module.dump_patches = True 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 119 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 120 | nn.init.constant_(m.weight, 1) 121 | nn.init.constant_(m.bias, 0) 122 | 123 | '''Encoder return multi feature map to nested 124 | ''' 125 | self.encoder = Encoder() 126 | self.decoder = Decoder() 127 | 128 | 129 | def forward(self, batch): 130 | encode = self.encoder(batch) 131 | output = self.decoder.forward(encode) 132 | 133 | return output 134 | # x0_B = self.encoder(xB) 135 | # out = self.decoder.forward(x0_A, x0_B) 136 | # return (out, ), [out] 137 | # 138 | # if __name__ == "__main__": 139 | # num_classes = [2,2] 140 | # current_task = 1 141 | # unet = UNet() 142 | # image_A = torch.randn((4, 3, 256, 256)) 143 | # image_B = torch.randn((4, 3, 256, 256)) 144 | # for name, m in unet.named_parameters(): 145 | # print(name) 146 | # outputs_change, feature_map = unet(image_A, image_B, current_task) 147 | -------------------------------------------------------------------------------- /lib/dataset_terra.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from PIL import Image 4 | import os 5 | import torch 6 | from torch.utils.data import Dataset 7 | import time 8 | from tqdm import tqdm 9 | from lib.utils import preprocess_image 10 | 11 | class TerraDataset(Dataset): 12 | def __init__( 13 | self, 14 | scene_list_path='terratrack_utils/train_scenes_500.txt', 15 | scene_info_path='/home/a409/users/huboni/Projects/dataset/TerraTrack/process_output_query_ref_500', 16 | base_path='/home/a409/users/huboni/Projects/dataset/TerraTrack', 17 | train=True, 18 | preprocessing=None, 19 | min_overlap_ratio=.5, 20 | max_overlap_ratio=1, 21 | max_scale_ratio=np.inf, 22 | pairs_per_scene=5, 23 | image_size=224 24 | ): 25 | self.scenes = [] 26 | with open(scene_list_path, 'r') as f: 27 | lines = f.readlines() 28 | for line in lines: 29 | self.scenes.append(line.strip('\n')) 30 | print("scenes:", self.scenes) 31 | 32 | self.scene_info_path = scene_info_path 33 | self.base_path = base_path 34 | 35 | self.train = train 36 | 37 | self.preprocessing = preprocessing 38 | 39 | self.min_overlap_ratio = min_overlap_ratio 40 | self.max_overlap_ratio = max_overlap_ratio 41 | self.max_scale_ratio = max_scale_ratio 42 | 43 | self.pairs_per_scene = pairs_per_scene 44 | 45 | self.image_size = image_size 46 | 47 | self.dataset = [] 48 | 49 | def build_dataset(self): 50 | self.dataset = [] 51 | if not self.train: 52 | np_random_state = np.random.get_state() 53 | np.random.seed(42) 54 | print('Building the validation dataset...') 55 | else: 56 | print('Building a new training dataset...') 57 | for scene in tqdm(self.scenes, total=len(self.scenes)): 58 | scene_info_path = os.path.join( 59 | self.scene_info_path, '%s.npz' % scene 60 | ) 61 | if not os.path.exists(scene_info_path): 62 | continue 63 | scene_info = np.load(scene_info_path, allow_pickle=True) 64 | overlap_matrix = scene_info['overlap_matrix'] 65 | scale_ratio_matrix = scene_info['scale_ratio_matrix'] 66 | 67 | valid = np.logical_and( 68 | np.logical_and( 69 | overlap_matrix >= self.min_overlap_ratio, 70 | overlap_matrix <= self.max_overlap_ratio 71 | ), 72 | scale_ratio_matrix <= self.max_scale_ratio 73 | ) 74 | # 得到匹配对 75 | pairs = np.vstack(np.where(valid)) 76 | # 如果该场景中配配对= 0) 143 | image_path1 = os.path.join( 144 | self.base_path, pair_metadata['image_path1'] 145 | ) 146 | image1 = Image.open(image_path1) 147 | if image1.mode != 'RGB': 148 | image1 = image1.convert('RGB') 149 | image1 = np.array(image1) 150 | # make sure image and depth have same weight and height 151 | assert(image1.shape[0] == depth1.shape[0] and image1.shape[1] == depth1.shape[1]) 152 | assert(image1.shape[0]>100) 153 | intrinsics1 = pair_metadata['intrinsics1'] 154 | pose1 = pair_metadata['pose1'] 155 | 156 | depth_path2 = os.path.join( 157 | self.base_path, pair_metadata['depth_path2'] 158 | ) 159 | with h5py.File(depth_path2, 'r') as hdf5_file: 160 | depth2 = np.array(hdf5_file['/depth']) 161 | assert(np.min(depth2) >= 0) 162 | image_path2 = os.path.join( 163 | self.base_path, pair_metadata['image_path2'] 164 | ) 165 | image2 = Image.open(image_path2) 166 | if image2.mode != 'RGB': 167 | image2 = image2.convert('RGB') 168 | image2 = np.array(image2) 169 | assert(image2.shape[0] == depth2.shape[0] and image2.shape[1] == depth2.shape[1]) 170 | assert(image2.shape[0]>100) 171 | intrinsics2 = pair_metadata['intrinsics2'] 172 | pose2 = pair_metadata['pose2'] 173 | 174 | central_match = pair_metadata['central_match'] 175 | # 通过centeral_match确定两张图像需要保留的区域,并计算出响应的边界框 176 | image1, bbox1, image2, bbox2 = self.crop(image1, image2, central_match) 177 | 178 | depth1 = depth1[ 179 | bbox1[0] : bbox1[0] + self.image_size, 180 | bbox1[1] : bbox1[1] + self.image_size 181 | ] 182 | depth2 = depth2[ 183 | bbox2[0] : bbox2[0] + self.image_size, 184 | bbox2[1] : bbox2[1] + self.image_size 185 | ] 186 | 187 | return ( 188 | image1, depth1, intrinsics1, pose1, bbox1, 189 | image2, depth2, intrinsics2, pose2, bbox2 190 | ) 191 | 192 | def crop(self, image1, image2, central_match): 193 | bbox1_i = max(int(central_match[0]) - self.image_size // 2, 0) 194 | if bbox1_i + self.image_size >= image1.shape[0]: 195 | bbox1_i = image1.shape[0] - self.image_size 196 | bbox1_j = max(int(central_match[1]) - self.image_size // 2, 0) 197 | if bbox1_j + self.image_size >= image1.shape[1]: 198 | bbox1_j = image1.shape[1] - self.image_size 199 | 200 | bbox2_i = max(int(central_match[2]) - self.image_size // 2, 0) 201 | if bbox2_i + self.image_size >= image2.shape[0]: 202 | bbox2_i = image2.shape[0] - self.image_size 203 | bbox2_j = max(int(central_match[3]) - self.image_size // 2, 0) 204 | if bbox2_j + self.image_size >= image2.shape[1]: 205 | bbox2_j = image2.shape[1] - self.image_size 206 | 207 | return ( 208 | image1[ 209 | bbox1_i : bbox1_i + self.image_size, 210 | bbox1_j : bbox1_j + self.image_size 211 | ], 212 | np.array([bbox1_i, bbox1_j]), 213 | image2[ 214 | bbox2_i : bbox2_i + self.image_size, 215 | bbox2_j : bbox2_j + self.image_size 216 | ], 217 | np.array([bbox2_i, bbox2_j]) 218 | ) 219 | 220 | def __getitem__(self, idx): 221 | ( 222 | image1, depth1, intrinsics1, pose1, bbox1, 223 | image2, depth2, intrinsics2, pose2, bbox2 224 | ) = self.recover_pair(self.dataset[idx]) 225 | 226 | # if use model vit ignore preprocess_image 227 | image1 = preprocess_image(image1, preprocessing=self.preprocessing) 228 | image2 = preprocess_image(image2, preprocessing=self.preprocessing) 229 | 230 | return { 231 | 'image1': torch.from_numpy(image1.astype(np.float32)), 232 | 'depth1': torch.from_numpy(depth1.astype(np.float32)), 233 | 'intrinsics1': torch.from_numpy(intrinsics1.astype(np.float32)), 234 | 'pose1': torch.from_numpy(pose1.astype(np.float32)), 235 | 'bbox1': torch.from_numpy(bbox1.astype(np.float32)), 236 | 'image2': torch.from_numpy(image2.astype(np.float32)), 237 | 'depth2': torch.from_numpy(depth2.astype(np.float32)), 238 | 'intrinsics2': torch.from_numpy(intrinsics2.astype(np.float32)), 239 | 'pose2': torch.from_numpy(pose2.astype(np.float32)), 240 | 'bbox2': torch.from_numpy(bbox2.astype(np.float32)) 241 | } 242 | -------------------------------------------------------------------------------- /lib/exceptions.py: -------------------------------------------------------------------------------- 1 | class EmptyTensorError(Exception): 2 | pass 3 | 4 | 5 | class NoGradientError(Exception): 6 | pass 7 | -------------------------------------------------------------------------------- /lib/full_model/__pycache__/model_d2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/full_model/__pycache__/model_d2.cpython-310.pyc -------------------------------------------------------------------------------- /lib/full_model/__pycache__/model_d2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/full_model/__pycache__/model_d2.cpython-38.pyc -------------------------------------------------------------------------------- /lib/full_model/__pycache__/model_swin_unet_d2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/full_model/__pycache__/model_swin_unet_d2.cpython-310.pyc -------------------------------------------------------------------------------- /lib/full_model/__pycache__/model_swin_unet_d2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/full_model/__pycache__/model_swin_unet_d2.cpython-38.pyc -------------------------------------------------------------------------------- /lib/full_model/__pycache__/model_unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boni-hu/CurriculumLoc/f1c53587c23c1f10c20a17bde8ef8d0159391e96/lib/full_model/__pycache__/model_unet.cpython-38.pyc -------------------------------------------------------------------------------- /lib/full_model/model_erf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import ViTFeatureExtractor, ViTModel 5 | import torchvision.models as models 6 | class DenseFeatureExtractionModule(nn.Module): 7 | def __init__(self, finetune_feature_extraction=False, use_cuda=True): 8 | super(DenseFeatureExtractionModule, self).__init__() 9 | # VGG16 10 | model = models.vgg16() 11 | vgg16_layers = [ 12 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 13 | 'pool1', 14 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 15 | 'pool2', 16 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 17 | 'pool3', 18 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 19 | 'pool4', 20 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 21 | 'pool5' 22 | ] 23 | conv4_3_idx = vgg16_layers.index('conv4_3') 24 | 25 | self.model = nn.Sequential( 26 | *list(model.features.children())[: conv4_3_idx + 1] 27 | ) 28 | self.num_channels = 512 29 | 30 | 31 | # Fix forward parameters 32 | for param in self.model.parameters(): 33 | param.requires_grad = False 34 | if finetune_feature_extraction: 35 | # Unlock conv4_3 36 | for param in list(self.model.parameters())[-2 :]: 37 | param.requires_grad = True 38 | 39 | if use_cuda: 40 | self.model = self.model.cuda() 41 | 42 | def forward(self, batch): 43 | 44 | # VGG 45 | output = self.model(batch) 46 | return output 47 | 48 | 49 | class SoftDetectionModule(nn.Module): 50 | def __init__(self, soft_local_max_size=3): 51 | super(SoftDetectionModule, self).__init__() 52 | 53 | self.soft_local_max_size = soft_local_max_size 54 | 55 | self.pad = self.soft_local_max_size // 2 56 | 57 | def forward(self, batch): 58 | b = batch.size(0) 59 | 60 | batch = F.relu(batch) # [2,512,28,28] 61 | 62 | max_per_sample = torch.max(batch.view(b, -1), dim=1)[0] # [1,2] 63 | exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1)) # [2,512,28,28] 64 | 65 | sum_exp = ( 66 | self.soft_local_max_size ** 2 * 67 | F.avg_pool2d( 68 | F.pad(exp, [self.pad] * 4, mode='constant', value=1.),# [2,512,30,30] 69 | self.soft_local_max_size, stride=1 70 | ) # [2,512,28,28] 71 | ) # [2, 512,28,28] 72 | local_max_score = exp / sum_exp # alpha 73 | 74 | depth_wise_max = torch.max(batch, dim=1)[0] # [2,28,28] 75 | 76 | depth_wise_max_score = batch / depth_wise_max.unsqueeze(1) # beta [2, 512, 28, 28] 77 | 78 | all_scores = local_max_score * depth_wise_max_score # [2, 512, 28, 28] 79 | score = torch.max(all_scores, dim=1)[0] # r [2,28,28] 80 | score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1) # s [2,28,28] 81 | 82 | return score 83 | 84 | 85 | class D2Net(nn.Module): 86 | def __init__(self, model_file=None, use_cuda=True): 87 | super(D2Net, self).__init__() 88 | 89 | self.dense_feature_extraction = DenseFeatureExtractionModule( 90 | finetune_feature_extraction=True, 91 | use_cuda=use_cuda 92 | ) 93 | 94 | self.detection = SoftDetectionModule() 95 | 96 | if model_file is not None: 97 | if use_cuda: 98 | self.load_state_dict(torch.load(model_file)['model']) 99 | else: 100 | self.load_state_dict(torch.load(model_file, map_location='cpu')['model']) 101 | 102 | def forward(self, batch): 103 | b = batch['image1'].size(0) 104 | 105 | dense_features = self.dense_feature_extraction( 106 | torch.cat([batch['image1'], batch['image2']], dim=0) 107 | ) 108 | 109 | dense_features1 = dense_features[: b, :, :, :] 110 | dense_features2 = dense_features[b :, :, :, :] 111 | scores = self.detection(dense_features) 112 | scores1 = scores[: b, :, :] 113 | scores2 = scores[b :, :, :] 114 | 115 | return { 116 | 'dense_features1': dense_features1, 117 | 'scores1': scores1, 118 | 'dense_features2': dense_features2, 119 | 'scores2': scores2 120 | } -------------------------------------------------------------------------------- /lib/full_model/model_swin_unet_d2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from transformers import ViTFeatureExtractor, ViTModel 5 | # import torchvision.models as models 6 | from ..backbone_model.swin_unet import SwinUnet 7 | 8 | # FixMe 将特征提取的基础网络替换为swin-unet 其他计算scores计算匹配点对部分保持 9 | class DenseFeatureExtractionModule(nn.Module): 10 | def __init__(self, finetune_feature_extraction=False, use_cuda=True): 11 | super(DenseFeatureExtractionModule, self).__init__() 12 | 13 | # TODO model:transformer 14 | ## VGG16 15 | # model = models.vgg16() 16 | # vgg16_layers = [ 17 | # 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 18 | # 'pool1', 19 | # 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 20 | # 'pool2', 21 | # 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 22 | # 'pool3', 23 | # 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 24 | # 'pool4', 25 | # 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 26 | # 'pool5' 27 | # ] 28 | # conv4_3_idx = vgg16_layers.index('conv4_3') 29 | # 30 | # self.model = nn.Sequential( 31 | # *list(model.features.children())[: conv4_3_idx + 1] 32 | # ) 33 | # self.num_channels = 512 34 | 35 | ## TODO Swin Unet 36 | trans_pretrained_path = '/home/a409/users/huboni/Projects/code/d2-net/models/swin_tiny_patch4_window7_224.pth' 37 | trans_pretrained_path = None 38 | self.model = SwinUnet(trans_pretrained_path).cuda() 39 | self.model.load_from(trans_pretrained_path) 40 | 41 | 42 | # # Fix forward parameters 43 | # for param in self.model.parameters(): 44 | # param.requires_grad = False 45 | # if finetune_feature_extraction: 46 | # # Unlock conv4_3 47 | # for param in list(self.model.parameters())[-2 :]: 48 | # param.requires_grad = True 49 | 50 | if use_cuda: 51 | self.model = self.model.cuda() 52 | 53 | def forward(self, batch): 54 | 55 | output = self.model(batch) 56 | 57 | return output 58 | 59 | 60 | class SoftDetectionModule(nn.Module): 61 | def __init__(self, soft_local_max_size=3): 62 | super(SoftDetectionModule, self).__init__() 63 | 64 | self.soft_local_max_size = soft_local_max_size 65 | 66 | self.pad = self.soft_local_max_size // 2 67 | 68 | def forward(self, batch): 69 | b = batch.size(0) 70 | batch = F.relu(batch) 71 | max_per_sample = torch.max(batch.reshape(b, -1), dim=1)[0] 72 | # print("max pre sample:", max_per_sample) 73 | exp = torch.exp(batch / max_per_sample.reshape(b, 1, 1, 1)) 74 | 75 | sum_exp = ( 76 | self.soft_local_max_size ** 2 * 77 | F.avg_pool2d( 78 | F.pad(exp, [self.pad] * 4, mode='constant', value=1.), 79 | self.soft_local_max_size, stride=1 80 | ) 81 | ) 82 | local_max_score = exp / sum_exp 83 | 84 | depth_wise_max = torch.max(batch, dim=1)[0] 85 | depth_wise_max_score = batch / depth_wise_max.unsqueeze(1) 86 | 87 | all_scores = local_max_score * depth_wise_max_score 88 | score = torch.max(all_scores, dim=1)[0] 89 | 90 | score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1) 91 | 92 | 93 | return score 94 | 95 | 96 | class Swin_D2UNet(nn.Module): 97 | def __init__(self, use_cuda=True): 98 | super(Swin_D2UNet, self).__init__() 99 | 100 | self.dense_feature_extraction = DenseFeatureExtractionModule( 101 | finetune_feature_extraction=True, 102 | use_cuda=use_cuda 103 | ) 104 | 105 | self.detection = SoftDetectionModule() 106 | 107 | # if model_file is not None: 108 | # if use_cuda: 109 | # self.load_state_dict(torch.load(model_file)['model']) 110 | # else: 111 | # self.load_state_dict(torch.load(model_file, map_location='cpu')['model']) 112 | 113 | def forward(self, batch): 114 | b = batch['image1'].size(0) 115 | 116 | # dense_features = self.dense_feature_extraction( 117 | # torch.cat([batch['image1'], batch['image2']], dim=0) 118 | # ) 119 | # scores = self.detection(dense_features) 120 | # dense_features1 = dense_features[: b, :, :, :] 121 | # dense_features2 = dense_features[b :, :, :, :] 122 | # scores1 = scores[: b, :, :] 123 | # scores2 = scores[b :, :, :] 124 | 125 | dense_features1 = self.dense_feature_extraction(batch['image1']) # [1,1000,224,224] 126 | # print("dense feature1:", dense_features1) 127 | # print("feature max:", torch.max(dense_features1)) 128 | # print("feature min:", torch.min(dense_features1)) 129 | dense_features2 = self.dense_feature_extraction(batch['image2']) 130 | 131 | # scores = self.detection(torch.cat([dense_features1, dense_features2], dim=0)) 132 | # scores2 = scores[b :, :, :] 133 | # scores1 = scores[: b, :, :] 134 | scores1 = self.detection(dense_features1) 135 | scores2 = self.detection(dense_features2) 136 | 137 | return { 138 | 'dense_features1': dense_features1, 139 | 'scores1': scores1, 140 | 'dense_features2': dense_features2, 141 | 'scores2': scores2 142 | } -------------------------------------------------------------------------------- /lib/full_model/model_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import ViTFeatureExtractor, ViTModel 5 | import torchvision.models as models 6 | from ..backbone_model.unet import UNet 7 | class DenseFeatureExtractionModule(nn.Module): 8 | def __init__(self, finetune_feature_extraction=False, use_cuda=True): 9 | super(DenseFeatureExtractionModule, self).__init__() 10 | # VGG16 11 | # model = models.vgg16() 12 | # vgg16_layers = [ 13 | # 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 14 | # 'pool1', 15 | # 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 16 | # 'pool2', 17 | # 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 18 | # 'pool3', 19 | # 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 20 | # 'pool4', 21 | # 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 22 | # 'pool5' 23 | # ] 24 | # conv4_3_idx = vgg16_layers.index('conv4_3') 25 | # 26 | # self.model = nn.Sequential( 27 | # *list(model.features.children())[: conv4_3_idx + 1] 28 | # ) 29 | # self.num_channels = 512 30 | 31 | self.model = UNet() 32 | print("model list", self.model) 33 | 34 | 35 | # Fix forward parameters 36 | # for param in self.model.parameters(): 37 | # param.requires_grad = False 38 | # if finetune_feature_extraction: 39 | # # Unlock conv4_3 40 | # for param in list(self.model.parameters())[-2 :]: 41 | # param.requires_grad = True 42 | 43 | if use_cuda: 44 | self.model = self.model.cuda() 45 | 46 | def forward(self, batch): 47 | 48 | # VGG 49 | output = self.model(batch) 50 | return output 51 | 52 | 53 | class SoftDetectionModule(nn.Module): 54 | def __init__(self, soft_local_max_size=3): 55 | super(SoftDetectionModule, self).__init__() 56 | 57 | self.soft_local_max_size = soft_local_max_size 58 | 59 | self.pad = self.soft_local_max_size // 2 60 | 61 | def forward(self, batch): 62 | b = batch.size(0) 63 | 64 | batch = F.relu(batch) # [2,512,28,28] 65 | 66 | max_per_sample = torch.max(batch.view(b, -1), dim=1)[0] # [1,2] 67 | exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1)) # [2,512,28,28] 68 | 69 | sum_exp = ( 70 | self.soft_local_max_size ** 2 * 71 | F.avg_pool2d( 72 | F.pad(exp, [self.pad] * 4, mode='constant', value=1.),# [2,512,30,30] 73 | self.soft_local_max_size, stride=1 74 | ) # [2,512,28,28] 75 | ) # [2, 512,28,28] 76 | local_max_score = exp / sum_exp # alpha 77 | 78 | depth_wise_max = torch.max(batch, dim=1)[0] # [2,28,28] 79 | 80 | depth_wise_max_score = batch / depth_wise_max.unsqueeze(1) # beta [2, 512, 28, 28] 81 | 82 | all_scores = local_max_score * depth_wise_max_score # [2, 512, 28, 28] 83 | score = torch.max(all_scores, dim=1)[0] # r [2,28,28] 84 | score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1) # s [2,28,28] 85 | 86 | return score 87 | 88 | 89 | class U2Net(nn.Module): 90 | def __init__(self, model_file=None, use_cuda=True): 91 | super(U2Net, self).__init__() 92 | 93 | self.dense_feature_extraction = DenseFeatureExtractionModule( 94 | finetune_feature_extraction=True, 95 | use_cuda=use_cuda 96 | ) 97 | 98 | self.detection = SoftDetectionModule() 99 | 100 | # if model_file is not None: 101 | # if use_cuda: 102 | # self.load_state_dict(torch.load(model_file)['model']) 103 | # else: 104 | # self.load_state_dict(torch.load(model_file, map_location='cpu')['model']) 105 | 106 | def forward(self, batch): 107 | b = batch['image1'].size(0) 108 | 109 | # print("image1 origin", batch['image1']) 110 | dense_features1 = self.dense_feature_extraction(batch['image1']) 111 | dense_features2 = self.dense_feature_extraction(batch['image2']) 112 | print("dense_features1 max:", dense_features1.max()) 113 | print("dense_features1 min:", dense_features1.min()) 114 | 115 | scores = self.detection(torch.cat([dense_features1, dense_features2], dim=0)) 116 | 117 | # dense_features = self.dense_feature_extraction( 118 | # torch.cat([batch['image1'], batch['image2']], dim=0) 119 | # ) 120 | # dense_features1 = dense_features[: b, :, :, :] 121 | # dense_features2 = dense_features[b :, :, :, :] 122 | # scores = self.detection(dense_features) 123 | 124 | scores1 = scores[: b, :, :] 125 | scores2 = scores[b :, :, :] 126 | 127 | 128 | return { 129 | 'dense_features1': dense_features1, 130 | 'scores1': scores1, 131 | 'dense_features2': dense_features2, 132 | 'scores2': scores2 133 | } -------------------------------------------------------------------------------- /lib/full_model/model_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import ViTFeatureExtractor, ViTModel 5 | import torchvision.models as models 6 | class DenseFeatureExtractionModule(nn.Module): 7 | def __init__(self, finetune_feature_extraction=False, use_cuda=True): 8 | super(DenseFeatureExtractionModule, self).__init__() 9 | 10 | # TODO model:transformer 11 | feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') 12 | model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') 13 | 14 | self.model = model 15 | self.feature_extractor = feature_extractor 16 | ## VGG16 17 | # model = models.vgg16() 18 | # vgg16_layers = [ 19 | # 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 20 | # 'pool1', 21 | # 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 22 | # 'pool2', 23 | # 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 24 | # 'pool3', 25 | # 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 26 | # 'pool4', 27 | # 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 28 | # 'pool5' 29 | # ] 30 | # conv4_3_idx = vgg16_layers.index('conv4_3') 31 | # 32 | # self.model = nn.Sequential( 33 | # *list(model.features.children())[: conv4_3_idx + 1] 34 | # ) 35 | # self.num_channels = 512 36 | # 37 | # 38 | # # Fix forward parameters 39 | # for param in self.model.parameters(): 40 | # param.requires_grad = False 41 | # if finetune_feature_extraction: 42 | # # Unlock conv4_3 43 | # for param in list(self.model.parameters())[-2 :]: 44 | # param.requires_grad = True 45 | 46 | if use_cuda: 47 | self.model = self.model.cuda() 48 | 49 | def forward(self, batch): 50 | 51 | # VGG 52 | # output = self.model(batch) 53 | # transformer 54 | batch = self.feature_extractor(batch, return_tensors='pt') 55 | batch.to(device="cuda") 56 | output = self.model(**batch).last_hidden_state 57 | return output 58 | 59 | 60 | class SoftDetectionModule(nn.Module): 61 | def __init__(self, soft_local_max_size=3): 62 | super(SoftDetectionModule, self).__init__() 63 | 64 | self.soft_local_max_size = soft_local_max_size 65 | 66 | self.pad = self.soft_local_max_size // 2 67 | 68 | def forward(self, batch): 69 | b = batch.size(0) 70 | 71 | batch = F.relu(batch) # [2,512,28,28] 72 | 73 | max_per_sample = torch.max(batch.view(b, -1), dim=1)[0] # [1,2] 74 | exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1)) # [2,512,28,28] 75 | 76 | sum_exp = ( 77 | self.soft_local_max_size ** 2 * 78 | F.avg_pool2d( 79 | F.pad(exp, [self.pad] * 4, mode='constant', value=1.),# [2,512,30,30] 80 | self.soft_local_max_size, stride=1 81 | ) # [2,512,28,28] 82 | ) # [2, 512,28,28] 83 | local_max_score = exp / sum_exp # alpha 84 | 85 | depth_wise_max = torch.max(batch, dim=1)[0] # [2,28,28] 86 | 87 | depth_wise_max_score = batch / depth_wise_max.unsqueeze(1) # beta [2, 512, 28, 28] 88 | 89 | all_scores = local_max_score * depth_wise_max_score # [2, 512, 28, 28] 90 | score = torch.max(all_scores, dim=1)[0] # r [2,28,28] 91 | score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1) # s [2,28,28] 92 | 93 | return score 94 | 95 | 96 | class D2Net(nn.Module): 97 | def __init__(self, model_file=None, use_cuda=True): 98 | super(D2Net, self).__init__() 99 | 100 | self.dense_feature_extraction = DenseFeatureExtractionModule( 101 | finetune_feature_extraction=True, 102 | use_cuda=use_cuda 103 | ) 104 | 105 | self.detection = SoftDetectionModule() 106 | self.conv_out = nn.Conv2d(in_channels=768, out_channels=512, kernel_size=3, stride=2, padding=1) 107 | 108 | 109 | # if model_file is not None: 110 | # if use_cuda: 111 | # self.load_state_dict(torch.load(model_file)['model']) 112 | # else: 113 | # self.load_state_dict(torch.load(model_file, map_location='cpu')['model']) 114 | 115 | def forward(self, batch): 116 | b = batch['image1'].size(0) 117 | 118 | # dense_features = self.dense_feature_extraction( 119 | # torch.cat([batch['image1'], batch['image2']], dim=0) 120 | # ) 121 | dense_features1 = self.dense_feature_extraction(batch['image1'])[:,1:, :] 122 | dense_features2 = self.dense_feature_extraction(batch['image2'])[:,1:, :] 123 | dense_features1 = dense_features1.view(1, 14, 14, 768) 124 | dense_features2 = dense_features2.view(1, 14, 14, 768) 125 | # dense_features1 = dense_features1.permute(0,3,1,2) 126 | # dense_features2 = dense_features2.permute(0,3,1,2) 127 | print("dense_features1 max 1", dense_features1.max) 128 | print("dense_features1 min 1", dense_features1.min) 129 | 130 | # dense_features1 = self.conv_out(dense_features1) 131 | # print("dense_features1 shape conv", dense_features1.shape) 132 | # dense_features2 = self.conv_out(dense_features2) 133 | 134 | 135 | scores = self.detection(torch.cat([dense_features1, dense_features2], dim=0)) 136 | print("scores shape:", scores.shape) 137 | # scores1 = self.detection(dense_features1) 138 | # scores2 = self.detection(dense_features2) 139 | 140 | 141 | # dense_features1 = dense_features[: b, :, :, :] 142 | # print("dense features:", dense_features1) 143 | print("feature max:", torch.max(dense_features1)) 144 | print("feature min:", torch.min(dense_features1)) 145 | 146 | # print("dense feaures 1 shape:", dense_features1.shape) [1,512,28,28] 147 | 148 | # dense_features2 = dense_features[b :, :, :, :] 149 | scores1 = scores[: b, :, :] 150 | scores2 = scores[b :, :, :] 151 | 152 | return { 153 | 'dense_features1': dense_features1, 154 | 'scores1': scores1, 155 | 'dense_features2': dense_features2, 156 | 'scores2': scores2 157 | } 158 | -------------------------------------------------------------------------------- /lib/model_swin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .swin_unet import SwinUnet 5 | 6 | class DenseFeatureExtractionModule(nn.Module): 7 | def __init__(self, use_relu=True, use_cuda=True): 8 | super(DenseFeatureExtractionModule, self).__init__() 9 | self.model = SwinUnet().cuda() 10 | 11 | if use_cuda: 12 | self.model = self.model.cuda() 13 | 14 | self.num_channels = 512 15 | 16 | self.use_relu = use_relu 17 | 18 | def forward(self, batch): 19 | # img_size = batch.size()[-1] 20 | # print("img size:", img_size) 21 | output = self.model(batch) 22 | if self.use_relu: 23 | output = F.relu(output) 24 | return output 25 | 26 | 27 | class SwinU2Net(nn.Module): 28 | def __init__(self, model_file=None, use_relu=True, use_cuda=True): 29 | super(SwinU2Net, self).__init__() 30 | 31 | self.dense_feature_extraction = DenseFeatureExtractionModule( 32 | use_relu=use_relu, use_cuda=use_cuda 33 | ) 34 | 35 | self.detection = HardDetectionModule() 36 | 37 | self.localization = HandcraftedLocalizationModule() 38 | 39 | if model_file is not None: 40 | # self.load_state_dict(torch.load(model_file)['model']) 41 | self.load_state_dict(torch.load(model_file)['model']) 42 | def forward(self, batch): 43 | _, _, h, w = batch.size() 44 | dense_features = self.dense_feature_extraction(batch, h) 45 | 46 | detections = self.detection(dense_features) 47 | 48 | displacements = self.localization(dense_features) 49 | 50 | return { 51 | 'dense_features': dense_features, 52 | 'detections': detections, 53 | 'displacements': displacements 54 | } 55 | 56 | 57 | class HardDetectionModule(nn.Module): 58 | def __init__(self, edge_threshold=5): 59 | super(HardDetectionModule, self).__init__() 60 | 61 | self.edge_threshold = edge_threshold 62 | 63 | self.dii_filter = torch.tensor( 64 | [[0, 1., 0], [0, -2., 0], [0, 1., 0]] 65 | ).view(1, 1, 3, 3) 66 | self.dij_filter = 0.25 * torch.tensor( 67 | [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] 68 | ).view(1, 1, 3, 3) 69 | self.djj_filter = torch.tensor( 70 | [[0, 0, 0], [1., -2., 1.], [0, 0, 0]] 71 | ).view(1, 1, 3, 3) 72 | 73 | def forward(self, batch): 74 | b, c, h, w = batch.size() 75 | device = batch.device 76 | 77 | depth_wise_max = torch.max(batch, dim=1)[0] 78 | is_depth_wise_max = (batch == depth_wise_max) 79 | del depth_wise_max 80 | 81 | local_max = F.max_pool2d(batch, 3, stride=1, padding=1) 82 | is_local_max = (batch == local_max) 83 | del local_max 84 | 85 | dii = F.conv2d( 86 | batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1 87 | ).view(b, c, h, w) 88 | dij = F.conv2d( 89 | batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1 90 | ).view(b, c, h, w) 91 | djj = F.conv2d( 92 | batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1 93 | ).view(b, c, h, w) 94 | 95 | det = dii * djj - dij * dij 96 | tr = dii + djj 97 | del dii, dij, djj 98 | 99 | threshold = (self.edge_threshold + 1) ** 2 / self.edge_threshold 100 | is_not_edge = torch.min(tr * tr / det <= threshold, det > 0) 101 | 102 | detected = torch.min( 103 | is_depth_wise_max, 104 | torch.min(is_local_max, is_not_edge) 105 | ) 106 | del is_depth_wise_max, is_local_max, is_not_edge 107 | 108 | return detected 109 | 110 | 111 | class HandcraftedLocalizationModule(nn.Module): 112 | def __init__(self): 113 | super(HandcraftedLocalizationModule, self).__init__() 114 | 115 | self.di_filter = torch.tensor( 116 | [[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]] 117 | ).view(1, 1, 3, 3) 118 | self.dj_filter = torch.tensor( 119 | [[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]] 120 | ).view(1, 1, 3, 3) 121 | 122 | self.dii_filter = torch.tensor( 123 | [[0, 1., 0], [0, -2., 0], [0, 1., 0]] 124 | ).view(1, 1, 3, 3) 125 | self.dij_filter = 0.25 * torch.tensor( 126 | [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] 127 | ).view(1, 1, 3, 3) 128 | self.djj_filter = torch.tensor( 129 | [[0, 0, 0], [1., -2., 1.], [0, 0, 0]] 130 | ).view(1, 1, 3, 3) 131 | 132 | def forward(self, batch): 133 | b, c, h, w = batch.size() 134 | device = batch.device 135 | 136 | dii = F.conv2d( 137 | batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1 138 | ).view(b, c, h, w) 139 | dij = F.conv2d( 140 | batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1 141 | ).view(b, c, h, w) 142 | djj = F.conv2d( 143 | batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1 144 | ).view(b, c, h, w) 145 | det = dii * djj - dij * dij 146 | 147 | inv_hess_00 = djj / det 148 | inv_hess_01 = -dij / det 149 | inv_hess_11 = dii / det 150 | del dii, dij, djj, det 151 | 152 | di = F.conv2d( 153 | batch.view(-1, 1, h, w), self.di_filter.to(device), padding=1 154 | ).view(b, c, h, w) 155 | dj = F.conv2d( 156 | batch.view(-1, 1, h, w), self.dj_filter.to(device), padding=1 157 | ).view(b, c, h, w) 158 | 159 | step_i = -(inv_hess_00 * di + inv_hess_01 * dj) 160 | step_j = -(inv_hess_01 * di + inv_hess_11 * dj) 161 | del inv_hess_00, inv_hess_01, inv_hess_11, di, dj 162 | 163 | return torch.stack([step_i, step_j], dim=1) 164 | -------------------------------------------------------------------------------- /lib/model_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DenseFeatureExtractionModule(nn.Module): 7 | def __init__(self, use_relu=True, use_cuda=True): 8 | super(DenseFeatureExtractionModule, self).__init__() 9 | 10 | self.model = nn.Sequential( 11 | nn.Conv2d(3, 64, 3, padding=1), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(64, 64, 3, padding=1), 14 | nn.ReLU(inplace=True), 15 | nn.MaxPool2d(2, stride=2), 16 | nn.Conv2d(64, 128, 3, padding=1), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(128, 128, 3, padding=1), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(2, stride=2), 21 | nn.Conv2d(128, 256, 3, padding=1), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(256, 256, 3, padding=1), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(256, 256, 3, padding=1), 26 | nn.ReLU(inplace=True), 27 | nn.AvgPool2d(2, stride=1), 28 | nn.Conv2d(256, 512, 3, padding=2, dilation=2), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(512, 512, 3, padding=2, dilation=2), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(512, 512, 3, padding=2, dilation=2), 33 | ) 34 | self.num_channels = 512 35 | 36 | self.use_relu = use_relu 37 | 38 | if use_cuda: 39 | self.model = self.model.cuda() 40 | 41 | def forward(self, batch): 42 | output = self.model(batch) 43 | if self.use_relu: 44 | output = F.relu(output) 45 | return output 46 | 47 | 48 | class D2Net(nn.Module): 49 | def __init__(self, model_file=None, use_relu=True, use_cuda=True): 50 | super(D2Net, self).__init__() 51 | 52 | self.dense_feature_extraction = DenseFeatureExtractionModule( 53 | use_relu=use_relu, use_cuda=use_cuda 54 | ) 55 | 56 | self.detection = HardDetectionModule() 57 | 58 | self.localization = HandcraftedLocalizationModule() 59 | 60 | if model_file is not None: 61 | if use_cuda: 62 | self.load_state_dict(torch.load(model_file)['model']) 63 | else: 64 | self.load_state_dict(torch.load(model_file, map_location='cpu')['model']) 65 | 66 | def forward(self, batch): 67 | _, _, h, w = batch.size() 68 | dense_features = self.dense_feature_extraction(batch) 69 | 70 | detections = self.detection(dense_features) 71 | 72 | displacements = self.localization(dense_features) 73 | 74 | return { 75 | 'dense_features': dense_features, 76 | 'detections': detections, 77 | 'displacements': displacements 78 | } 79 | 80 | 81 | class HardDetectionModule(nn.Module): 82 | def __init__(self, edge_threshold=5): 83 | super(HardDetectionModule, self).__init__() 84 | 85 | self.edge_threshold = edge_threshold 86 | 87 | self.dii_filter = torch.tensor( 88 | [[0, 1., 0], [0, -2., 0], [0, 1., 0]] 89 | ).view(1, 1, 3, 3) 90 | self.dij_filter = 0.25 * torch.tensor( 91 | [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] 92 | ).view(1, 1, 3, 3) 93 | self.djj_filter = torch.tensor( 94 | [[0, 0, 0], [1., -2., 1.], [0, 0, 0]] 95 | ).view(1, 1, 3, 3) 96 | 97 | def forward(self, batch): 98 | b, c, h, w = batch.size() 99 | device = batch.device 100 | 101 | depth_wise_max = torch.max(batch, dim=1)[0] 102 | is_depth_wise_max = (batch == depth_wise_max) 103 | del depth_wise_max 104 | 105 | local_max = F.max_pool2d(batch, 3, stride=1, padding=1) 106 | is_local_max = (batch == local_max) 107 | del local_max 108 | 109 | dii = F.conv2d( 110 | batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1 111 | ).view(b, c, h, w) 112 | dij = F.conv2d( 113 | batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1 114 | ).view(b, c, h, w) 115 | djj = F.conv2d( 116 | batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1 117 | ).view(b, c, h, w) 118 | 119 | det = dii * djj - dij * dij 120 | tr = dii + djj 121 | del dii, dij, djj 122 | 123 | threshold = (self.edge_threshold + 1) ** 2 / self.edge_threshold 124 | is_not_edge = torch.min(tr * tr / det <= threshold, det > 0) 125 | 126 | detected = torch.min( 127 | is_depth_wise_max, 128 | torch.min(is_local_max, is_not_edge) 129 | ) 130 | del is_depth_wise_max, is_local_max, is_not_edge 131 | 132 | return detected 133 | 134 | 135 | class HandcraftedLocalizationModule(nn.Module): 136 | def __init__(self): 137 | super(HandcraftedLocalizationModule, self).__init__() 138 | 139 | self.di_filter = torch.tensor( 140 | [[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]] 141 | ).view(1, 1, 3, 3) 142 | self.dj_filter = torch.tensor( 143 | [[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]] 144 | ).view(1, 1, 3, 3) 145 | 146 | self.dii_filter = torch.tensor( 147 | [[0, 1., 0], [0, -2., 0], [0, 1., 0]] 148 | ).view(1, 1, 3, 3) 149 | self.dij_filter = 0.25 * torch.tensor( 150 | [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] 151 | ).view(1, 1, 3, 3) 152 | self.djj_filter = torch.tensor( 153 | [[0, 0, 0], [1., -2., 1.], [0, 0, 0]] 154 | ).view(1, 1, 3, 3) 155 | 156 | def forward(self, batch): 157 | b, c, h, w = batch.size() 158 | device = batch.device 159 | 160 | dii = F.conv2d( 161 | batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1 162 | ).view(b, c, h, w) 163 | dij = F.conv2d( 164 | batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1 165 | ).view(b, c, h, w) 166 | djj = F.conv2d( 167 | batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1 168 | ).view(b, c, h, w) 169 | det = dii * djj - dij * dij 170 | 171 | inv_hess_00 = djj / det 172 | inv_hess_01 = -dij / det 173 | inv_hess_11 = dii / det 174 | del dii, dij, djj, det 175 | 176 | di = F.conv2d( 177 | batch.view(-1, 1, h, w), self.di_filter.to(device), padding=1 178 | ).view(b, c, h, w) 179 | dj = F.conv2d( 180 | batch.view(-1, 1, h, w), self.dj_filter.to(device), padding=1 181 | ).view(b, c, h, w) 182 | 183 | step_i = -(inv_hess_00 * di + inv_hess_01 * dj) 184 | step_j = -(inv_hess_01 * di + inv_hess_11 * dj) 185 | del inv_hess_00, inv_hess_01, inv_hess_11, di, dj 186 | 187 | return torch.stack([step_i, step_j], dim=1) 188 | -------------------------------------------------------------------------------- /lib/swin_transformer_unet_skip_expand_decoder_sys.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | from einops import rearrange 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | 7 | 8 | class Mlp(nn.Module): 9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 10 | super().__init__() 11 | out_features = out_features or in_features 12 | hidden_features = hidden_features or in_features 13 | self.fc1 = nn.Linear(in_features, hidden_features) 14 | self.act = act_layer() 15 | self.fc2 = nn.Linear(hidden_features, out_features) 16 | self.drop = nn.Dropout(drop) 17 | 18 | def forward(self, x): 19 | x = self.fc1(x) 20 | x = self.act(x) 21 | x = self.drop(x) 22 | x = self.fc2(x) 23 | x = self.drop(x) 24 | return x 25 | 26 | 27 | def window_partition(x, window_size): 28 | """ 29 | Args: 30 | x: (B, H, W, C) 31 | window_size (int): window size 32 | 33 | Returns: 34 | windows: (num_windows*B, window_size, window_size, C) 35 | """ 36 | B, H, W, C = x.shape 37 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 38 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 39 | return windows 40 | 41 | 42 | def window_reverse(windows, window_size, H, W): 43 | """ 44 | Args: 45 | windows: (num_windows*B, window_size, window_size, C) 46 | window_size (int): Window size 47 | H (int): Height of image 48 | W (int): Width of image 49 | 50 | Returns: 51 | x: (B, H, W, C) 52 | """ 53 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 54 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 55 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 56 | return x 57 | 58 | 59 | class WindowAttention(nn.Module): 60 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 61 | It supports both of shifted and non-shifted window. 62 | 63 | Args: 64 | dim (int): Number of input channels. 65 | window_size (tuple[int]): The height and width of the window. 66 | num_heads (int): Number of attention heads. 67 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 68 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 69 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 70 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 71 | """ 72 | 73 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 74 | 75 | super().__init__() 76 | self.dim = dim 77 | self.window_size = window_size # Wh, Ww 78 | self.num_heads = num_heads 79 | head_dim = dim // num_heads 80 | self.scale = qk_scale or head_dim ** -0.5 81 | 82 | # define a parameter table of relative position bias 83 | self.relative_position_bias_table = nn.Parameter( 84 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 85 | 86 | # get pair-wise relative position index for each token inside the window 87 | coords_h = torch.arange(self.window_size[0]) 88 | coords_w = torch.arange(self.window_size[1]) 89 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 90 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 91 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 92 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 93 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 94 | relative_coords[:, :, 1] += self.window_size[1] - 1 95 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 96 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 97 | self.register_buffer("relative_position_index", relative_position_index) 98 | 99 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 100 | self.attn_drop = nn.Dropout(attn_drop) 101 | self.proj = nn.Linear(dim, dim) 102 | self.proj_drop = nn.Dropout(proj_drop) 103 | 104 | trunc_normal_(self.relative_position_bias_table, std=.02) 105 | self.softmax = nn.Softmax(dim=-1) 106 | 107 | def forward(self, x, mask=None): 108 | """ 109 | Args: 110 | x: input features with shape of (num_windows*B, N, C) 111 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 112 | """ 113 | B_, N, C = x.shape 114 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 116 | 117 | q = q * self.scale 118 | attn = (q @ k.transpose(-2, -1)) 119 | 120 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 121 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 122 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 123 | attn = attn + relative_position_bias.unsqueeze(0) 124 | 125 | if mask is not None: 126 | nW = mask.shape[0] 127 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 128 | attn = attn.view(-1, self.num_heads, N, N) 129 | attn = self.softmax(attn) 130 | else: 131 | attn = self.softmax(attn) 132 | 133 | attn = self.attn_drop(attn) 134 | 135 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 136 | x = self.proj(x) 137 | x = self.proj_drop(x) 138 | return x 139 | 140 | def extra_repr(self) -> str: 141 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 142 | 143 | def flops(self, N): 144 | # calculate flops for 1 window with token length of N 145 | flops = 0 146 | # qkv = self.qkv(x) 147 | flops += N * self.dim * 3 * self.dim 148 | # attn = (q @ k.transpose(-2, -1)) 149 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 150 | # x = (attn @ v) 151 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 152 | # x = self.proj(x) 153 | flops += N * self.dim * self.dim 154 | return flops 155 | 156 | 157 | class SwinTransformerBlock(nn.Module): 158 | r""" Swin Transformer Block. 159 | 160 | Args: 161 | dim (int): Number of input channels. 162 | input_resolution (tuple[int]): Input resulotion. 163 | num_heads (int): Number of attention heads. 164 | window_size (int): Window size. 165 | shift_size (int): Shift size for SW-MSA. 166 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 167 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 168 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 169 | drop (float, optional): Dropout rate. Default: 0.0 170 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 171 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 172 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 173 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 174 | """ 175 | 176 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 177 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 178 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 179 | super().__init__() 180 | self.dim = dim 181 | self.input_resolution = input_resolution 182 | self.num_heads = num_heads 183 | self.window_size = window_size 184 | self.shift_size = shift_size 185 | self.mlp_ratio = mlp_ratio 186 | if min(self.input_resolution) <= self.window_size: 187 | # if window size is larger than input resolution, we don't partition windows 188 | self.shift_size = 0 189 | self.window_size = min(self.input_resolution) 190 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 191 | 192 | self.norm1 = norm_layer(dim) 193 | self.attn = WindowAttention( 194 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 195 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 196 | 197 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 198 | self.norm2 = norm_layer(dim) 199 | mlp_hidden_dim = int(dim * mlp_ratio) 200 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 201 | 202 | if self.shift_size > 0: 203 | # calculate attention mask for SW-MSA 204 | H, W = self.input_resolution 205 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 206 | h_slices = (slice(0, -self.window_size), 207 | slice(-self.window_size, -self.shift_size), 208 | slice(-self.shift_size, None)) 209 | w_slices = (slice(0, -self.window_size), 210 | slice(-self.window_size, -self.shift_size), 211 | slice(-self.shift_size, None)) 212 | cnt = 0 213 | for h in h_slices: 214 | for w in w_slices: 215 | img_mask[:, h, w, :] = cnt 216 | cnt += 1 217 | 218 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 219 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 220 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 221 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 222 | else: 223 | attn_mask = None 224 | 225 | self.register_buffer("attn_mask", attn_mask) 226 | 227 | def forward(self, x): 228 | H, W = self.input_resolution 229 | B, L, C = x.shape 230 | assert L == H * W, "input feature has wrong size" 231 | 232 | shortcut = x 233 | x = self.norm1(x) 234 | x = x.view(B, H, W, C) 235 | 236 | # cyclic shift 237 | if self.shift_size > 0: 238 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 239 | else: 240 | shifted_x = x 241 | 242 | # partition windows 243 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 244 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 245 | 246 | # W-MSA/SW-MSA 247 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 248 | 249 | # merge windows 250 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 251 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 252 | 253 | # reverse cyclic shift 254 | if self.shift_size > 0: 255 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 256 | else: 257 | x = shifted_x 258 | x = x.view(B, H * W, C) 259 | 260 | # FFN 261 | x = shortcut + self.drop_path(x) 262 | x = x + self.drop_path(self.mlp(self.norm2(x))) 263 | 264 | return x 265 | 266 | def extra_repr(self) -> str: 267 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 268 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 269 | 270 | def flops(self): 271 | flops = 0 272 | H, W = self.input_resolution 273 | # norm1 274 | flops += self.dim * H * W 275 | # W-MSA/SW-MSA 276 | nW = H * W / self.window_size / self.window_size 277 | flops += nW * self.attn.flops(self.window_size * self.window_size) 278 | # mlp 279 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 280 | # norm2 281 | flops += self.dim * H * W 282 | return flops 283 | 284 | 285 | class PatchMerging(nn.Module): 286 | r""" Patch Merging Layer. 287 | 288 | Args: 289 | input_resolution (tuple[int]): Resolution of input feature. 290 | dim (int): Number of input channels. 291 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 292 | """ 293 | 294 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 295 | super().__init__() 296 | self.input_resolution = input_resolution 297 | self.dim = dim 298 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 299 | self.norm = norm_layer(4 * dim) 300 | 301 | def forward(self, x): 302 | """ 303 | x: B, H*W, C 304 | """ 305 | H, W = self.input_resolution 306 | B, L, C = x.shape 307 | assert L == H * W, "input feature has wrong size" 308 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 309 | 310 | x = x.view(B, H, W, C) 311 | 312 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 313 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 314 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 315 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 316 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 317 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 318 | 319 | x = self.norm(x) 320 | x = self.reduction(x) 321 | 322 | return x 323 | 324 | def extra_repr(self) -> str: 325 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 326 | 327 | def flops(self): 328 | H, W = self.input_resolution 329 | flops = H * W * self.dim 330 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 331 | return flops 332 | 333 | class PatchExpand(nn.Module): 334 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 335 | super().__init__() 336 | self.input_resolution = input_resolution 337 | self.dim = dim 338 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() 339 | self.norm = norm_layer(dim // dim_scale) 340 | 341 | def forward(self, x): 342 | """ 343 | x: B, H*W, C 344 | """ 345 | H, W = self.input_resolution 346 | x = self.expand(x) 347 | B, L, C = x.shape 348 | assert L == H * W, "input feature has wrong size" 349 | 350 | x = x.view(B, H, W, C) 351 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) 352 | x = x.view(B,-1,C//4) 353 | x= self.norm(x) 354 | 355 | return x 356 | 357 | class FinalPatchExpand_X4(nn.Module): 358 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 359 | super().__init__() 360 | self.input_resolution = input_resolution 361 | self.dim = dim 362 | self.dim_scale = dim_scale 363 | self.expand = nn.Linear(dim, 16*dim, bias=False) 364 | self.output_dim = dim 365 | self.norm = norm_layer(self.output_dim) 366 | 367 | def forward(self, x): 368 | """ 369 | x: B, H*W, C 370 | """ 371 | H, W = self.input_resolution 372 | x = self.expand(x) 373 | B, L, C = x.shape 374 | assert L == H * W, "input feature has wrong size" 375 | 376 | x = x.view(B, H, W, C) 377 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) 378 | x = x.view(B,-1,self.output_dim) 379 | x= self.norm(x) 380 | 381 | return x 382 | 383 | class BasicLayer(nn.Module): 384 | """ A basic Swin Transformer layer for one stage. 385 | 386 | Args: 387 | dim (int): Number of input channels. 388 | input_resolution (tuple[int]): Input resolution. 389 | depth (int): Number of blocks. 390 | num_heads (int): Number of attention heads. 391 | window_size (int): Local window size. 392 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 393 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 394 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 395 | drop (float, optional): Dropout rate. Default: 0.0 396 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 397 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 398 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 399 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 400 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 401 | """ 402 | 403 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 404 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 405 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 406 | 407 | super().__init__() 408 | self.dim = dim 409 | self.input_resolution = input_resolution 410 | self.depth = depth 411 | self.use_checkpoint = use_checkpoint 412 | 413 | # build blocks 414 | self.blocks = nn.ModuleList([ 415 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 416 | num_heads=num_heads, window_size=window_size, 417 | shift_size=0 if (i % 2 == 0) else window_size // 2, 418 | mlp_ratio=mlp_ratio, 419 | qkv_bias=qkv_bias, qk_scale=qk_scale, 420 | drop=drop, attn_drop=attn_drop, 421 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 422 | norm_layer=norm_layer) 423 | for i in range(depth)]) 424 | 425 | # patch merging layer 426 | if downsample is not None: 427 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 428 | else: 429 | self.downsample = None 430 | 431 | def forward(self, x): 432 | for blk in self.blocks: 433 | if self.use_checkpoint: 434 | x = checkpoint.checkpoint(blk, x) 435 | else: 436 | x = blk(x) 437 | if self.downsample is not None: 438 | x = self.downsample(x) 439 | return x 440 | 441 | def extra_repr(self) -> str: 442 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 443 | 444 | def flops(self): 445 | flops = 0 446 | for blk in self.blocks: 447 | flops += blk.flops() 448 | if self.downsample is not None: 449 | flops += self.downsample.flops() 450 | return flops 451 | 452 | class BasicLayer_up(nn.Module): 453 | """ A basic Swin Transformer layer for one stage. 454 | 455 | Args: 456 | dim (int): Number of input channels. 457 | input_resolution (tuple[int]): Input resolution. 458 | depth (int): Number of blocks. 459 | num_heads (int): Number of attention heads. 460 | window_size (int): Local window size. 461 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 462 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 463 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 464 | drop (float, optional): Dropout rate. Default: 0.0 465 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 466 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 467 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 468 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 469 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 470 | """ 471 | 472 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 473 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 474 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 475 | 476 | super().__init__() 477 | self.dim = dim 478 | self.input_resolution = input_resolution 479 | self.depth = depth 480 | self.use_checkpoint = use_checkpoint 481 | 482 | # build blocks 483 | self.blocks = nn.ModuleList([ 484 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 485 | num_heads=num_heads, window_size=window_size, 486 | shift_size=0 if (i % 2 == 0) else window_size // 2, 487 | mlp_ratio=mlp_ratio, 488 | qkv_bias=qkv_bias, qk_scale=qk_scale, 489 | drop=drop, attn_drop=attn_drop, 490 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 491 | norm_layer=norm_layer) 492 | for i in range(depth)]) 493 | 494 | # patch merging layer 495 | if upsample is not None: 496 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 497 | else: 498 | self.upsample = None 499 | 500 | def forward(self, x): 501 | for blk in self.blocks: 502 | if self.use_checkpoint: 503 | x = checkpoint.checkpoint(blk, x) 504 | else: 505 | x = blk(x) 506 | if self.upsample is not None: 507 | x = self.upsample(x) 508 | return x 509 | 510 | class PatchEmbed(nn.Module): 511 | r""" Image to Patch Embedding 512 | 513 | Args: 514 | img_size (int): Image size. Default: 224. 515 | patch_size (int): Patch token size. Default: 4. 516 | in_chans (int): Number of input image channels. Default: 3. 517 | embed_dim (int): Number of linear projection output channels. Default: 96. 518 | norm_layer (nn.Module, optional): Normalization layer. Default: None 519 | """ 520 | 521 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 522 | super().__init__() 523 | img_size = to_2tuple(img_size) 524 | patch_size = to_2tuple(patch_size) 525 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 526 | self.img_size = img_size 527 | self.patch_size = patch_size 528 | self.patches_resolution = patches_resolution 529 | self.num_patches = patches_resolution[0] * patches_resolution[1] 530 | 531 | self.in_chans = in_chans 532 | self.embed_dim = embed_dim 533 | 534 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 535 | if norm_layer is not None: 536 | self.norm = norm_layer(embed_dim) 537 | else: 538 | self.norm = None 539 | 540 | def forward(self, x): 541 | B, C, H, W = x.shape 542 | # FIXME look at relaxing size constraints 543 | assert H == self.img_size[0] and W == self.img_size[1], \ 544 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 545 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 546 | if self.norm is not None: 547 | x = self.norm(x) 548 | return x 549 | 550 | def flops(self): 551 | Ho, Wo = self.patches_resolution 552 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 553 | if self.norm is not None: 554 | flops += Ho * Wo * self.embed_dim 555 | return flops 556 | 557 | 558 | class SwinTransformerSys(nn.Module): 559 | r""" Swin Transformer 560 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 561 | https://arxiv.org/pdf/2103.14030 562 | 563 | Args: 564 | img_size (int | tuple(int)): Input image size. Default 224 565 | patch_size (int | tuple(int)): Patch size. Default: 4 566 | in_chans (int): Number of input image channels. Default: 3 567 | num_classes (int): Number of classes for classification head. Default: 1000 568 | embed_dim (int): Patch embedding dimension. Default: 96 569 | depths (tuple(int)): Depth of each Swin Transformer layer. 570 | num_heads (tuple(int)): Number of attention heads in different layers. 571 | window_size (int): Window size. Default: 7 572 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 573 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 574 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 575 | drop_rate (float): Dropout rate. Default: 0 576 | attn_drop_rate (float): Attention dropout rate. Default: 0 577 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 578 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 579 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 580 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 581 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 582 | """ 583 | 584 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 585 | embed_dim=512, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], 586 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 587 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 588 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 589 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 590 | super().__init__() 591 | 592 | # print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, 593 | # depths_decoder,drop_path_rate,num_classes)) 594 | # print("img --- size:", img_size) 595 | 596 | self.num_classes = num_classes 597 | self.num_layers = len(depths) 598 | self.embed_dim = embed_dim 599 | self.ape = ape 600 | self.patch_norm = patch_norm 601 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 602 | self.num_features_up = int(embed_dim * 2) 603 | self.mlp_ratio = mlp_ratio 604 | self.final_upsample = final_upsample 605 | 606 | 607 | # split image into non-overlapping patches 608 | self.patch_embed = PatchEmbed( 609 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 610 | norm_layer=norm_layer if self.patch_norm else None) 611 | num_patches = self.patch_embed.num_patches 612 | patches_resolution = self.patch_embed.patches_resolution 613 | self.patches_resolution = patches_resolution 614 | 615 | # absolute position embedding 616 | if self.ape: 617 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 618 | trunc_normal_(self.absolute_pos_embed, std=.02) 619 | 620 | self.pos_drop = nn.Dropout(p=drop_rate) 621 | 622 | # stochastic depth 623 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 624 | 625 | # build encoder and bottleneck layers 626 | self.layers = nn.ModuleList() 627 | for i_layer in range(self.num_layers): 628 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 629 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 630 | patches_resolution[1] // (2 ** i_layer)), 631 | depth=depths[i_layer], 632 | num_heads=num_heads[i_layer], 633 | window_size=window_size, 634 | mlp_ratio=self.mlp_ratio, 635 | qkv_bias=qkv_bias, qk_scale=qk_scale, 636 | drop=drop_rate, attn_drop=attn_drop_rate, 637 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 638 | norm_layer=norm_layer, 639 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 640 | use_checkpoint=use_checkpoint) 641 | self.layers.append(layer) 642 | 643 | # build decoder layers 644 | self.layers_up = nn.ModuleList() 645 | self.concat_back_dim = nn.ModuleList() 646 | for i_layer in range(self.num_layers): 647 | concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), 648 | int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() 649 | if i_layer ==0 : 650 | layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), 651 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) 652 | else: 653 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), 654 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), 655 | patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), 656 | depth=depths[(self.num_layers-1-i_layer)], 657 | num_heads=num_heads[(self.num_layers-1-i_layer)], 658 | window_size=window_size, 659 | mlp_ratio=self.mlp_ratio, 660 | qkv_bias=qkv_bias, qk_scale=qk_scale, 661 | drop=drop_rate, attn_drop=attn_drop_rate, 662 | drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], 663 | norm_layer=norm_layer, 664 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, 665 | use_checkpoint=use_checkpoint) 666 | self.layers_up.append(layer_up) 667 | self.concat_back_dim.append(concat_linear) 668 | 669 | self.norm = norm_layer(self.num_features) 670 | self.norm_up= norm_layer(self.embed_dim) 671 | # self.norm_up= norm_layer(192) 672 | 673 | if self.final_upsample == "expand_first": 674 | # print("---final upsample expand_first---") # FIXME 675 | self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim) 676 | # self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False) 677 | self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=3, stride=2, padding=1, bias=False) 678 | 679 | self.output_2 = nn.Conv2d(in_channels=self.num_classes,out_channels=self.num_classes,kernel_size=3, stride=2, padding=1, bias=False) 680 | # self.output_3 = nn.Conv2d(in_channels=self.num_classes,out_channels=self.num_classes,kernel_size=3, dilation=2, padding=2, bias=False) 681 | 682 | self.apply(self._init_weights) 683 | 684 | def _init_weights(self, m): 685 | if isinstance(m, nn.Linear): 686 | trunc_normal_(m.weight, std=.02) 687 | if isinstance(m, nn.Linear) and m.bias is not None: 688 | nn.init.constant_(m.bias, 0) 689 | elif isinstance(m, nn.LayerNorm): 690 | nn.init.constant_(m.bias, 0) 691 | nn.init.constant_(m.weight, 1.0) 692 | 693 | @torch.jit.ignore 694 | def no_weight_decay(self): 695 | return {'absolute_pos_embed'} 696 | 697 | @torch.jit.ignore 698 | def no_weight_decay_keywords(self): 699 | return {'relative_position_bias_table'} 700 | 701 | #Encoder and Bottleneck 702 | def forward_features(self, x): 703 | x = self.patch_embed(x) 704 | if self.ape: 705 | x = x + self.absolute_pos_embed 706 | x = self.pos_drop(x) 707 | x_downsample = [] 708 | 709 | for layer in self.layers: 710 | x_downsample.append(x) 711 | x = layer(x) 712 | 713 | x = self.norm(x) # B L C [1,49,768] 714 | 715 | return x, x_downsample 716 | 717 | #Dencoder and Skip connection 718 | ## FIXME 修改网络结构 保证预训练的参数还可以用 跳过一些层实时 719 | def forward_up_features(self, x, x_downsample): 720 | for inx, layer_up in enumerate(self.layers_up): 721 | # print("swin decoder",self.layers_up) 722 | if inx == 0: 723 | x = layer_up(x) 724 | else: 725 | x = torch.cat([x,x_downsample[3-inx]],-1) 726 | x = self.concat_back_dim[inx](x) 727 | x = layer_up(x) 728 | 729 | x = self.norm_up(x) # B L C [1,3136,96] 730 | 731 | return x 732 | 733 | def up_x4(self, x): 734 | H, W = self.patches_resolution 735 | 736 | B, L, C = x.shape 737 | assert L == H*W, "input features has wrong size" 738 | 739 | if self.final_upsample=="expand_first": 740 | x = self.up(x) 741 | x = x.view(B,4*H,4*W,-1) 742 | # x = x.view(B,H,W,-1) 743 | x = x.permute(0,3,1,2) #B,C,H,W 744 | x = self.output(x) # 1/2 745 | x = self.output_2(x) # 1/4 746 | # x = self.output_2(x) # 1/8 # FIXME 最后一层空洞卷积保持尺寸还是直接去掉 747 | # x = self.output_3(x) # 1/4 # FIXME 这里注意训练的时候保持一致 748 | 749 | 750 | return x 751 | 752 | def forward(self, x): 753 | x, x_downsample = self.forward_features(x) 754 | x = self.forward_up_features(x,x_downsample) 755 | x = self.up_x4(x) 756 | # print("out shape", x.shape) 757 | # x = x.view(1,28,28,-1) 758 | # x = x.permute(0,3,2,1) 759 | # x = self.output(x) 760 | # print("output:", x.shape) 761 | # x = self.up_x4(x) 762 | 763 | return x 764 | 765 | def flops(self): 766 | flops = 0 767 | flops += self.patch_embed.flops() 768 | for i, layer in enumerate(self.layers): 769 | flops += layer.flops() 770 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 771 | flops += self.num_features * self.num_classes 772 | return flops 773 | -------------------------------------------------------------------------------- /lib/swin_unet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | from .swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class SwinUnet(nn.Module): 25 | def __init__(self, img_size=224, num_classes=21843, zero_head=False, vis=False): 26 | super(SwinUnet, self).__init__() 27 | self.num_classes = num_classes 28 | self.zero_head = zero_head 29 | 30 | self.swin_unet = SwinTransformerSys(img_size=224, 31 | patch_size=4, 32 | in_chans=3, 33 | num_classes=512, 34 | embed_dim=96, 35 | depths=[2, 2, 2, 2], 36 | num_heads=[3, 6, 12, 24], 37 | window_size=7, 38 | mlp_ratio=4, 39 | qkv_bias=True, 40 | qk_scale=None, 41 | drop_rate=0, 42 | drop_path_rate=0.1, 43 | ape=True, 44 | patch_norm=True, 45 | use_checkpoint=True) 46 | 47 | def forward(self, x): 48 | if x.size()[1] == 1: 49 | x = x.repeat(1, 3, 1, 1) 50 | logits = self.swin_unet(x) 51 | return logits 52 | 53 | # def load_from(self, pretrained_path): 54 | # pretrained_path = pretrained_path 55 | # if pretrained_path is not None: 56 | # print("pretrained_path:{}".format(pretrained_path)) 57 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 58 | # pretrained_dict = torch.load(pretrained_path, map_location=device) 59 | # if "model" not in pretrained_dict: 60 | # print("---start load pretrained modle by splitting---") 61 | # pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()} 62 | # for k in list(pretrained_dict.keys()): 63 | # if "output" in k: 64 | # print("delete key:{}".format(k)) 65 | # del pretrained_dict[k] 66 | # msg = self.swin_unet.load_state_dict(pretrained_dict, strict=False) 67 | # # print(msg) 68 | # return 69 | # pretrained_dict = pretrained_dict['model'] 70 | # print("---start load pretrained modle of swin encoder---") 71 | # 72 | # model_dict = self.swin_unet.state_dict() 73 | # full_dict = copy.deepcopy(pretrained_dict) 74 | # for k, v in pretrained_dict.items(): 75 | # if "layers." in k: 76 | # current_layer_num = 3 - int(k[7:8]) 77 | # current_k = "layers_up." + str(current_layer_num) + k[8:] 78 | # full_dict.update({current_k: v}) 79 | # for k in list(full_dict.keys()): 80 | # if k in model_dict: 81 | # if full_dict[k].shape != model_dict[k].shape: 82 | # print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape)) 83 | # del full_dict[k] 84 | # 85 | # msg = self.swin_unet.load_state_dict(full_dict, strict=False) 86 | # # print(msg) 87 | # else: 88 | # print("none pretrain") 89 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from lib.exceptions import EmptyTensorError 8 | 9 | 10 | def preprocess_image(image, preprocessing=None): 11 | image = image.astype(np.float32) 12 | image = np.transpose(image, [2, 0, 1]) 13 | if preprocessing is None: 14 | pass 15 | elif preprocessing == 'caffe': 16 | # RGB -> BGR 17 | image = image[:: -1, :, :] 18 | # Zero-center by mean pixel 19 | mean = np.array([103.939, 116.779, 123.68]) 20 | image = image - mean.reshape([3, 1, 1]) 21 | elif preprocessing == 'torch': 22 | image /= 255.0 23 | mean = np.array([0.559, 0.573, 0.555]) 24 | std = np.array([0.185, 0.172, 0.176]) 25 | image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1]) 26 | else: 27 | raise ValueError('Unknown preprocessing parameter.') 28 | return image 29 | 30 | 31 | def imshow_image(image, preprocessing=None): 32 | if preprocessing is None: 33 | pass 34 | elif preprocessing == 'caffe': 35 | mean = np.array([103.939, 116.779, 123.68]) 36 | image = image + mean.reshape([3, 1, 1]) 37 | # RGB -> BGR 38 | image = image[:: -1, :, :] 39 | elif preprocessing == 'torch': 40 | # mean = np.array([0.485, 0.456, 0.406]) 41 | # std = np.array([0.229, 0.224, 0.225]) 42 | mean = np.array([0.559, 0.573, 0.555]) 43 | std = np.array([0.185, 0.172, 0.176]) 44 | image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1]) 45 | image *= 255.0 46 | else: 47 | raise ValueError('Unknown preprocessing parameter.') 48 | image = np.transpose(image, [1, 2, 0]) 49 | image = np.round(image).astype(np.uint8) 50 | return image 51 | 52 | 53 | def grid_positions(h, w, device, matrix=False): 54 | lines = torch.arange( 55 | 0, h, device=device 56 | ).view(-1, 1).float().repeat(1, w) 57 | columns = torch.arange( 58 | 0, w, device=device 59 | ).view(1, -1).float().repeat(h, 1) 60 | if matrix: 61 | return torch.stack([lines, columns], dim=0) 62 | else: 63 | return torch.cat([lines.view(1, -1), columns.view(1, -1)], dim=0) 64 | 65 | 66 | def upscale_positions(pos, scaling_steps=0): 67 | for _ in range(scaling_steps): 68 | pos = pos * 2 + 0.5 69 | return pos 70 | 71 | 72 | def downscale_positions(pos, scaling_steps=0): 73 | for _ in range(scaling_steps): 74 | pos = (pos - 0.5) / 2 75 | return pos 76 | 77 | 78 | def interpolate_dense_features(pos, dense_features, return_corners=False): 79 | device = pos.device 80 | 81 | ids = torch.arange(0, pos.size(1), device=device) 82 | 83 | _, h, w = dense_features.size() 84 | 85 | i = pos[0, :] 86 | j = pos[1, :] 87 | 88 | # Valid corners 89 | i_top_left = torch.floor(i).long() 90 | j_top_left = torch.floor(j).long() 91 | valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0) 92 | 93 | i_top_right = torch.floor(i).long() 94 | j_top_right = torch.ceil(j).long() 95 | valid_top_right = torch.min(i_top_right >= 0, j_top_right < w) 96 | 97 | i_bottom_left = torch.ceil(i).long() 98 | j_bottom_left = torch.floor(j).long() 99 | valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0) 100 | 101 | i_bottom_right = torch.ceil(i).long() 102 | j_bottom_right = torch.ceil(j).long() 103 | valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w) 104 | 105 | valid_corners = torch.min( 106 | torch.min(valid_top_left, valid_top_right), 107 | torch.min(valid_bottom_left, valid_bottom_right) 108 | ) 109 | 110 | i_top_left = i_top_left[valid_corners] 111 | j_top_left = j_top_left[valid_corners] 112 | 113 | i_top_right = i_top_right[valid_corners] 114 | j_top_right = j_top_right[valid_corners] 115 | 116 | i_bottom_left = i_bottom_left[valid_corners] 117 | j_bottom_left = j_bottom_left[valid_corners] 118 | 119 | i_bottom_right = i_bottom_right[valid_corners] 120 | j_bottom_right = j_bottom_right[valid_corners] 121 | 122 | ids = ids[valid_corners] 123 | if ids.size(0) == 0: 124 | raise EmptyTensorError 125 | 126 | # Interpolation 127 | i = i[ids] 128 | j = j[ids] 129 | dist_i_top_left = i - i_top_left.float() 130 | dist_j_top_left = j - j_top_left.float() 131 | w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) 132 | w_top_right = (1 - dist_i_top_left) * dist_j_top_left 133 | w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) 134 | w_bottom_right = dist_i_top_left * dist_j_top_left 135 | 136 | descriptors = ( 137 | w_top_left * dense_features[:, i_top_left, j_top_left] + 138 | w_top_right * dense_features[:, i_top_right, j_top_right] + 139 | w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] + 140 | w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right] 141 | ) 142 | 143 | pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0) 144 | 145 | if not return_corners: 146 | return [descriptors, pos, ids] 147 | else: 148 | corners = torch.stack([ 149 | torch.stack([i_top_left, j_top_left], dim=0), 150 | torch.stack([i_top_right, j_top_right], dim=0), 151 | torch.stack([i_bottom_left, j_bottom_left], dim=0), 152 | torch.stack([i_bottom_right, j_bottom_right], dim=0) 153 | ], dim=0) 154 | return [descriptors, pos, ids, corners] 155 | 156 | 157 | def savefig(filepath, fig=None, dpi=None): 158 | # TomNorway - https://stackoverflow.com/a/53516034 159 | if not fig: 160 | fig = plt.gcf() 161 | 162 | plt.subplots_adjust(0, 0, 1, 1, 0, 0) 163 | for ax in fig.axes: 164 | ax.axis('off') 165 | ax.margins(0, 0) 166 | ax.xaxis.set_major_locator(plt.NullLocator()) 167 | ax.yaxis.set_major_locator(plt.NullLocator()) 168 | 169 | fig.savefig(filepath, pad_inches=0, bbox_inches='tight', dpi=dpi) 170 | -------------------------------------------------------------------------------- /match_localization.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | from os.path import exists 4 | import numpy as np 5 | import matchs 6 | from sklearn.neighbors import NearestNeighbors 7 | from tqdm.auto import tqdm 8 | import pandas as pd 9 | import time 10 | import argparse 11 | import cv2 12 | import configparser 13 | 14 | import warnings 15 | warnings.filterwarnings("ignore", category=Warning) 16 | 17 | def parse_gt_file(gtfile): 18 | print('Parsing ground truth data file...') 19 | gtdata = np.load(gtfile) 20 | return gtdata['utmQ'], gtdata['utmDb'], gtdata['posDistThr'] 21 | 22 | def get_positives(utmQ, utmDb, posDistThr): 23 | # positives for evaluation are those within trivial threshold range 24 | # fit NN to find them, search by radius 25 | knn = NearestNeighbors(n_jobs=-1) 26 | knn.fit(utmDb) 27 | distances, positives = knn.radius_neighbors(utmQ, radius=posDistThr) 28 | 29 | return positives 30 | 31 | def compute_recall(query_idx_list, gt, predictions, numQ, n_values, recall_str=''): 32 | start2 = time.perf_counter() 33 | correct_at_n = np.zeros(len(n_values)) 34 | 35 | for i, pred in enumerate(predictions): 36 | for j, n in enumerate(n_values): 37 | # if in top N then also in top NN, where NN > N 38 | if np.any(np.in1d(pred[:n], gt[query_idx_list[i]])): 39 | correct_at_n[i:] += 1 40 | break 41 | recall_at_n = correct_at_n / numQ 42 | all_recalls = {} # make dict for output 43 | for i, n in enumerate(n_values): 44 | all_recalls[n] = recall_at_n[i] 45 | tqdm.write("====> Recall {}@{}: {:.4f}".format(recall_str, n, recall_at_n[i])) 46 | 47 | recall_time = time.perf_counter() 48 | print('Compute recall time is %6.3f' % (recall_time - start2)) 49 | 50 | return all_recalls 51 | 52 | def write_recalls(opt, netvlad_recalls, cnnmatch_recalls, n_values, rec_file): 53 | with open(rec_file, 'w') as res_out: 54 | res_out.write(str(opt)+'\n') 55 | res_out.write("n_values: "+str(n_values)+'\n') 56 | for n in n_values: 57 | res_out.write("Recall {}@{}: {:.4f}\n".format('netvlad_match', n, netvlad_recalls[n])) 58 | res_out.write("Recall {}@{}: {:.4f}\n".format('cnn_match', n, cnnmatch_recalls[n])) 59 | 60 | def image_point2gps_point(keypoint_position: np.array, image_position: np.array, gsd, img_size): 61 | half_img_size = img_size/2 62 | point_position = np.array([image_position[0]+(keypoint_position[0]-half_img_size)*gsd, image_position[1]+(half_img_size - keypoint_position[1])*gsd]) 63 | alt = np.zeros(keypoint_position.shape[1]) 64 | point_position= np.vstack([*point_position,alt]) 65 | return point_position 66 | 67 | 68 | 69 | def main(): 70 | parser = argparse.ArgumentParser(description='Patch-NetVLAD-Feature-Match') 71 | parser.add_argument('--config_path', type=str, default='performance.ini', 72 | help='File name (with extension) to an ini file that stores most of the configuration data for cnn-matching') 73 | parser.add_argument('--ground_truth_path', type=str, default='/home/a409/users/huboni/Projects/code/Patch-NetVLAD/patchnetvlad/dataset_gt_files/mavic-xjtu_dist50.npz', 74 | help='ground truth file dist 50') 75 | parser.add_argument('--netvlad_result_path', type=str, default='/home/a409/users/huboni/Projects/code/Patch-NetVLAD/patchnetvlad/results/mavic-xjtu_2048_7_dist50/NetVLAD_predictions.txt', 76 | help='netvlad predictions result path') 77 | parser.add_argument('--reference_origin_path', type=str, default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic-xjtu/reference.csv', 78 | help='reference csv path') 79 | parser.add_argument('--query_origin_path', type=str, default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic-xjtu/query.csv', 80 | help='query csv path') 81 | parser.add_argument('--out_dir', type=str, default='/home/a409/users/huboni/Projects/code/cnn-matching/result/mavic-xjtu/', 82 | help='Dir to save recall') 83 | parser.add_argument('--out_1019_dir', type=str, default='/home/a409/users/huboni/Projects/code/cnn-matching/result/mavic-xjtu/', 84 | help='Dir to save recall') 85 | parser.add_argument('--out_file', type=str, default='cnn_mavic-xjtu_dist50_v50', 86 | help='Dir to save recall') 87 | parser.add_argument('--fig_res_file', type=str, default='pnp_fig_res_npu_dist20.csv') 88 | parser.add_argument('--global_res_file', type=str, default='pnp_fig_res_npu_dist20.csv') 89 | parser.add_argument('--rerank_res_file', type=str, default='pnp_fig_res_npu_dist20.csv') 90 | parser.add_argument('--rerank1_res_file', type=str, default='pnp_fig_res_npu_dist20.csv') 91 | parser.add_argument('--pinpoint_res_file', type=str, default='pnp_fig_res_npu_dist20.csv') 92 | 93 | 94 | parser.add_argument('--model_type', type=str, default='cnn', help='or cnn') 95 | parser.add_argument('--img_size', type=int, default=500, help='or 224') 96 | parser.add_argument('--checkpoint', type=str, default='/home/a409/users/huboni/Projects/code/cnn-matching/models/d2_tf.pth', 97 | help='or ./models/d2_tf.pth') 98 | 99 | opt = parser.parse_args() 100 | print(opt) 101 | 102 | n_values = [1, 5, 10, 20, 50] # FIXME 103 | if opt.img_size == 500: 104 | # gsd = 300 / 500 # alto 105 | # gsd = 300 / 500 # xjtu 106 | gsd = 0.167 # npu 107 | print("gsd:", gsd) 108 | # k = np.array([[345, 0, 247.328], [0, 345, 245], [0, 0, 1]]) # ALTO 109 | # k = np.array([[350, 0, 250], [0, 350, 250], [0, 0, 1]]) # xjtu 110 | k = np.array([[475.48, 0, 250], [0, 475.48, 250], [0, 0, 1]]) # npu 111 | # k = np.array([[3051.6 / (3000/500), 0, 1500/(3000/500)], [0, 3051.6/(3000/500), 1500/(3000/500)], [0, 0, 1]]) # TerraTrack 112 | elif opt.img_size == 224: 113 | gsd = 300 / 500 * (500/224) 114 | print("gsd:", gsd) 115 | k = np.array([[345 / (500/224), 0, 247.328 / (500/224)], [0, 345 / (500/224), 245 / (500/224)], [0, 0, 1]]) 116 | num_rec = (n_values[-1]) 117 | with open(opt.netvlad_result_path, 'r') as file: 118 | lines = file.readlines()[2:] # 前两行是注释数据 119 | numQ = len(lines)//num_rec 120 | print("numQ:", numQ) 121 | 122 | query_idx_list_all = [] 123 | refidx_list = [] 124 | refidx_inliners_list = [] 125 | 126 | start0 = time.perf_counter() 127 | # rerank by cnn match + flann + ransac 128 | qData = pd.read_csv(opt.query_origin_path) 129 | # qData = qData.sort_values(by='name', ascending=True) 130 | print("qData", qData) 131 | dbData = pd.read_csv(opt.reference_origin_path) 132 | # dbData = dbData.sort_values(by='name', ascending=True) 133 | # print("daData", dbData) 134 | for i in tqdm(range(len(lines)), desc='cnn_match_ing'): 135 | 136 | query_file = lines[i].split(',')[0].strip() # query_file 按照idx的顺序逐个存取的,所以只需要将ref的idx关联起来 137 | query_name = os.path.basename(query_file) 138 | reference_file = lines[i].split(',')[1].strip() # reference_name 139 | # reference_name = os.path.join(reference_file.split('/')[-2], reference_file.split('/')[-1]) # alto 140 | reference_name = os.path.basename(reference_file) # terratrack FIXME 141 | # print("reference_name", reference_name) 142 | query_idx = qData[qData.name == query_name].index.to_list()[0] 143 | query_idx_list_all.append(query_idx) 144 | ref_idx = dbData[dbData.name == reference_name].index.to_list()[0] 145 | inliners, avg_dist, q_kps, r_kps = matchs.cnn_match(query_file, reference_file, opt.model_type, opt.checkpoint, opt.img_size) 146 | # print("-------current avg_list------:", avg_dist) 147 | # print("--------num inliners---------:", inliners) 148 | refidx_list.append(ref_idx) 149 | refidx_inliners_list.append([ref_idx, avg_dist, q_kps, r_kps]) 150 | query_idx_list = query_idx_list_all[0: len(query_idx_list_all) : num_rec] 151 | print("-----------q idx list:", query_idx_list) 152 | match_time = time.perf_counter() 153 | print('CNN match time is %6.3f' % (match_time - start0)) 154 | 155 | # gen origin input netvlad predictions 156 | refidx_list_split = [refidx_list[i:i + num_rec] for i in range(0, len(refidx_list), num_rec)] 157 | netvlad_predictions = np.array(refidx_list_split) 158 | 159 | # split every query rematch[ref_idx inliners] 160 | refidx_inliners_list_split =[refidx_inliners_list[i:i+num_rec] for i in range(0, len(refidx_inliners_list), num_rec)] 161 | # reranking every query rematch[ref_idx inliners] by inliners avg_dist 162 | start1 = time.perf_counter() 163 | cnnmatch_predictions = [] 164 | pinpoint_utms = [] 165 | name_l = [] 166 | arr_ppl = [] 167 | arr_r1l = [] 168 | global_candidates_utm = [] 169 | rerank_candidates_utm = [] 170 | rerank_1_utm = [] 171 | pinpoint_utm_list = [] 172 | pass_count = 0 173 | fig_csv_file = os.path.join(opt.out_dir, opt.fig_res_file) 174 | global_csv_file = os.path.join(opt.out_1019_dir, opt.global_res_file) 175 | rerank_csv_file = os.path.join(opt.out_1019_dir, opt.rerank_res_file) 176 | rerank1_csv_file = os.path.join(opt.out_1019_dir, opt.rerank1_res_file) 177 | pinpoint_csv_file = os.path.join(opt.out_1019_dir, opt.pinpoint_res_file) 178 | 179 | 180 | 181 | pnp_select_res_dict = pd.DataFrame() 182 | for i, idx_inliners_list in enumerate(refidx_inliners_list_split): 183 | print("-----------------",i) 184 | reranked_rf_idx = np.array(sorted(idx_inliners_list, key=lambda x:x[1], reverse=False))[:, 0][:5] # 按照第二位 降序排列 并输出第一位 FIXME Terratrack 185 | cnnmatch_predictions.append(reranked_rf_idx) 186 | reranked_query_kps = np.array(sorted(idx_inliners_list, key=lambda x:x[1], reverse=False))[:, 2][:5] # FIXME 187 | reranked_ref_kps = np.array(sorted(idx_inliners_list, key=lambda x:x[1], reverse=False))[:, 3][:5] # FIXME 188 | query_utm = np.array(qData.loc[query_idx_list[i], ["easting", "northing"]]) 189 | query_utm = query_utm[::-1] # terratrack swap easting northing 190 | print("query name:", qData.loc[query_idx_list[i], ["name"]]) 191 | 192 | # print("query name:", qData.loc[query_id, ["name"]]) 193 | ref_image_position_o = dbData.loc[reranked_rf_idx.tolist(), ["easting", "northing"]] 194 | # ref_image_position = np.array([ref_image_position_o["easting"], ref_image_position_o["northing"]]).transpose() # alto 195 | ref_image_position = np.array([ref_image_position_o["northing"], ref_image_position_o["easting"]]).transpose() # terratrack 196 | 197 | ref_kpts_position = [] 198 | for j, image_position in enumerate(ref_image_position): 199 | ref_kpts_position.append(image_point2gps_point(np.array(reranked_ref_kps[j]).transpose(), image_position, gsd, opt.img_size)) 200 | ref_kpts_position = np.hstack(ref_kpts_position).transpose() # 关键点GPS坐标 201 | query_kps = [np.array(ii) for ii in reranked_query_kps] 202 | query_kps_position = np.vstack(query_kps) 203 | 204 | success, R_vec, t, inliers = cv2.solvePnPRansac(ref_kpts_position, query_kps_position, k, np.zeros(4), 205 | flags=cv2.SOLVEPNP_ITERATIVE, iterationsCount=5000, 206 | reprojectionError=10) 207 | # print("R_vec:", R_vec) 208 | print("success:", success) 209 | if not success: 210 | continue 211 | r_w2c, _ = cv2.Rodrigues(R_vec) 212 | t_w2c = t 213 | r_c2w = np.linalg.inv(r_w2c) 214 | t_c2w = -r_c2w @ t_w2c 215 | # print("---------t_c2w-------:", t_c2w) 216 | pinpoint_utm = t_c2w[:, 0][:2] 217 | name_l.append(qData.loc[query_idx_list[i], ["name"]]) 218 | pinpoint_utms.append(pinpoint_utm) 219 | # print("pinpoint utm:", pinpoint_utm) 220 | pp_loss = np.linalg.norm(query_utm - pinpoint_utm) 221 | print("pp_loss", pp_loss) 222 | # if pp_loss>10: 223 | # pass_count += 1 224 | # continue 225 | arr_ppl.append(pp_loss) 226 | r1_utm = dbData.loc[reranked_rf_idx.tolist()[0], ["easting", "northing"]] 227 | print("r1 name:", dbData.loc[reranked_rf_idx.tolist()[0], ["name"]]) 228 | r2_utm = dbData.loc[reranked_rf_idx.tolist()[1], ["easting", "northing"]] 229 | r3_utm = dbData.loc[reranked_rf_idx.tolist()[2], ["easting", "northing"]] 230 | r4_utm = dbData.loc[reranked_rf_idx.tolist()[3], ["easting", "northing"]] 231 | r5_utm = dbData.loc[reranked_rf_idx.tolist()[4], ["easting", "northing"]] 232 | # swap easting northing below 233 | r1_loss = np.linalg.norm(query_utm - np.array(r1_utm)[::-1]) 234 | r2_loss = np.linalg.norm(query_utm - np.array(r2_utm)[::-1]) 235 | r3_loss = np.linalg.norm(query_utm - np.array(r3_utm)[::-1]) 236 | r4_loss = np.linalg.norm(query_utm - np.array(r4_utm)[::-1]) 237 | r5_loss = np.linalg.norm(query_utm - np.array(r5_utm)[::-1]) 238 | 239 | r_mean_loss = sum([r1_loss,r2_loss,r3_loss,r4_loss,r5_loss])/5 240 | arr_r1l.append(r1_loss) 241 | print("localization loss:", pp_loss) 242 | print("recall@1 loss:", r1_loss) 243 | # save csv for localization fig 244 | # if r_mean_loss-pp_loss>=50.0 and pp_loss<2.0: 245 | # if r1_loss-pp_loss>10.0: 246 | if True: 247 | global_candidates_utm.extend((dbData.loc[np.array(idx_inliners_list)[:,0][:5], ["easting", "northing"]]).to_numpy().tolist()) 248 | rerank_candidates_utm.extend((dbData.loc[reranked_rf_idx.tolist()[:5], ["easting", "northing"]]).to_numpy().tolist()) 249 | rerank_1_utm.append((dbData.loc[reranked_rf_idx.tolist()[0], ["easting", "northing"]]).to_numpy().tolist()) 250 | pinpoint_utm_list.append(pinpoint_utm.tolist()) 251 | 252 | fig_res = dict(query_name = (qData.loc[query_idx_list[i], ["name"]]).to_numpy(), 253 | global_recall5_idx = (np.array(idx_inliners_list)[:, 0][:5]), global_recall5_name = (dbData.loc[np.array(idx_inliners_list)[:, 0][:5],["name"]]).to_numpy(), 254 | global_recall5_utm = (dbData.loc[np.array(idx_inliners_list)[:,0][:5], ["easting", "northing"]]).to_numpy(), 255 | rerank_recall5_idx = (reranked_rf_idx.tolist()[:5]), rerank_recall5_name = (dbData.loc[reranked_rf_idx.tolist()[:5], ["name"]]).to_numpy(), 256 | rerank_recall5_utm = (dbData.loc[reranked_rf_idx.tolist()[:5], ["easting", "northing"]]).to_numpy(), 257 | rerank_recall1_idx = (reranked_rf_idx.tolist()[0]), rerank_recall1_name = (dbData.loc[reranked_rf_idx.tolist()[:1], ["name"]]).to_numpy(), 258 | rerank_recall1_utm = (dbData.loc[reranked_rf_idx.tolist()[0], ["easting", "northing"]]).to_numpy(), 259 | pinpoint_utm = pinpoint_utm, 260 | query_utm = query_utm, r1_loss = r1_loss, pp_loss=pp_loss) 261 | # print(fig_res) 262 | pnp_select_res_dict = pnp_select_res_dict._append(fig_res, ignore_index=True) 263 | 264 | pd.DataFrame(pnp_select_res_dict).to_csv(fig_csv_file, index=True, encoding='gbk',float_format='%.6f') 265 | pd.DataFrame(global_candidates_utm).to_csv(global_csv_file, index=True, encoding='gbk',float_format='%.6f') 266 | pd.DataFrame(rerank_candidates_utm).to_csv(rerank_csv_file, index=True, encoding='gbk',float_format='%.6f') 267 | pd.DataFrame(rerank_1_utm).to_csv(rerank1_csv_file, index=True, encoding='gbk',float_format='%.6f') 268 | pd.DataFrame(pinpoint_utm_list).to_csv(pinpoint_csv_file, index=True, encoding='gbk',float_format='%.6f') 269 | 270 | 271 | 272 | # print("global candidates list:", global_candidates_utm) 273 | # print("rerank candidates list:", rerank_candidates_utm) 274 | # print("rerank 1 utm list:", rerank_1_utm) 275 | # print("pinpoint utm list", pinpoint_utm_list) 276 | 277 | print("pass count:", pass_count) 278 | print("avg pinpoint loss:", np.mean(np.array(arr_ppl))) 279 | print("avg recall@1 loss:", np.mean(np.array(arr_r1l))) 280 | print("var pinpoint loss:", np.var(np.array(arr_ppl))) 281 | print("var recall@1 loss:", np.var(np.array(arr_r1l))) 282 | print("std pinpoint loss:", np.std(np.array(arr_ppl))) 283 | print("std recall@1 loss:", np.std(np.array(arr_r1l))) 284 | # 保存pinpoint.csv 285 | pinpoint_dict = {"easting":np.array(pinpoint_utms)[:,0], 286 | "northing":np.array(pinpoint_utms)[:,1], 287 | "name":name_l} 288 | 289 | out_csv_file = os.path.join(opt.out_file + '.csv') 290 | print("pinpoint utm save to %s" % out_csv_file) 291 | pd.DataFrame(pinpoint_dict).to_csv(os.path.join(opt.out_dir, 'pinpoint_utm', out_csv_file), index=False) 292 | 293 | 294 | cnnmatch_predictions = np.array(cnnmatch_predictions) 295 | cnn_pred_time = time.perf_counter() 296 | print('rerank match get cnn predict time is %6.3f' % (cnn_pred_time - start1)) 297 | 298 | # get ground truth 299 | utmQ,utmDb,posDistThr = parse_gt_file(opt.ground_truth_path) 300 | gt = get_positives(utmQ, utmDb, posDistThr) 301 | 302 | netvlad_recalls = compute_recall(query_idx_list, gt, netvlad_predictions, numQ, n_values, 'netvlad_match') 303 | cnnmatch_recalls = compute_recall(query_idx_list, gt, cnnmatch_predictions, numQ, n_values, 'cnn_match') 304 | 305 | # out_recall_file = os.path.join(opt.out_dir, 'recalls', opt.checkpoint.split('/')[-3] + '_' + str(opt.img_size) + '.txt') 306 | out_recall_file = os.path.join(opt.out_dir, 'recalls', opt.out_file + '.txt') 307 | 308 | print('Writing recalls to', out_recall_file) 309 | write_recalls(opt, netvlad_recalls, cnnmatch_recalls, n_values, out_recall_file) 310 | 311 | 312 | 313 | 314 | if __name__ == "__main__": 315 | main() 316 | 317 | -------------------------------------------------------------------------------- /matching.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import imageio 5 | import plotmatch 6 | from lib.cnn_feature import cnn_feature_extract 7 | import matplotlib.pyplot as plt 8 | import time 9 | from skimage import measure 10 | from skimage import transform 11 | import math 12 | import os 13 | from os.path import exists 14 | from os.path import join 15 | 16 | 17 | #time count 18 | start = time.perf_counter() 19 | 20 | _RESIDUAL_THRESHOLD = 20 21 | 22 | imgfile1 = '/home/a409/users/huboni/Paper/locally_global_match_pnp/fig/fig-matching/q000000.png' 23 | imgfile2 = '/home/a409/users/huboni/Paper/locally_global_match_pnp/fig/fig-matching/r000000.png' 24 | 25 | start = time.perf_counter() 26 | 27 | # read left image 28 | image1_o = cv2.imread(imgfile1) 29 | image1 = cv2.resize(image1_o, (224,224)) 30 | image2_o = cv2.imread(imgfile2) 31 | image2 = cv2.resize(image2_o, (224,224)) 32 | print('read image time is %6.3f' % (time.perf_counter() - start)) 33 | 34 | start0 = time.perf_counter() 35 | 36 | kps_left, sco_left, des_left = cnn_feature_extract(image1_o, nfeatures = -1) 37 | kps_right, sco_right, des_right = cnn_feature_extract(image2_o, nfeatures = -1) 38 | 39 | print('Feature_extract time is %6.3f, left: %6.3f,right %6.3f' % ((time.perf_counter() - start), len(kps_left), len(kps_right))) 40 | start = time.perf_counter() 41 | 42 | #Flann特征匹配 43 | FLANN_INDEX_KDTREE = 1 44 | index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) 45 | search_params = dict(checks=40) 46 | flann = cv2.FlannBasedMatcher(index_params, search_params) 47 | matches = flann.knnMatch(des_left, des_right, k=2) 48 | matches_reverse = flann.knnMatch(des_right, des_left, k=2) 49 | 50 | 51 | goodMatch = [] 52 | locations_1_to_use = [] 53 | locations_2_to_use = [] 54 | 55 | disdif_avg = 0 56 | for m, n in matches: 57 | disdif_avg += n.distance - m.distance 58 | disdif_avg = disdif_avg / len(matches) 59 | 60 | for m, n in matches: 61 | if n.distance > m.distance + disdif_avg and matches_reverse[m.trainIdx][1].distance > matches_reverse[m.trainIdx][0].distance+disdif_avg and matches_reverse[m.trainIdx][0].trainIdx == m.queryIdx: 62 | goodMatch.append(m) 63 | p2 = cv2.KeyPoint(kps_right[m.trainIdx][0], kps_right[m.trainIdx][1], 1) 64 | p1 = cv2.KeyPoint(kps_left[m.queryIdx][0], kps_left[m.queryIdx][1], 1) 65 | locations_1_to_use.append([p1.pt[0], p1.pt[1]]) 66 | locations_2_to_use.append([p2.pt[0], p2.pt[1]]) 67 | print('match num is %d' % len(goodMatch)) 68 | locations_1_to_use = np.array(locations_1_to_use) 69 | locations_2_to_use = np.array(locations_2_to_use) 70 | 71 | # Perform geometric verification using RANSAC. 72 | _, inliers = measure.ransac((locations_1_to_use, locations_2_to_use), 73 | transform.AffineTransform, 74 | min_samples=3, 75 | residual_threshold=_RESIDUAL_THRESHOLD, 76 | max_trials=1000) 77 | 78 | print('Found %d inliers' % sum(inliers)) 79 | 80 | inlier_idxs = np.nonzero(inliers)[0] 81 | #最终匹配结果 82 | matches = np.column_stack((inlier_idxs, inlier_idxs)) 83 | print('whole time is %6.3f' % (time.perf_counter() - start0)) 84 | 85 | # save inliners 像素坐标 86 | coordinate_out_dir = '/home/a409/users/chenlin/VPR_huboni/Val/GPR_Dataset/VPR_results/' 87 | coordinate_out_ref = [] 88 | coordinate_out_query = [] 89 | 90 | dist = 0 91 | for i in range(matches.shape[0]): 92 | idx1 = matches[i, 0] 93 | idx2 = matches[i, 1] 94 | # 计算两幅匹配图像最终匹配点对的像素距离 95 | dist += math.sqrt(((locations_1_to_use[idx1,0]-locations_2_to_use[idx2,0])**2)+((locations_1_to_use[idx1,1]-locations_2_to_use[idx2,1])**2)) 96 | avg_dist = dist/matches.shape[0] 97 | print("avg_dist:", avg_dist) 98 | # coordinate_out_query.append([locations_1_to_use[idx1, 0], locations_1_to_use[idx1, 1]]) 99 | # coordinate_out_ref.append([locations_2_to_use[idx2, 0], locations_2_to_use[idx2, 1]]) 100 | 101 | # 输出最终两幅匹配图像匹配点对坐标 102 | # if not exists(coordinate_out_dir): 103 | # os.mkdir(coordinate_out_dir) 104 | # out_file = join(coordinate_out_dir, 'coordinate.txt') 105 | # print('Writing recalls to', out_file) 106 | # with open(out_file, 'a+') as res_out: 107 | # res_out.write(imgfile1 + '\n' + str(coordinate_out_query) + '\n') 108 | # res_out.write(imgfile2 + '\n' + str(coordinate_out_ref) + '\n') 109 | 110 | # Visualize correspondences, and save to file. 111 | #1 绘制匹配连线 112 | plt.rcParams['savefig.dpi'] = 500 #图片像素 113 | plt.rcParams['figure.dpi'] = 500 #分辨率 114 | plt.rcParams['figure.figsize'] = (3.0, 2.0) # 设置figure_size尺寸 115 | _, ax = plt.subplots() 116 | plotmatch.plot_matches( 117 | ax, 118 | image1_o, 119 | image2_o, 120 | locations_1_to_use, 121 | locations_2_to_use, 122 | np.column_stack((inlier_idxs, inlier_idxs)), 123 | plot_matche_points = False, 124 | matchline = True, 125 | matchlinewidth = 0.1) 126 | ax.axis('off') 127 | ax.set_title('') 128 | # plt.show() 129 | plt.savefig('/home/a409/users/huboni/Paper/locally_global_match_pnp/fig/fig-matching/0.png') -------------------------------------------------------------------------------- /performance.ini: -------------------------------------------------------------------------------- 1 | [path] 2 | input_predictions = '/home/a409/users/huboni/Projects/code/Patch-NetVLAD/patchnetvlad/results/mavic_npu_2048_7_dist20/NetVLAD_predictions.txt' 3 | ground_truth = '/home/a409/users/huboni/Projects/code/Patch-NetVLAD/patchnetvlad/dataset_gt_files/mavic_npu_dist20.npz' 4 | reference_csv = '/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/reference.csv' 5 | out_dir = '/home/a409/users/huboni/Projects/code/cnn-matching/result/TerraTrack_rs/' 6 | 7 | [feature_match] 8 | n_values_all = 1,5,10,20,50 9 | -------------------------------------------------------------------------------- /plotmatch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def plot_matches(ax, image1, image2, keypoints1, keypoints2, matches, 5 | keypoints_color='r', matches_color=None, plot_matche_points=True, matchline = True, matchlinewidth = 0.5, 6 | alignment='horizontal'): 7 | """Plot matched features. 8 | 9 | Parameters 10 | ---------- 11 | ax : matplotlib.axes.Axes 12 | Matches and image are drawn in this ax. 13 | image1 : (N, M [, 3]) array 14 | First grayscale or color image. 15 | image2 : (N, M [, 3]) array 16 | Second grayscale or color image. 17 | keypoints1 : (K1, 2) array 18 | First keypoint coordinates as ``(row, col)``. 19 | keypoints2 : (K2, 2) array 20 | Second keypoint coordinates as ``(row, col)``. 21 | matches : (Q, 2) array 22 | Indices of corresponding matches in first and second set of 23 | descriptors, where ``matches[:, 0]`` denote the indices in the first 24 | and ``matches[:, 1]`` the indices in the second set of descriptors. 25 | keypoints_color : matplotlib color, optional 26 | Color for keypoint locations. 27 | matches_color : matplotlib color, optional 28 | Color for lines which connect keypoint matches. By default the 29 | color is chosen randomly. 30 | only_matches : bool, optional 31 | Whether to only plot matches and not plot the keypoint locations. 32 | alignment : {'horizontal', 'vertical'}, optional 33 | Whether to show images side by side, ``'horizontal'``, or one above 34 | the other, ``'vertical'``. 35 | 36 | """ 37 | 38 | #image1 = img_as_float(image1) 39 | #image2 = img_as_float(image2) 40 | 41 | new_shape1 = list(image1.shape) 42 | new_shape2 = list(image2.shape) 43 | 44 | if image1.shape[0] < image2.shape[0]: 45 | new_shape1[0] = image2.shape[0] 46 | elif image1.shape[0] > image2.shape[0]: 47 | new_shape2[0] = image1.shape[0] 48 | 49 | if image1.shape[1] < image2.shape[1]: 50 | new_shape1[1] = image2.shape[1] 51 | elif image1.shape[1] > image2.shape[1]: 52 | new_shape2[1] = image1.shape[1] 53 | 54 | if new_shape1 != image1.shape: 55 | #new_image1 = np.zeros(new_shape1, dtype=image1.dtype) 56 | new_image1 = np.full(new_shape1, 255) 57 | new_image1[:image1.shape[0], :image1.shape[1]] = image1 58 | image1 = new_image1 59 | 60 | if new_shape2 != image2.shape: 61 | #new_image2 = np.zeros(new_shape2, dtype=image2.dtype) 62 | new_image2 = np.full(new_shape2, 255) 63 | new_image2[:image2.shape[0], :image2.shape[1]] = image2 64 | image2 = new_image2 65 | 66 | offset = np.array(image1.shape) 67 | if alignment == 'horizontal': 68 | if image2.ndim == 3: 69 | blank = np.full((new_shape2[0], 10, 3), 255) 70 | if image1.ndim == 2 or image2.ndim == 2: 71 | blank = np.full((new_shape2[0], 10), 255) 72 | image = np.concatenate([image1, blank, image2], axis=1) 73 | offset[0] = 0 74 | offset[1] += 10 75 | elif alignment == 'vertical': 76 | if image2.ndim == 3 : 77 | blank = np.full(10,(new_shape2[1], 3), 255) 78 | if image1.ndim == 2: 79 | blank = np.full(10,(new_shape2[1]), 255) 80 | image = np.concatenate([image1, blank, image2], axis=0) 81 | offset[1] = 0 82 | offset[0] += 10 83 | else: 84 | mesg = ("plot_matches accepts either 'horizontal' or 'vertical' for " 85 | "alignment, but '{}' was given. See " 86 | "https://scikit-image.org/docs/dev/api/skimage.feature.html#skimage.feature.plot_matches " # noqa 87 | "for details.").format(alignment) 88 | raise ValueError(mesg) 89 | 90 | 91 | if plot_matche_points: 92 | ax.scatter(keypoints1[:, 0], keypoints1[:, 1], 93 | facecolors='none', edgecolors=keypoints_color, marker = '.') 94 | ax.scatter(keypoints2[:, 0] + offset[1], keypoints2[:, 1] + offset[0], 95 | facecolors='none', edgecolors=keypoints_color, marker = '.') 96 | 97 | ax.imshow(image, interpolation='nearest', cmap='gray') 98 | ax.axis((0, image1.shape[1] + offset[1], image1.shape[0] + offset[0], 0)) 99 | 100 | if matchline == True: 101 | for i in range(matches.shape[0]): 102 | idx1 = matches[i, 0] 103 | idx2 = matches[i, 1] 104 | 105 | if matches_color is None: 106 | color = np.random.rand(3) 107 | else: 108 | color = matches_color 109 | 110 | ax.plot((keypoints1[idx1, 0], keypoints2[idx2, 0] + offset[1]), 111 | (keypoints1[idx1, 1], keypoints2[idx2, 1] + offset[0]), 112 | '-', color=color, linewidth=matchlinewidth, marker='+', markersize=4) 113 | 114 | 115 | def plot_matches2(ax, image1, image2, keypoints1, keypoints2, 116 | keypoints_color='r', matches_color=None, plot_matche_points=True, matchline = True, matchlinewidth = 0.5, 117 | alignment='horizontal'): 118 | 119 | 120 | new_shape1 = list(image1.shape) 121 | new_shape2 = list(image2.shape) 122 | 123 | if image1.shape[0] < image2.shape[0]: 124 | new_shape1[0] = image2.shape[0] 125 | elif image1.shape[0] > image2.shape[0]: 126 | new_shape2[0] = image1.shape[0] 127 | 128 | if image1.shape[1] < image2.shape[1]: 129 | new_shape1[1] = image2.shape[1] 130 | elif image1.shape[1] > image2.shape[1]: 131 | new_shape2[1] = image1.shape[1] 132 | 133 | if new_shape1 != image1.shape: 134 | #new_image1 = np.zeros(new_shape1, dtype=image1.dtype) 135 | new_image1 = np.full(new_shape1, 255) 136 | new_image1[:image1.shape[0], :image1.shape[1]] = image1 137 | image1 = new_image1 138 | 139 | if new_shape2 != image2.shape: 140 | #new_image2 = np.zeros(new_shape2, dtype=image2.dtype) 141 | new_image2 = np.full(new_shape2, 255) 142 | new_image2[:image2.shape[0], :image2.shape[1]] = image2 143 | image2 = new_image2 144 | 145 | offset = np.array(image1.shape) 146 | if alignment == 'horizontal': 147 | if image2.ndim == 3: 148 | blank = np.full((new_shape2[0], 10, 3), 255) 149 | if image1.ndim == 2 or image2.ndim == 2: 150 | blank = np.full((new_shape2[0], 10), 255) 151 | image = np.concatenate([image1, blank, image2], axis=1) 152 | offset[0] = 0 153 | offset[1] += 10 154 | elif alignment == 'vertical': 155 | if image2.ndim == 3 : 156 | blank = np.full(10,(new_shape2[1], 3), 255) 157 | if image1.ndim == 2: 158 | blank = np.full(10,(new_shape2[1]), 255) 159 | image = np.concatenate([image1, blank, image2], axis=0) 160 | offset[1] = 0 161 | offset[0] += 10 162 | else: 163 | mesg = ("plot_matches accepts either 'horizontal' or 'vertical' for " 164 | "alignment, but '{}' was given. See " 165 | "https://scikit-image.org/docs/dev/api/skimage.feature.html#skimage.feature.plot_matches " # noqa 166 | "for details.").format(alignment) 167 | raise ValueError(mesg) 168 | 169 | 170 | if plot_matche_points: 171 | ax.scatter(keypoints1[:, 0], keypoints1[:, 1], 172 | facecolors='none', edgecolors=keypoints_color, marker = '.') 173 | ax.scatter(keypoints2[:, 0] + offset[1], keypoints2[:, 1] + offset[0], 174 | facecolors='none', edgecolors=keypoints_color, marker = '.') 175 | 176 | ax.imshow(image, interpolation='nearest', cmap='gray') 177 | ax.axis((0, image1.shape[1] + offset[1], image1.shape[0] + offset[0], 0)) 178 | 179 | if matchline == True: 180 | for i in range(keypoints1.shape[0]): 181 | 182 | if matches_color is None: 183 | color = np.random.rand(3) 184 | else: 185 | color = matches_color 186 | 187 | ax.plot((keypoints1[i, 0], keypoints2[i, 0] + offset[1]), 188 | (keypoints1[i, 1], keypoints2[i, 1] + offset[0]), 189 | '-', color=color, linewidth=matchlinewidth, marker='+', markersize=2) 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /result/cv-match.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from matplotlib import pyplot as plt 4 | 5 | imgname1 = '/mnt/a409/users/chenlin/VPR_huboni/Val/GPR_Dataset/query_images/000000.png' 6 | imgname2 = '/mnt/a409/users/chenlin/VPR_huboni/Val/GPR_Dataset/reference_images/offset_0_None/000000.png' 7 | 8 | sift = cv2.SIFT_create() 9 | 10 | # FLANN 参数设计 11 | FLANN_INDEX_KDTREE = 0 12 | index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5) 13 | search_params = dict(checks=50) 14 | flann = cv2.FlannBasedMatcher(index_params,search_params) 15 | 16 | img1 = cv2.imread(imgname1) 17 | gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) #灰度处理图像 18 | kp1, des1 = sift.detectAndCompute(img1,None)#des是描述子 19 | 20 | img2 = cv2.imread(imgname2) 21 | gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) 22 | kp2, des2 = sift.detectAndCompute(img2,None) 23 | 24 | hmerge = np.hstack((gray1, gray2)) #水平拼接 25 | cv2.imshow("./gray.png", hmerge) #拼接显示为gray 26 | cv2.waitKey(0) 27 | 28 | img3 = cv2.drawKeypoints(img1,kp1,img1,color=(255,0,255)) 29 | img4 = cv2.drawKeypoints(img2,kp2,img2,color=(255,0,255)) 30 | 31 | hmerge = np.hstack((img3, img4)) #水平拼接 32 | cv2.imshow("./point.png", hmerge) #拼接显示为gray 33 | cv2.waitKey(0) 34 | matches = flann.knnMatch(des1,des2,k=2) 35 | matchesMask = [[0,0] for i in range(len(matches))] 36 | 37 | good = [] 38 | for m,n in matches: 39 | if m.distance < 0.1*n.distance: 40 | good.append([m]) 41 | 42 | img5 = cv2.drawMatchesKnn(img1,kp1,img2,kp2,matches,None,flags=2) 43 | cv2.imshow("./FLANN.png", img5) 44 | cv2.waitKey(0) 45 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /result/plot_rs.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | 6 | 7 | x_axis_data = [1,5,10,20,50,100,200] 8 | y_dist20_patch = [0.2043,0.3290,0.3545,0.3622,0.3697,0.3355,0.3314] 9 | y_dist20_swin = [0.2043,0.5629,0.5806,0.5987,0.6259,0.6021,0.5903] 10 | y_dist50_patch = [0.6200,0.8023,0.8266,0.8468,0.8504,0.8456,0.8325] 11 | y_dist50_swin = [0.6200,0.8634,0.9030,0.9231,0.9448,0.9293,0.9086] 12 | 13 | 14 | #画图 15 | 16 | plt.plot(x_axis_data, y_dist20_patch, 'go--', alpha=0.5, linewidth=1, label='Patch-NetVLAD-2048 dist20')#' 17 | plt.plot(x_axis_data, y_dist20_swin, 'ms-', alpha=0.5, linewidth=1, label='CurriculumLoc(ours) dist20') 18 | plt.plot(x_axis_data, y_dist50_patch, 'bo--', alpha=0.5, linewidth=1, label='Patch-NetVLAD-2048 dist50') 19 | plt.plot(x_axis_data, y_dist50_swin, 'rs-', alpha=0.5, linewidth=1, label='CurriculumLoc(ours) dist50') 20 | plt.xticks([1,5,10,20,50,100,200]) 21 | plt.yticks([0,0.2,0.4,0.6,0.8,1.0]) 22 | 23 | for a, b in zip(x_axis_data, y_dist20_patch): 24 | plt.text(a, b, str(b), ha='center', va='bottom', fontsize=8) # ha='center', va='top' 25 | for a, b1 in zip(x_axis_data, y_dist20_swin): 26 | plt.text(a, b1, str(b1), ha='center', va='bottom', fontsize=8) 27 | for a, b2 in zip(x_axis_data, y_dist50_patch): 28 | plt.text(a, b2, str(b2), ha='center', va='bottom', fontsize=8) 29 | for a, b3 in zip(x_axis_data, y_dist50_swin): 30 | plt.text(a, b3, str(b3), ha='center', va='bottom', fontsize=8) 31 | 32 | 33 | plt.legend() #显示上面的label 34 | plt.xlabel('Candidates number') 35 | plt.ylabel('rerank R@1')#accuracy 36 | 37 | 38 | plt.savefig("./plot_rs.png") 39 | plt.show() 40 | 41 | -------------------------------------------------------------------------------- /terratrack_utils/compute_mean_std.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | from matplotlib.pyplot import imread 5 | import numpy as np 6 | # from scipy.misc import imread 7 | filepath = '/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/query_images_512' # 数据集目录 8 | pathDir = os.listdir(filepath) 9 | R_channel = 0 10 | G_channel = 0 11 | B_channel = 0 12 | for idx in range(len(pathDir)): 13 | filename = pathDir[idx] 14 | img = imread(os.path.join(filepath, filename)) / 255.0 15 | R_channel = R_channel + np.sum(img[:, :, 0]) 16 | G_channel = G_channel + np.sum(img[:, :, 1]) 17 | B_channel = B_channel + np.sum(img[:, :, 2]) 18 | num = len(pathDir) * 512 * 512 # 这里(512,512)是每幅图片的大小,所有图片尺寸都一样 19 | R_mean = R_channel / num 20 | G_mean = G_channel / num 21 | B_mean = B_channel / num 22 | R_channel = 0 23 | G_channel = 0 24 | B_channel = 0 25 | for idx in range(len(pathDir)): 26 | filename = pathDir[idx] 27 | img = imread(os.path.join(filepath, filename)) / 255.0 28 | R_channel = R_channel + np.sum((img[:, :, 0] - R_mean) ** 2) 29 | G_channel = G_channel + np.sum((img[:, :, 1] - G_mean) ** 2) 30 | B_channel = B_channel + np.sum((img[:, :, 2] - B_mean) ** 2) 31 | R_var = np.sqrt(R_channel / num) 32 | G_var = np.sqrt(G_channel / num) 33 | B_var = np.sqrt(B_channel / num) 34 | print("R_mean is %f, G_mean is %f, B_mean is %f" % (R_mean, G_mean, B_mean)) 35 | print("R_var is %f, G_var is %f, B_var is %f" % (R_var, G_var, B_var)) 36 | -------------------------------------------------------------------------------- /terratrack_utils/depth_bin2array.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import warnings 4 | import os 5 | import h5py 6 | import argparse 7 | 8 | 9 | ''' 10 | convert colmap depth map from .bin to .png or to .h5 11 | ''' 12 | warnings.filterwarnings('ignore') # 屏蔽nan与min_depth比较时产生的警告 13 | 14 | # camnum = 12 15 | # fB = 32504; 16 | min_depth_percentile = 2 17 | max_depth_percentile = 98 18 | 19 | parser = argparse.ArgumentParser(description='convert depth bin to array and save h5') 20 | 21 | parser.add_argument( 22 | '--depthmapsdir', type=str, 23 | help='path to the origin colmap output depth', 24 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/query_reference_SfM_model_500/dense/stereo/depth_maps' 25 | ) 26 | parser.add_argument( 27 | '--output_h5_path', type=str, 28 | help='path to the save h5 depth', 29 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/query_reference_SfM_model_500/dense/stereo/depth_bin_h5' 30 | ) 31 | args = parser.parse_args() 32 | 33 | 34 | if not os.path.exists(args.output_h5_path): 35 | os.mkdir(args.output_h5_path) 36 | 37 | def read_array(path): 38 | with open(path, "rb") as fid: 39 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 40 | usecols=(0, 1, 2), dtype=int) 41 | # print("width:", width) 42 | # print("height:", height) 43 | # print("channels:", channels) 44 | fid.seek(0) 45 | num_delimiter = 0 46 | byte = fid.read(1) 47 | while True: 48 | if byte == b"&": 49 | num_delimiter += 1 50 | if num_delimiter >= 3: 51 | break 52 | byte = fid.read(1) 53 | array = np.fromfile(fid, np.float32) 54 | # print(array.shape) 55 | array = array.reshape((width, height, channels), order="F") 56 | return np.transpose(array, (1, 0, 2)).squeeze() 57 | 58 | def bin2depth(inputdepth, output_h5_path): 59 | # depth_map = '0.png.geometric.bin' 60 | # print(depthdir) 61 | # if min_depth_percentile > max_depth_percentile: 62 | # raise ValueError("min_depth_percentile should be less than or equal " 63 | # "to the max_depth_perceintile.") 64 | 65 | # Read depth and normal maps corresponding to the same image. 66 | if not os.path.exists(inputdepth): 67 | raise FileNotFoundError("file not found: {}".format(inputdepth)) 68 | 69 | # np.set_printoptions(threshold=np.inf) 70 | 71 | depth_map = read_array(inputdepth) 72 | # depth_map[depth_map<=0] = 0 73 | min_depth, max_depth = np.percentile(depth_map[depth_map>0], [min_depth_percentile, max_depth_percentile]) 74 | depth_map[depth_map <= 0] = np.nan # 把0和负数都设置为nan,防止被min_depth取代 75 | depth_map[depth_map < min_depth] = min_depth 76 | depth_map[depth_map > max_depth] = max_depth 77 | # depth_map[depth_map<=0] = 0 78 | depth_map = np.nan_to_num(depth_map) 79 | # print("input_bin_path:", inputdepth) 80 | depth_map_shape = depth_map.shape[0]*depth_map.shape[1] 81 | if depth_map_shape<100: 82 | print("depth_map_shape < 480*480!!!! and is:", depth_map_shape) 83 | print("depth name:", inputdepth) 84 | print(np.any(depth_map<0)) # 深度图存在负数 存在0, 取值在0~1之间 85 | 86 | # save depth as h5 FIXME 87 | h5_path = os.path.join(output_h5_path, '.'.join(os.path.basename(inputdepth).split('.')[:-1])+'.h5') 88 | with h5py.File(h5_path, 'w') as f: 89 | f.create_dataset('depth', data=depth_map) 90 | 91 | # bin 2 png 92 | # min_depth, max_depth = np.percentile(depth_map[depth_map>0], [min_depth_percentile, max_depth_percentile]) 93 | # depth_map[depth_map <= 0] = np.nan # 把0和负数都设置为nan,防止被min_depth取代 94 | # depth_map[depth_map < min_depth] = min_depth 95 | # depth_map[depth_map > max_depth] = max_depth 96 | 97 | # maxdisp = fB / min_depth; 98 | # mindisp = fB / max_depth; 99 | # depth_map = (fB/depth_map - mindisp) * 255 / (maxdisp - mindisp); 100 | # depth_map = np.nan_to_num(depth_map*255) # nan全都变为0 101 | # depth_map = depth_map.astype(int) 102 | 103 | # image = Image.fromarray(np.uint8(depth_map)).convert('L') 104 | # image = image.resize((500, 500), Image.ANTIALIAS) # 保证resize为500*500 105 | # print(image) 106 | # print(np.array(image)) 107 | # ouputdepth = os.path.join(outputdir, '.'.join(os.path.basename(inputdepth).split('.')[:-1])+'.png') 108 | # print(ouputdepth) 109 | # image.save(ouputdepth) 110 | 111 | 112 | for depthbin in os.listdir(args.depthmapsdir): 113 | inputdepth = os.path.join(args.depthmapsdir, depthbin) 114 | if os.path.exists(inputdepth): 115 | bin2depth(inputdepth, args.output_h5_path) -------------------------------------------------------------------------------- /terratrack_utils/depth_png2h5.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import os 4 | from PIL import Image 5 | 6 | # convert depth from png to h5 7 | depth_png_path = '/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/reference_SfM_model_500/dense/stereo/depth_maps_png' 8 | output_h5_path = '/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu/reference_SfM_model_500/dense/stereo/depth_bin_h5' 9 | if not os.path.exists(output_h5_path): 10 | os.mkdir(output_h5_path) 11 | 12 | for png_name in os.listdir(depth_png_path): 13 | print("png_name:", png_name) 14 | depth_png = Image.open(os.path.join(depth_png_path, png_name)) 15 | print(depth_png) 16 | png_array = np.array(depth_png) 17 | print(png_array) 18 | 19 | # Save depth map as HDF5 file 20 | h5_path = os.path.join(output_h5_path, '.'.join(png_name.split('.')[:3])+'.h5') 21 | with h5py.File(h5_path, 'w') as f: 22 | f.create_dataset('depth', data=png_array) 23 | 24 | ## Read h5 depth 25 | # for h5_file in os.listdir(output_h5_path): 26 | # print("h5 name:", h5_file) 27 | # depth_h5 = h5py.File(os.path.join(output_h5_path, h5_file), 'r') 28 | # depth = depth_h5['depth'] 29 | # print("h5_depth:",depth) 30 | # depth_array = np.array(depth) 31 | # print("depth_array:", depth_array) 32 | # depth_h5.close() -------------------------------------------------------------------------------- /terratrack_utils/preprocess_terra.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import imagesize 4 | 5 | import numpy as np 6 | 7 | import os 8 | 9 | parser = argparse.ArgumentParser(description='MegaDepth preprocessing script') 10 | 11 | parser.add_argument( 12 | '--base_path', type=str, 13 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu', 14 | help='path to mavic_npu' 15 | ) 16 | 17 | parser.add_argument( 18 | '--output_path', type=str, 19 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/process_output_query_ref_500', 20 | help='path to the output directory' 21 | ) 22 | 23 | 24 | args = parser.parse_args() 25 | 26 | if not os.path.exists(args.output_path): 27 | os.mkdir(args.output_path) 28 | 29 | base_path = args.base_path 30 | 31 | # megadepth TODO 32 | base_depth_path = os.path.join( 33 | base_path, 'query_reference_SfM_model_500/dense/stereo' 34 | ) 35 | 36 | base_undistorted_sfm_path = os.path.join( 37 | base_path, 'query_reference_Undistorted_SfM_500' 38 | ) 39 | 40 | undistorted_sparse_path = os.path.join( 41 | base_undistorted_sfm_path, 'sparse-txt' 42 | ) 43 | 44 | depths_path = os.path.join( 45 | base_depth_path, 'depth_bin_h5' 46 | ) 47 | 48 | 49 | images_path = os.path.join( 50 | base_undistorted_sfm_path, 'images' 51 | ) 52 | 53 | 54 | # Process cameras.txt 55 | with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f: 56 | raw = f.readlines()[3 :] # skip the header 57 | 58 | camera_intrinsics = {} 59 | for camera in raw: 60 | camera = camera.split(' ') 61 | camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]] 62 | 63 | # Process points3D.txt 64 | with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f: 65 | raw = f.readlines()[3 :] # skip the header 66 | 67 | points3D = {} # 3d点 字典:{点id:x,y,z} 68 | for point3D in raw: 69 | point3D = point3D.split(' ') 70 | points3D[int(point3D[0])] = np.array([ 71 | float(point3D[1]), float(point3D[2]), float(point3D[3]) 72 | ]) 73 | 74 | # Process images.txt 75 | with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f: 76 | raw = f.readlines()[4 :] # skip the header 77 | 78 | image_id_to_idx = {} 79 | image_names = [] 80 | raw_pose = [] 81 | camera = [] 82 | points3D_id_to_2D = [] 83 | n_points3D = [] 84 | for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])): 85 | image = image.split(' ') 86 | points = points.split(' ') 87 | 88 | image_id_to_idx[int(image[0])] = idx 89 | 90 | image_name = image[-1].strip('\n') 91 | image_names.append(image_name) 92 | # print("image_name:", image_name) 93 | 94 | raw_pose.append([float(elem) for elem in image[1 : -2]]) # image pose[ QW, QX, QY, QZ, TX, TY, TZ] 95 | camera.append(int(image[-2])) 96 | current_points3D_id_to_2D = {} 97 | for x, y, point3D_id in zip(points[:: 3], points[1 :: 3], points[2 :: 3]): 98 | if int(point3D_id) == -1: # 3d点id为-1的跳过,说明图像中该2d点没有对应的3d点 99 | continue 100 | current_points3D_id_to_2D[int(point3D_id)] = [float(x), float(y)] 101 | points3D_id_to_2D.append(current_points3D_id_to_2D) # 3d点的id与对应图像中2d点的坐标 102 | n_points3D.append(len(current_points3D_id_to_2D)) 103 | n_images = len(image_names) 104 | 105 | # Image and depthmaps paths 106 | image_paths = [] 107 | depth_paths = [] 108 | for image_name in image_names: 109 | image_path = os.path.join(images_path, image_name) 110 | 111 | # Path to the depth file 112 | depth_path = os.path.join( 113 | depths_path, image_name + '.geometric.h5' 114 | ) 115 | # print("depth_path:", depth_path) 116 | 117 | if os.path.exists(depth_path): 118 | # Check if depth map or background / foreground mask 119 | depth_paths.append(depth_path) 120 | print("depth path:", depth_path) 121 | image_paths.append(image_path) 122 | print("image path:", image_path) 123 | else: 124 | depth_paths.append(None) 125 | image_paths.append(None) 126 | 127 | # Camera configuration 128 | intrinsics = [] 129 | poses = [] 130 | principal_axis = [] 131 | points3D_id_to_ndepth = [] 132 | for idx, image_name in enumerate(image_names): 133 | if image_paths[idx] is None: 134 | intrinsics.append(None) 135 | poses.append(None) 136 | principal_axis.append([0, 0, 0]) 137 | points3D_id_to_ndepth.append({}) 138 | continue 139 | image_intrinsics = camera_intrinsics[camera[idx]] 140 | K = np.zeros([3, 3]) 141 | K[0, 0] = image_intrinsics[2] 142 | K[0, 2] = image_intrinsics[4] 143 | K[1, 1] = image_intrinsics[3] 144 | K[1, 2] = image_intrinsics[5] 145 | K[2, 2] = 1 146 | intrinsics.append(K) 147 | 148 | image_pose = raw_pose[idx] 149 | qvec = image_pose[: 4] 150 | qvec = qvec / np.linalg.norm(qvec) 151 | w, x, y, z = qvec 152 | R = np.array([ 153 | [ 154 | 1 - 2 * y * y - 2 * z * z, 155 | 2 * x * y - 2 * z * w, 156 | 2 * x * z + 2 * y * w 157 | ], 158 | [ 159 | 2 * x * y + 2 * z * w, 160 | 1 - 2 * x * x - 2 * z * z, 161 | 2 * y * z - 2 * x * w 162 | ], 163 | [ 164 | 2 * x * z - 2 * y * w, 165 | 2 * y * z + 2 * x * w, 166 | 1 - 2 * x * x - 2 * y * y 167 | ] 168 | ]) 169 | principal_axis.append(R[2, :]) 170 | t = image_pose[4 : 7] 171 | # World-to-Camera pose 172 | current_pose = np.zeros([4, 4]) 173 | current_pose[: 3, : 3] = R 174 | current_pose[: 3, 3] = t 175 | current_pose[3, 3] = 1 176 | # Camera-to-World pose 177 | # pose = np.zeros([4, 4]) 178 | # pose[: 3, : 3] = np.transpose(R) 179 | # pose[: 3, 3] = -np.matmul(np.transpose(R), t) 180 | # pose[3, 3] = 1 181 | poses.append(current_pose) 182 | 183 | current_points3D_id_to_ndepth = {} 184 | for point3D_id in points3D_id_to_2D[idx].keys(): 185 | p3d = points3D[point3D_id] 186 | current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (.5 * (K[0, 0] + K[1, 1])) 187 | points3D_id_to_ndepth.append(current_points3D_id_to_ndepth) 188 | principal_axis = np.array(principal_axis) 189 | # 相机的朝向? 190 | angles = np.rad2deg(np.arccos( 191 | np.clip( 192 | np.dot(principal_axis, np.transpose(principal_axis)), 193 | -1, 1 194 | ) 195 | )) 196 | 197 | # Compute overlap score 198 | overlap_matrix = np.full([n_images, n_images], -1.) 199 | scale_ratio_matrix = np.full([n_images, n_images], -1.) 200 | for idx1 in range(n_images): 201 | if image_paths[idx1] is None or depth_paths[idx1] is None: 202 | continue 203 | for idx2 in range(idx1 + 1, n_images): 204 | if image_paths[idx2] is None or depth_paths[idx2] is None: 205 | continue 206 | # FIXME 计算match按位&的话,如果两个对应的id的顺序不一直呢 207 | # 共同观测到的特征点对应的三维点的 ID 208 | matches = ( 209 | points3D_id_to_2D[idx1].keys() & 210 | points3D_id_to_2D[idx2].keys() 211 | ) 212 | min_num_points3D = min( 213 | len(points3D_id_to_2D[idx1]), len(points3D_id_to_2D[idx2]) 214 | ) # FIXME 计算两个图像之间的重叠率 这个参数并没有用到 215 | overlap_matrix[idx1, idx2] = len(matches) / len(points3D_id_to_2D[idx1]) # min_num_points3D 216 | overlap_matrix[idx2, idx1] = len(matches) / len(points3D_id_to_2D[idx2]) # min_num_points3D 217 | if len(matches) == 0: 218 | continue 219 | points3D_id_to_ndepth1 = points3D_id_to_ndepth[idx1] 220 | points3D_id_to_ndepth2 = points3D_id_to_ndepth[idx2] 221 | nd1 = np.array([points3D_id_to_ndepth1[match] for match in matches]) 222 | nd2 = np.array([points3D_id_to_ndepth2[match] for match in matches]) 223 | # FIXME 这个比例缩放倍率最小值的计算可以用来判断这两个点在相机坐标系下的物理尺寸比例是否接近,从而用于后续的相对姿态估计。 224 | min_scale_ratio = np.min(np.maximum(nd1 / nd2, nd2 / nd1)) 225 | scale_ratio_matrix[idx1, idx2] = min_scale_ratio 226 | scale_ratio_matrix[idx2, idx1] = min_scale_ratio 227 | print(overlap_matrix) 228 | np.savez( 229 | os.path.join(args.output_path, os.path.basename(base_path)+'.npz'), 230 | image_paths=image_paths, 231 | depth_paths=depth_paths, 232 | intrinsics=intrinsics, 233 | poses=poses, 234 | overlap_matrix=overlap_matrix, 235 | scale_ratio_matrix=scale_ratio_matrix, 236 | angles=angles, 237 | n_points3D=n_points3D, 238 | points3D_id_to_2D=points3D_id_to_2D, 239 | points3D_id_to_ndepth=points3D_id_to_ndepth 240 | ) 241 | -------------------------------------------------------------------------------- /terratrack_utils/train_scenes_500.txt: -------------------------------------------------------------------------------- 1 | mavic_npu 2 | -------------------------------------------------------------------------------- /terratrack_utils/train_scenes_origin.txt: -------------------------------------------------------------------------------- 1 | mavic-river 2 | mavic-factory 3 | mavic-fengniao 4 | mavic-hongkong 5 | inspire1-rail-kfs 6 | phantom3-grass-kfs 7 | phantom3-centralPark-kfs 8 | phantom3-npu-kfs 9 | phantom3-freeway-kfs 10 | phantom3-village-kfs 11 | gopro-npu-kfs 12 | gopro-saplings-kfs 13 | -------------------------------------------------------------------------------- /terratrack_utils/undistort_reconstructions_terra.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import imagesize 4 | 5 | import os 6 | 7 | import subprocess 8 | 9 | parser = argparse.ArgumentParser(description='MegaDepth Undistortion') 10 | 11 | parser.add_argument( 12 | '--colmap_path', type=str, 13 | default='/usr/bin', 14 | help='path to colmap executable' 15 | ) 16 | parser.add_argument( 17 | '--base_path', type=str, 18 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/mavic_npu', 19 | help='path to mavic_npu' 20 | ) 21 | 22 | args = parser.parse_args() 23 | 24 | sfm_path = os.path.join( 25 | args.base_path, 'query_reference_SfM_model_500' 26 | ) 27 | base_depth_path = os.path.join( 28 | sfm_path, 'dense' 29 | ) 30 | output_path = os.path.join( 31 | args.base_path, 'query_reference_Undistorted_SfM_500' 32 | ) 33 | 34 | os.mkdir(output_path) 35 | 36 | image_path = os.path.join( 37 | base_depth_path, 'images' 38 | ) 39 | 40 | # Find the maximum image size in scene. 41 | max_image_size = 0 42 | for image_name in os.listdir(image_path): 43 | max_image_size = max( 44 | max_image_size, 45 | max(imagesize.get(os.path.join(image_path, image_name))) 46 | ) 47 | 48 | # Undistort the images and update the reconstruction. 49 | subprocess.call([ 50 | os.path.join(args.colmap_path, 'colmap'), 'image_undistorter', 51 | '--image_path', os.path.join(args.base_path, 'query_reference_images_500'), 52 | '--input_path', os.path.join(sfm_path, 'sparse', '0'), 53 | '--output_path', output_path, 54 | '--max_image_size', str(max_image_size) 55 | ]) 56 | 57 | # Transform the reconstruction to raw text format. 58 | sparse_txt_path = os.path.join(output_path, 'sparse-txt') 59 | os.mkdir(sparse_txt_path) 60 | subprocess.call([ 61 | os.path.join(args.colmap_path, 'colmap'), 'model_converter', 62 | '--input_path', os.path.join(output_path, 'sparse'), 63 | '--output_path', sparse_txt_path, 64 | '--output_type', 'TXT' 65 | ]) -------------------------------------------------------------------------------- /terratrack_utils/valid_scenes_500.txt: -------------------------------------------------------------------------------- 1 | mavic_npu 2 | -------------------------------------------------------------------------------- /terratrack_utils/valid_scenes_origin.txt: -------------------------------------------------------------------------------- 1 | mavic-xjtu 2 | phantom3-huangqi-kfs 3 | -------------------------------------------------------------------------------- /train_terra.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import os 6 | 7 | import shutil 8 | 9 | import torch 10 | import torch.optim as optim 11 | 12 | from torch.utils.data import DataLoader 13 | 14 | from tqdm import tqdm 15 | 16 | import warnings 17 | 18 | from lib.dataset_terra import TerraDataset 19 | from lib.exceptions import NoGradientError 20 | from lib.loss import loss_function 21 | from lib.full_model.model_swin_unet_d2 import Swin_D2UNet 22 | from lib.full_model.model_unet import U2Net 23 | from torch.utils.tensorboard import SummaryWriter 24 | from datetime import datetime 25 | 26 | # CUDA 27 | use_cuda = torch.cuda.is_available() 28 | device = torch.device("cuda:0" if use_cuda else "cpu") 29 | 30 | # Seed 31 | torch.manual_seed(1) 32 | if use_cuda: 33 | torch.cuda.manual_seed(1) 34 | np.random.seed(1) 35 | 36 | # Argument parsing 37 | parser = argparse.ArgumentParser(description='Training script') 38 | 39 | parser.add_argument( 40 | '--dataset_path', type=str, 41 | help='path to the dataset', 42 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack' 43 | ) 44 | parser.add_argument( 45 | '--scene_info_path', type=str, 46 | help='path to the processed scenes', 47 | default='/home/a409/users/huboni/Projects/dataset/TerraTrack/process_output_query_ref_500' 48 | ) 49 | 50 | parser.add_argument( 51 | '--preprocessing', type=str, default='torch', 52 | help='image preprocessing (caffe or torch)' 53 | ) 54 | parser.add_argument( 55 | '--model_file', type=str, default='models/d2_ots.pth', 56 | help='path to the full model' 57 | ) 58 | 59 | parser.add_argument( 60 | '--num_epochs', type=int, default=100, 61 | help='number of training epochs' 62 | ) 63 | # default = le-3 64 | parser.add_argument( 65 | '--lr', type=float, default=1e-3, 66 | help='initial learning rate' 67 | ) 68 | parser.add_argument( 69 | '--batch_size', type=int, default=1, 70 | help='batch size' 71 | ) 72 | parser.add_argument( 73 | '--num_workers', type=int, default=4, 74 | help='number of workers for data loading' 75 | ) 76 | 77 | parser.add_argument( 78 | '--use_validation', dest='use_validation', action='store_true', 79 | help='use the validation split' 80 | ) 81 | parser.set_defaults(use_validation=True) 82 | 83 | parser.add_argument( 84 | '--log_interval', type=int, default=2000, 85 | help='loss logging interval' 86 | ) 87 | parser.add_argument( 88 | '--plot', dest='plot', action='store_true', 89 | help='plot training pairs' 90 | ) 91 | parser.set_defaults(plot=True) 92 | 93 | parser.add_argument( 94 | '--checkpoint_directory', type=str, default='checkpoints', 95 | help='directory for training checkpoints' 96 | ) 97 | 98 | parser.add_argument( 99 | '--net', type=str, default='vgg', 100 | help='choose net vgg or swin' 101 | ) 102 | args = parser.parse_args() 103 | 104 | print(args) 105 | 106 | # Create the folders for plotting if need be 107 | if args.plot: 108 | plot_path = 'train_vis_56_true' 109 | if os.path.isdir(plot_path): 110 | print('[Warning] Plotting directory already exists.') 111 | else: 112 | os.mkdir(plot_path) 113 | 114 | if args.net=='swin': 115 | model = Swin_D2UNet( 116 | # model_file=args.model_file, 117 | use_cuda=use_cuda 118 | ) 119 | elif args.net=='unet': 120 | model = U2Net( 121 | model_file=args.model_file, 122 | use_cuda=use_cuda 123 | ) 124 | 125 | # Optimizer 126 | optimizer = optim.Adam( 127 | filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr 128 | ) 129 | # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.00001) 130 | 131 | # Dataset 132 | if args.use_validation: 133 | validation_dataset = TerraDataset( 134 | scene_list_path='terratrack_utils/valid_scenes_500.txt', 135 | scene_info_path=args.scene_info_path, 136 | base_path=args.dataset_path, 137 | train=False, 138 | preprocessing=args.preprocessing, 139 | pairs_per_scene=2 140 | ) 141 | validation_dataloader = DataLoader( 142 | validation_dataset, 143 | batch_size=args.batch_size, 144 | num_workers=args.num_workers 145 | ) 146 | 147 | training_dataset = TerraDataset( 148 | scene_list_path='terratrack_utils/train_scenes_500.txt', 149 | scene_info_path=args.scene_info_path, 150 | base_path=args.dataset_path, 151 | preprocessing=args.preprocessing 152 | ) 153 | training_dataloader = DataLoader( 154 | training_dataset, 155 | batch_size=args.batch_size, 156 | num_workers=args.num_workers 157 | ) 158 | 159 | 160 | # Define epoch function 161 | def process_epoch( 162 | epoch_idx, 163 | model, loss_function, optimizer, dataloader, device, 164 | log_file, args, writer, train=True 165 | ): 166 | epoch_losses = [] 167 | 168 | torch.set_grad_enabled(train) 169 | 170 | progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) 171 | # max_iterations = 20*len(progress_bar) # 20 is epoch 172 | nBatches = len(progress_bar) 173 | for batch_idx, batch in progress_bar: 174 | if train: 175 | optimizer.zero_grad() 176 | 177 | batch['train'] = train 178 | batch['epoch_idx'] = epoch_idx 179 | batch['batch_idx'] = batch_idx 180 | batch['batch_size'] = args.batch_size 181 | batch['preprocessing'] = args.preprocessing 182 | batch['log_interval'] = args.log_interval 183 | 184 | try: 185 | loss = loss_function(model, batch, device, plot=args.plot) 186 | except NoGradientError: 187 | continue 188 | 189 | current_loss = loss.data.cpu().numpy()[0] 190 | if train: 191 | writer.add_scalar('Train/CurrentBatchLoss', current_loss, (epoch_idx - 1) * nBatches + batch_idx) 192 | else: 193 | writer.add_scalar('Valid/CurrentBatchLoss', current_loss, (epoch_idx - 1) * nBatches + batch_idx) 194 | 195 | epoch_losses.append(current_loss) 196 | 197 | progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses))) 198 | 199 | if batch_idx % args.log_interval == 0: 200 | log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % ( 201 | 'train' if train else 'valid', 202 | epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses) 203 | )) 204 | 205 | if train: 206 | loss.backward() 207 | optimizer.step() 208 | # lr_ = args.lr * (1.0 - batch_idx*epoch_idx / max_iterations) ** 0.9 209 | # for param_group in optimizer.param_groups: 210 | # param_group['lr'] = lr_ 211 | 212 | 213 | log_file.write('[%s] epoch %d - avg_loss: %f\n' % ( 214 | 'train' if train else 'valid', 215 | epoch_idx, 216 | np.mean(epoch_losses) 217 | )) 218 | if train: 219 | writer.add_scalar('Train/AvgLoss', np.mean(epoch_losses), epoch_idx) 220 | else: 221 | writer.add_scalar('Valid/AvgLoss', np.mean(epoch_losses), epoch_idx) 222 | 223 | log_file.flush() 224 | 225 | return np.mean(epoch_losses) 226 | 227 | writer = SummaryWriter(log_dir=os.path.join(args.checkpoint_directory, datetime.now().strftime('%b%d_%H-%M-%S')+'_'+args.net)) 228 | # Create the checkpoint directory 229 | logdir = writer.file_writer.get_logdir() 230 | save_checkpoint_path = os.path.join(logdir, 'checkpoints') 231 | 232 | if os.path.isdir(save_checkpoint_path): 233 | print('[Warning] Checkpoint directory already exists.') 234 | else: 235 | os.mkdir(save_checkpoint_path) 236 | 237 | log_file = os.path.join(logdir, 'log.txt') 238 | # Open the log file for writing 239 | if os.path.exists(log_file): 240 | print('[Warning] Log file already exists.') 241 | log_file = open(log_file, 'a+') 242 | log_file.write("args:" + str(args)) 243 | 244 | # Initialize the history 245 | train_loss_history = [] 246 | validation_loss_history = [] 247 | if args.use_validation: 248 | validation_dataset.build_dataset() 249 | min_validation_loss = process_epoch( 250 | 0, 251 | model, loss_function, optimizer, validation_dataloader, device, 252 | log_file, args, writer, 253 | train=False 254 | ) 255 | 256 | # Start the training 257 | for epoch_idx in range(1, args.num_epochs + 1): 258 | # Process epoch 259 | print("epoch :", epoch_idx) 260 | training_dataset.build_dataset() 261 | train_loss_history.append( 262 | process_epoch( 263 | epoch_idx, 264 | model, loss_function, optimizer, training_dataloader, device, 265 | log_file, args, writer 266 | ) 267 | ) 268 | 269 | if args.use_validation: 270 | validation_loss_history.append( 271 | process_epoch( 272 | epoch_idx, 273 | model, loss_function, optimizer, validation_dataloader, device, 274 | log_file, args, writer, 275 | train=False 276 | ) 277 | ) 278 | 279 | # Save the current checkpoint 280 | checkpoint_path = os.path.join( 281 | save_checkpoint_path, 282 | 'checkpoint.pth' 283 | ) 284 | checkpoint = { 285 | 'args': args, 286 | 'epoch_idx': epoch_idx, 287 | 'model': model.state_dict(), 288 | 'optimizer': optimizer.state_dict(), 289 | 'train_loss_history': train_loss_history, 290 | 'validation_loss_history': validation_loss_history 291 | } 292 | torch.save(checkpoint, checkpoint_path) 293 | 294 | if ( 295 | args.use_validation and 296 | validation_loss_history[-1] < min_validation_loss 297 | ): 298 | min_validation_loss = validation_loss_history[-1] 299 | best_checkpoint_path = os.path.join( 300 | save_checkpoint_path, 301 | 'best.pth' ) 302 | shutil.copy(checkpoint_path, best_checkpoint_path) 303 | 304 | # Close the log file 305 | log_file.close() 306 | writer.close() 307 | --------------------------------------------------------------------------------