├── .gitattributes ├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── ckpt └── CrossScore-v1.0.0.ckpt ├── config ├── data │ ├── MapFreeReloc540.yaml │ ├── RealEstate_540.yaml │ ├── SimpleReference.yaml │ ├── combined_testing.yaml │ ├── combined_training.yaml │ └── mip360.yaml ├── default.yaml ├── default_predict.yaml ├── default_test.yaml └── model │ └── model.yaml ├── dataloading ├── data_manager.py ├── dataset │ ├── nvs_dataset.py │ └── simple_reference.py └── transformation │ └── crop.py ├── environment.yaml ├── model ├── cross_reference.py ├── customised_transformer │ └── transformer.py ├── positional_encoding.py └── regression_layer.py ├── predict.sh ├── pyrightconfig.json ├── task ├── core.py ├── predict.py ├── test.py └── train.py └── utils ├── check_config.py ├── data_processing └── split_gaussian_processed.py ├── evaluation ├── metric.py ├── metric_logger.py └── summarise_score_gt.py ├── io ├── batch_writer.py ├── images.py └── score_summariser.py ├── misc └── image.py ├── neighbour └── sampler.py └── plot └── batch_visualiser.py /.gitattributes: -------------------------------------------------------------------------------- 1 | ckpt/CrossScore-v1.0.0.ckpt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | predict/ 3 | 4 | playground.py 5 | tmp.sh 6 | 7 | log* 8 | datadir* 9 | debug* 10 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.watcherExclude": { 3 | "**/datadir/**": true, 4 | "**/log*/**": true 5 | }, 6 | "search.exclude": { 7 | "**/datadir/**": true, 8 | "**/log*/**": true, 9 | "datadir/**": true, 10 | "log*/**": true 11 | }, 12 | "python.analysis.exclude": [ 13 | "**/datadir/**" 14 | ], 15 | "black-formatter.args": [ 16 | "--line-length=100" 17 | ] 18 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Zirui Wang 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CrossScore: Towards Multi-View Image Evaluation and Scoring 2 | 3 | **[Project Page](https://crossscore.active.vision) | 4 | [arXiv](https://arxiv.org/abs/2404.14409)** 5 | 6 | [Zirui Wang](https://scholar.google.com/citations?user=zCBKqa8AAAAJ&hl=en), 7 | [Wenjing Bian](https://scholar.google.com/citations?user=IVfbqkgAAAAJ&hl=en), 8 | [Victor Adrian Prisacariu](http://www.robots.ox.ac.uk/~victor). 9 | 10 | [Active Vision Lab (AVL)](https://www.robots.ox.ac.uk/~lav), 11 | University of Oxford. 12 | 13 | 14 | ## Table of Content 15 | - [Environment](#Environment) 16 | - [Data](#Data) 17 | - [Training](#Training) 18 | - [Inferencing](#Inferencing) 19 | 20 | ## Environment 21 | We provide a `environment.yaml` file to set up a `conda` environment: 22 | ```bash 23 | git clone https://github.com/ActiveVisionLab/CrossScore.git 24 | cd CrossScore 25 | conda env create -f environment.yaml 26 | conda activate CrossScore 27 | ``` 28 | 29 | ## Data 30 | **TLDR**: download this 31 | [file](https://www.robots.ox.ac.uk/~ryan/CrossScore/MFR_subset_demo.tar.gz) (~3GB), 32 | put it in `datadir`: 33 | ```bash 34 | mkdir datadir 35 | cd datadir 36 | wget https://www.robots.ox.ac.uk/~ryan/CrossScore/MFR_subset_demo.tar.gz 37 | tar -xzvf MFR_subset_demo.tar.gz 38 | rm MFR_subset_demo.tar.gz 39 | cd .. 40 | ``` 41 | 42 | To demonstrate a minimum working example for training and inferencing steps shown below, 43 | we provide a small pre-processed subset. 44 | The is a subset of 45 | [Map-Free Relocalisation (MFR)](https://research.nianticlabs.com/mapfree-reloc-benchmark/dataset) 46 | and is pre-processed using 47 | [3D Gaussian Splatting (3DGS)](https://github.com/graphdeco-inria/gaussian-splatting). 48 | This small demo dataset is available at this 49 | [link](https://www.robots.ox.ac.uk/~ryan/CrossScore/MFR_subset_demo.tar.gz) (~3GB). 50 | This is the file in TLDR. 51 | We only use this demo subset to present the expected dataloading structure. 52 | 53 | In our actual training, our model is trained using MFR that pre-processed by three NVS methods: 54 | [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), 55 | [TensoRF](https://docs.nerf.studio/nerfology/methods/tensorf.html), and 56 | [NeRFacto](https://docs.nerf.studio/nerfology/methods/nerfacto.html). 57 | Due to the preprocessed file size (~2TB), it is challenging to directly share 58 | this pre-processed data. One work around is to release a data pre-processing script 59 | for MFR, which we are still tidying up. 60 | **We aim to release the pre-processing script in Dec 2024.** 61 | 62 | ## Training 63 | We train our model with two NVIDIA A5000 (24GB) GPUs for about two days. 64 | However, the model should perform reasonably well after 12 hours of training. 65 | It is also possible to train with a single GPU. 66 | ```bash 67 | python task/train.py trainer.devices='[0,1]' # 2 GPUs 68 | # python task/train.py trainer.devices='[0]' # 1 GPU 69 | ``` 70 | 71 | ## Inferencing 72 | We provide an example command to predict CrossScore for NVS rendered images 73 | by referencing real captured images. 74 | ```bash 75 | git lfs install && git lfs pull # get our ckpt using git LFS 76 | bash predict.sh 77 | ``` 78 | After running the script, our CrossScore score maps should be written to `predict` dir. 79 | The output should be similar to our 80 | [demo video](https://crossscore.active.vision/assets/additional_results.mp4) 81 | on our project page. 82 | 83 | ## Todo 84 | - [ ] Create a HuggingFace demo page. 85 | - [ ] Release ECCV quantitative results related scripts. 86 | - [ ] Release data processing scripts 87 | - [ ] Release PyPI and Conda package. 88 | 89 | ## Acknowledgement 90 | This research is supported by an 91 | [ARIA](https://facebookresearch.github.io/projectaria_tools/docs/intro) 92 | research gift grant from Meta Reality Lab. We gratefully thank 93 | [Shangzhe Wu](http://elliottwu.com), 94 | [Tengda Han](https://tengdahan.github.io/), 95 | [Zihang Lai](https://scholar.google.com/citations?user=31eXgMYAAAAJ&hl=en) for insightful discussions, and 96 | [Michael Hobley](https://portraits.keble.net/2022/michael-hobley) for proofreading. 97 | 98 | ## Citation 99 | ```bibtex 100 | @inproceedings{wang2024crossscore, 101 | title={CrossScore: Towards Multi-View Image Evaluation and Scoring}, 102 | author={Zirui Wang and Wenjing Bian and Victor Adrian Prisacariu}, 103 | booktitle={ECCV}, 104 | year={2024} 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /ckpt/CrossScore-v1.0.0.ckpt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1974040df3a16c0d93a4af90685dd923b1c52ed1b3b2aedd68071e2371514cac 3 | size 129050414 4 | -------------------------------------------------------------------------------- /config/data/MapFreeReloc540.yaml: -------------------------------------------------------------------------------- 1 | loader: 2 | train: 3 | batch_size: 32 4 | num_workers: 8 5 | shuffle: True 6 | pin_memory: True 7 | persistent_workers: True 8 | prefetch_factor: 2 9 | 10 | validation: 11 | batch_size: 32 12 | num_workers: 8 13 | shuffle: True 14 | pin_memory: True 15 | persistent_workers: True 16 | prefetch_factor: 2 17 | 18 | dataset: 19 | path: datadir/processed_training_ready/gaussian/map-free-reloc 20 | resolution: res_540 21 | num_gaussians_iters: -1 # set -1 to use all 22 | zero_reference: False 23 | 24 | neighbour_config: 25 | strategy: random 26 | cross: 5 27 | deterministic: False 28 | 29 | transforms: 30 | crop_size: 518 31 | -------------------------------------------------------------------------------- /config/data/RealEstate_540.yaml: -------------------------------------------------------------------------------- 1 | loader: 2 | train: 3 | batch_size: 32 4 | num_workers: 8 5 | shuffle: True 6 | pin_memory: True 7 | persistent_workers: True 8 | prefetch_factor: 2 9 | 10 | validation: 11 | batch_size: 32 12 | num_workers: 8 13 | shuffle: True 14 | pin_memory: True 15 | persistent_workers: True 16 | prefetch_factor: 2 17 | 18 | dataset: 19 | path: datadir/processed_training_ready/gaussian/RealEstate200 20 | resolution: res_540 21 | num_gaussians_iters: -1 # set -1 to use all 22 | zero_reference: False 23 | 24 | neighbour_config: 25 | strategy: random 26 | cross: 5 27 | deterministic: False 28 | 29 | transforms: 30 | crop_size: 518 31 | -------------------------------------------------------------------------------- /config/data/SimpleReference.yaml: -------------------------------------------------------------------------------- 1 | loader: 2 | validation: 3 | batch_size: 8 4 | num_workers: 8 5 | shuffle: True 6 | pin_memory: True 7 | persistent_workers: False 8 | prefetch_factor: 2 9 | 10 | dataset: 11 | query_dir: null 12 | reference_dir: null 13 | resolution: res_540 14 | zero_reference: False 15 | 16 | neighbour_config: 17 | strategy: random 18 | cross: 5 19 | deterministic: False 20 | 21 | transforms: 22 | crop_size: 518 23 | -------------------------------------------------------------------------------- /config/data/combined_testing.yaml: -------------------------------------------------------------------------------- 1 | loader: 2 | train: 3 | batch_size: 24 4 | num_workers: 6 5 | shuffle: True 6 | pin_memory: True 7 | persistent_workers: True 8 | prefetch_factor: 2 9 | 10 | validation: 11 | batch_size: 24 12 | num_workers: 6 13 | shuffle: True 14 | pin_memory: True 15 | persistent_workers: True 16 | prefetch_factor: 2 17 | 18 | dataset: 19 | path: 20 | - datadir/processed_training_ready/gaussian/map-free-reloc 21 | # - datadir/processed_training_ready/nerfacto/map-free-reloc 22 | # - datadir/processed_training_ready/tensorf/map-free-reloc 23 | # - datadir/processed_testing_ready/gaussian/RealEstate200 24 | # - datadir/processed_testing_ready/gaussian/mip360 25 | # - datadir/processed_testing_ready/ibrnet/map-free-reloc 26 | # - datadir/processed_testing_ready/pixelnerf/map-free-reloc 27 | resolution: null # use the first resolution found, usually only one 28 | num_gaussians_iters: -1 # set -1 to use all 29 | zero_reference: False 30 | 31 | neighbour_config: 32 | strategy: random 33 | cross: 5 34 | deterministic: False 35 | 36 | transforms: 37 | crop_size: 518 38 | -------------------------------------------------------------------------------- /config/data/combined_training.yaml: -------------------------------------------------------------------------------- 1 | loader: 2 | train: 3 | batch_size: 24 4 | num_workers: 6 5 | shuffle: True 6 | pin_memory: True 7 | persistent_workers: True 8 | prefetch_factor: 2 9 | 10 | validation: 11 | batch_size: 24 12 | num_workers: 6 13 | shuffle: True 14 | pin_memory: True 15 | persistent_workers: True 16 | prefetch_factor: 2 17 | 18 | dataset: 19 | path: 20 | # -------------- This is only a tiny example for demo. 21 | # We did not train our model using this subset. 22 | - datadir/processed_training_ready/gaussian/map-free-reloc 23 | # -------------- Below are the actual data used for training (~2TB), 24 | # which is challenging to share. See our github repo 25 | # https://github.com/ActiveVisionLab/CrossScore.git for more details and updates. 26 | # - datadir/processed_training_ready/gaussian/map-free-reloc 27 | # - datadir/processed_training_ready/nerfacto/map-free-reloc 28 | # - datadir/processed_training_ready/tensorf/map-free-reloc 29 | resolution: null # use the first resolution found, usually only one 30 | num_gaussians_iters: -1 # set -1 to use all 31 | zero_reference: False 32 | 33 | neighbour_config: 34 | strategy: random 35 | cross: 5 36 | deterministic: False 37 | 38 | transforms: 39 | crop_size: 518 40 | -------------------------------------------------------------------------------- /config/data/mip360.yaml: -------------------------------------------------------------------------------- 1 | loader: 2 | train: 3 | batch_size: 32 4 | num_workers: 8 5 | shuffle: True 6 | pin_memory: True 7 | persistent_workers: True 8 | prefetch_factor: 2 9 | 10 | validation: 11 | batch_size: 32 12 | num_workers: 8 13 | shuffle: True # to vis diff batches 14 | pin_memory: True 15 | persistent_workers: True 16 | prefetch_factor: 2 17 | 18 | dataset: 19 | path: datadir/processed_training_ready/gaussian/mip360 20 | resolution: res_400 # def res_400 21 | num_gaussians_iters: -1 # set -1 to use all 22 | zero_reference: False 23 | 24 | neighbour_config: 25 | strategy: random 26 | cross: 5 27 | deterministic: False 28 | 29 | transforms: 30 | crop_size: 518 31 | -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ # overriding order, see https://hydra.cc/docs/tutorials/structured_config/defaults/#a-note-about-composition-order 3 | - data: combined_training 4 | - model: model 5 | 6 | hydra: 7 | run: 8 | dir: log/${now:%Y%m%d}_${now:%H%M%S.%f}_${alias} 9 | 10 | lightning: 11 | seed: 1 12 | 13 | project: 14 | name: CrossScore 15 | 16 | alias: "" 17 | 18 | trainer: 19 | accelerator: gpu 20 | devices: [0] 21 | # devices: [0, 1] 22 | precision: 16-mixed 23 | 24 | max_epochs: 9 25 | max_steps: -1 # def -1 26 | 27 | overfit_batches: 0 # def 0 28 | limit_train_batches: 1.0 29 | limit_val_batches: 1.0 30 | num_sanity_val_steps: 2 # def 2 31 | 32 | log_every_n_steps: 50 # def 50 33 | ckpt_path_to_load: null 34 | 35 | checkpointing: 36 | every_n_train_steps: null 37 | every_n_epochs: null 38 | train_time_interval: 2 # hours 39 | save_last: True 40 | save_top_k: -1 41 | 42 | optimizer: 43 | type: AdamW 44 | lr: 5e-4 45 | 46 | lr_scheduler: 47 | type: StepLR 48 | step_size: 100 49 | gamma: 0.5 50 | step_interval: epoch # or step 51 | 52 | do_profiling: False 53 | 54 | logger: 55 | vis_imgs_every_n_train_steps: 100 56 | vis_scalar_every_n_train_steps: 100 57 | vis_histogram_every_n_train_steps: 100 58 | 59 | cache_size: 60 | train: 61 | n_scalar: 10 62 | 63 | validation: 64 | n_fig: 2 65 | 66 | this_main: 67 | resize_short_side: 540 68 | -------------------------------------------------------------------------------- /config/default_predict.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ # overriding order, see https://hydra.cc/docs/tutorials/structured_config/defaults/#a-note-about-composition-order 3 | - data: SimpleReference 4 | - model: model 5 | - override hydra/hydra_logging: disabled 6 | - override hydra/job_logging: disabled 7 | 8 | hydra: 9 | output_subdir: null 10 | run: 11 | dir: . 12 | 13 | lightning: 14 | seed: 1 15 | 16 | project: 17 | name: CrossScore 18 | 19 | alias: "" 20 | 21 | trainer: 22 | accelerator: gpu 23 | devices: [0] 24 | # devices: [0, 1] 25 | precision: 16-mixed 26 | 27 | limit_test_batches: 1.0 28 | ckpt_path_to_load: null 29 | 30 | logger: 31 | predict: 32 | out_dir: null # if null, use ckpt dir 33 | write: 34 | flag: 35 | batch: True 36 | score_map_prediction: True 37 | item_path_json: False 38 | score_map_gt: False 39 | attn_weights: False 40 | image_query: True 41 | image_reference: True 42 | config: 43 | vis_img_every_n_steps: 1 # -1: off 44 | score_map_colour_mode: rgb # gray or rgb, use rgb for vis 45 | 46 | this_main: 47 | resize_short_side: 518 # set to -1 to disable 48 | crop_mode: null # default no crop 49 | 50 | force_batch_size: False 51 | -------------------------------------------------------------------------------- /config/default_test.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ # overriding order, see https://hydra.cc/docs/tutorials/structured_config/defaults/#a-note-about-composition-order 3 | - data: combined_training 4 | - model: model 5 | - override hydra/hydra_logging: disabled 6 | - override hydra/job_logging: disabled 7 | 8 | hydra: 9 | output_subdir: null 10 | run: 11 | dir: . 12 | 13 | lightning: 14 | seed: 1 15 | 16 | project: 17 | name: CrossScore 18 | 19 | alias: "" 20 | 21 | trainer: 22 | accelerator: gpu 23 | devices: [0] 24 | # devices: [0, 1] 25 | precision: 16-mixed 26 | 27 | limit_test_batches: 1.0 28 | ckpt_path_to_load: null 29 | 30 | logger: 31 | test: 32 | out_dir: null # if null, use ckpt dir 33 | sync_dist: True 34 | on_step: False # def False 35 | write: 36 | flag: 37 | batch: True 38 | score_map_prediction: True 39 | item_path_json: True 40 | score_map_gt: False 41 | attn_weights: False 42 | image_query: False 43 | image_reference: False 44 | config: 45 | vis_img_every_n_steps: 1 # -1: off 46 | score_map_colour_mode: gray # def gray, use rgb for vis 47 | 48 | this_main: 49 | resize_short_side: 518 # set to -1 to disable 50 | crop_mode: integer_patches # default, crop imagess to the closest patchifiable size 51 | # crop_mode: null # no crop at all 52 | # crop_mode: dataset_default # crop using data config 53 | 54 | force_batch_size: False 55 | data_split: test # def test 56 | -------------------------------------------------------------------------------- /config/model/model.yaml: -------------------------------------------------------------------------------- 1 | patch_size: 14 2 | do_reference_cross: True 3 | 4 | decoder_do_self_attn: True 5 | decoder_do_short_cut: True 6 | need_attn_weights: False # def False, requires more gpu mem if True 7 | need_attn_weights_head_id: 0 # check which attn head 8 | 9 | backbone: 10 | from_pretrained: facebook/dinov2-small 11 | 12 | pos_enc: 13 | multi_view: 14 | interpolate_mode: bilinear 15 | req_grad: False 16 | h: 40 # def 40 so we always interpolate in training, could be 37 or 16 too. 17 | w: 40 18 | 19 | loss: 20 | fn: l1 21 | 22 | predict: 23 | metric: 24 | type: ssim 25 | # type: mae 26 | # type: mse 27 | 28 | # min: -1 29 | min: 0 30 | max: 1 31 | 32 | power_factor: default # can be a scalar -------------------------------------------------------------------------------- /dataloading/data_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import OmegaConf, ListConfig 3 | from dataloading.dataset.nvs_dataset import NvsDataset 4 | from pprint import pprint 5 | 6 | 7 | def get_dataset(cfg, transforms, data_split, return_item_paths=False): 8 | if isinstance(cfg.data.dataset.path, str): 9 | dataset_path_list = [cfg.data.dataset.path] 10 | elif isinstance(cfg.data.dataset.path, ListConfig): 11 | dataset_path_list = OmegaConf.to_object(cfg.data.dataset.path) 12 | else: 13 | raise ValueError("cfg.data.dataset.path should be a string or a ListConfig") 14 | 15 | N_dataset = len(dataset_path_list) 16 | print(f"Get {N_dataset} datasets for {data_split}:") 17 | pprint(f"cfg.data.dataset.path: {dataset_path_list}") 18 | print("==================================") 19 | 20 | dataset_list = [] 21 | for i in range(N_dataset): 22 | dataset_list.append( 23 | NvsDataset( 24 | dataset_path=dataset_path_list[i], 25 | resolution=cfg.data.dataset.resolution, 26 | data_split=data_split, 27 | transforms=transforms, 28 | neighbour_config=cfg.data.neighbour_config, 29 | metric_type=cfg.model.predict.metric.type, 30 | metric_min=cfg.model.predict.metric.min, 31 | metric_max=cfg.model.predict.metric.max, 32 | return_item_paths=return_item_paths, 33 | num_gaussians_iters=cfg.data.dataset.num_gaussians_iters, 34 | zero_reference=cfg.data.dataset.zero_reference, 35 | ) 36 | ) 37 | if N_dataset == 1: 38 | dataset = dataset_list[0] 39 | else: 40 | dataset = torch.utils.data.ConcatDataset(dataset_list) 41 | return dataset 42 | -------------------------------------------------------------------------------- /dataloading/dataset/nvs_dataset.py: -------------------------------------------------------------------------------- 1 | import os, sys, json 2 | from pathlib import Path 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from omegaconf import OmegaConf 7 | 8 | sys.path.append(str(Path(__file__).parents[2])) 9 | from utils.io.images import metric_map_read, image_read 10 | from utils.neighbour.sampler import SamplerFactory 11 | from utils.check_config import ConfigChecker 12 | 13 | 14 | class NeighbourSelector: 15 | """Return paths for neighbouring images (and metric maps) for query and reference.""" 16 | 17 | def __init__(self, paths, neighbour_config): 18 | self.paths = paths 19 | self.neighbour_config = neighbour_config 20 | 21 | self.idx_to_property_mapper = self._build_idx_to_property_mapper(self.paths) 22 | self.all_scene_names = np.array(sorted(self.paths.keys())) 23 | 24 | self.neighbour_sampler_ref_cross = None 25 | if self.neighbour_config["cross"] > 0: 26 | # since query is not in cross ref set, we can only do random sampling 27 | self.neighbour_sampler_ref_cross = SamplerFactory( 28 | strategy_name="random", 29 | N_sample=self.neighbour_config["cross"], 30 | include_query=False, 31 | deterministic=self.neighbour_config["deterministic"], 32 | ) 33 | 34 | @staticmethod 35 | def _build_idx_to_property_mapper(paths): 36 | """Only consider query, since the dataloading idx is based on number of query images""" 37 | 38 | scene_name_list = sorted(paths.keys()) 39 | gaussian_split_list = ["train", "test"] 40 | global_idx = 0 41 | 42 | idx_to_property_mapper = {} 43 | for sn in scene_name_list: 44 | for gs_split in gaussian_split_list: 45 | if f"gs_{gs_split}" not in paths[sn].keys(): 46 | continue 47 | n_iter = paths[sn][f"gs_{gs_split}"]["query"]["N_iters"] 48 | n_imgs_per_iter = paths[sn][f"gs_{gs_split}"]["query"]["N_imgs_per_iter"] 49 | n_imgs = n_iter * n_imgs_per_iter 50 | for idx in range(n_imgs): 51 | idx_to_property_mapper[global_idx] = { 52 | "scene_name": sn, 53 | "gaussian_split": gs_split, 54 | "iter_idx": idx // n_imgs_per_iter, 55 | "img_idx": idx % n_imgs_per_iter, 56 | } 57 | global_idx += 1 58 | return idx_to_property_mapper 59 | 60 | def __len__(self): 61 | return len(self.idx_to_property_mapper) 62 | 63 | def __getitem__(self, idx): 64 | results = { 65 | "query/img": None, 66 | "query/score_map": None, 67 | "reference/cross/imgs": [], 68 | } 69 | 70 | scene_name, gaussian_split, iter_idx, img_idx = self.idx_to_property_mapper[idx].values() 71 | tmp_split_paths = self.paths[scene_name][f"gs_{gaussian_split}"] 72 | iter_name = list(tmp_split_paths["query"]["images"].keys())[iter_idx] 73 | 74 | results["query/img"] = tmp_split_paths["query"]["images"][iter_name][img_idx] 75 | results["query/score_map"] = tmp_split_paths["query"]["score_map"][iter_name][img_idx] 76 | 77 | # sampling reference images for cross set 78 | if self.neighbour_sampler_ref_cross is not None: 79 | ref_list_cross = tmp_split_paths["reference"]["cross"]["images"][iter_name] 80 | results["reference/cross/imgs"] = self.neighbour_sampler_ref_cross( 81 | query=None, ref_list=ref_list_cross 82 | ) 83 | 84 | return results 85 | 86 | 87 | class NvsDataset(Dataset): 88 | 89 | def __init__( 90 | self, 91 | dataset_path, 92 | resolution, 93 | data_split, 94 | transforms, 95 | neighbour_config, 96 | metric_type, 97 | metric_min, 98 | metric_max, 99 | return_debug_info=False, 100 | return_item_paths=False, 101 | **kwargs, 102 | ): 103 | """ 104 | :param scene_path: Gaussian Splatting output scene dir that contains point cloud, test, train etc. 105 | :param query_split: train or test 106 | :param transforms: a dict of transforms for all, img, metric_map 107 | """ 108 | self.transforms = transforms 109 | self.neighbour_config = neighbour_config 110 | self.return_debug_info = return_debug_info 111 | self.return_item_paths = return_item_paths 112 | self.zero_reference = kwargs.get("zero_reference", False) 113 | self.num_gaussians_iters = kwargs.get("num_gaussians_iters", -1) 114 | 115 | if data_split not in ["train", "test", "val", "val_small", "test_small"]: 116 | raise ValueError(f"Unknown data_split {data_split}") 117 | 118 | self._detect_conflict_transforms() 119 | self.metric_config = self._build_metric_config(metric_type, metric_min, metric_max) 120 | 121 | # read split json for scene names 122 | if resolution is None: 123 | resolution = os.listdir(dataset_path)[0] 124 | self.dataset_path = Path(dataset_path, resolution) 125 | with open(self.dataset_path / "split.json", "r") as f: 126 | scene_names = json.load(f)[data_split] 127 | scene_paths = [self.dataset_path / n for n in sorted(scene_names)] 128 | 129 | # We use same split for all processed methods (e.g. gaussian, nerfacto, etc.). 130 | # Some scenes may not be processed by some methods, so we need to filter out. 131 | scene_paths = [p for p in scene_paths if p.exists()] 132 | 133 | # Define query and ref sets. Get all paths points to images and metric maps. 134 | self.all_paths = self.get_paths( 135 | scene_paths, self.num_gaussians_iters, self.metric_config.load_dir 136 | ) 137 | 138 | self.neighbour_selector = NeighbourSelector( 139 | self.all_paths, 140 | self.neighbour_config, 141 | ) 142 | 143 | def __getitem__(self, idx): 144 | # neighouring logic 145 | item_paths = self.neighbour_selector[idx] 146 | 147 | # load content from related paths 148 | result = self.load_content(item_paths, self.zero_reference, self.metric_config) 149 | 150 | if "resize" in self.transforms: 151 | result = self.resize_all(result) 152 | 153 | if "crop_integer_patches" in self.transforms: 154 | result = self.adaptive_crop_integer_patches_all(result) 155 | 156 | if self.return_debug_info: 157 | result["debug"] = { 158 | "query/ori_img": result["query/img"], 159 | "query/ori_score_map": result["query/score_map"], 160 | "reference/cross/ori_imgs": result["reference/cross/imgs"], 161 | } 162 | 163 | if self.return_item_paths: 164 | result["item_paths"] = item_paths 165 | 166 | # apply transforms to query 167 | transformed_query = self.transform_query( 168 | result["query/img"], 169 | result["query/score_map"], 170 | ) 171 | result["query/img"] = transformed_query["img"] 172 | result["query/score_map"] = transformed_query["score_map"] 173 | if self.return_debug_info: 174 | result["debug"]["query/crop_param"] = transformed_query["crop_param"] 175 | 176 | if self.neighbour_config["cross"] > 0: 177 | transformed_ref_cross = self.transform_reference(result["reference/cross/imgs"]) 178 | result["reference/cross/imgs"] = transformed_ref_cross["imgs"] 179 | if self.return_debug_info: 180 | result["debug"]["reference/cross/crop_param"] = transformed_ref_cross["crop_param"] 181 | else: 182 | del result["reference/cross/imgs"] 183 | return result 184 | 185 | @staticmethod 186 | def collate_fn_debug(batch): 187 | """Only return the first item in the batch, because the original images 188 | before cropping are in different sizes. 189 | Using [None] to add a batch dimension at the front. 190 | """ 191 | result = { 192 | "query/img": batch[0]["query/img"][None], 193 | "query/score_map": batch[0]["query/score_map"][None], 194 | } 195 | 196 | result["debug"] = { 197 | "query/ori_img": batch[0]["debug"]["query/ori_img"][None], 198 | "query/ori_score_map": batch[0]["debug"]["query/ori_score_map"][None], 199 | "query/crop_param": batch[0]["debug"]["query/crop_param"][None], 200 | } 201 | 202 | result["item_paths"] = batch[0]["item_paths"] 203 | 204 | if "reference/cross/imgs" in batch[0].keys(): 205 | result["reference/cross/imgs"] = batch[0]["reference/cross/imgs"][None] 206 | result["debug"]["reference/cross/ori_imgs"] = batch[0]["debug"][ 207 | "reference/cross/ori_imgs" 208 | ][None] 209 | result["debug"]["reference/cross/crop_param"] = batch[0]["debug"][ 210 | "reference/cross/crop_param" 211 | ][None] 212 | 213 | return result 214 | 215 | def __len__(self): 216 | return len(self.neighbour_selector) 217 | 218 | def resize_all(self, results): 219 | results["query/img"] = self.transforms["resize"](results["query/img"]) 220 | results["query/score_map"] = self.transforms["resize"](results["query/score_map"][None])[0] 221 | if "reference/cross/imgs" in results.keys(): 222 | results["reference/cross/imgs"] = self.transforms["resize"]( 223 | results["reference/cross/imgs"] 224 | ) 225 | return results 226 | 227 | def adaptive_crop_integer_patches_all(self, results): 228 | """ 229 | Adaptively crop all images to the closest integer patch size. 230 | This is needed for test_steps, where we need to compute loss on images in arbitrary sizes. 231 | """ 232 | P = 14 # dinov2 patch size 233 | ori_h, ori_w = results["query/img"].shape[-2:] 234 | new_h = ori_h - ori_h % P 235 | new_w = ori_w - ori_w % P 236 | results["query/img"] = results["query/img"][:, :new_h, :new_w] 237 | results["query/score_map"] = results["query/score_map"][:new_h, :new_w] 238 | if len(results["reference/cross/imgs"]) > 0: 239 | results["reference/cross/imgs"] = results["reference/cross/imgs"][:, :, :new_h, :new_w] 240 | return results 241 | 242 | def transform_query(self, img, score_map): 243 | if self.transforms.get("query_crop", None) is not None: 244 | crop_results = self.transforms["query_crop"](img, score_map) 245 | img = crop_results["out"][0] 246 | score_map = crop_results["out"][1] 247 | crop_param = crop_results["crop_param"] 248 | else: 249 | crop_param = torch.tensor([0, 0, *img.shape[-2:]]) # (4) 250 | 251 | if self.transforms.get("img", None) is not None: 252 | img = self.transforms["img"](img) 253 | 254 | if self.transforms.get("metric_map", None) is not None: 255 | score_map = self.transforms["metric_map"](score_map[None, None]) 256 | score_map = score_map[0, 0] 257 | 258 | return { 259 | "img": img, 260 | "score_map": score_map, 261 | "crop_param": crop_param, 262 | } 263 | 264 | def transform_reference(self, imgs): 265 | if self.transforms.get("reference_crop", None) is not None: 266 | crop_results = self.transforms["reference_crop"](imgs) 267 | imgs = crop_results["out"] 268 | crop_param = crop_results["crop_param"] 269 | else: 270 | crop_param = torch.stack( 271 | [torch.tensor([0, 0, *img.shape[-2:]]) for img in imgs] 272 | ) # (B, 4) 273 | 274 | if self.transforms.get("img", None) is not None: 275 | imgs = self.transforms["img"](imgs) 276 | return { 277 | "imgs": imgs, 278 | "crop_param": crop_param, 279 | } 280 | 281 | def _detect_conflict_transforms(self): 282 | if "resize" in self.transforms: 283 | crop_sizes = [] 284 | if "query_crop" in self.transforms: 285 | crop_sizes.append(self.transforms["query_crop"].output_size) 286 | if "reference_crop" in self.transforms: 287 | crop_sizes.append(self.transforms["reference_crop"].output_size) 288 | 289 | if len(crop_sizes) > 0: 290 | max_crop_size = np.max(crop_sizes) 291 | min_resize_size = np.min(self.transforms["resize"].size) 292 | if min_resize_size < max_crop_size: 293 | raise ValueError( 294 | f"Required to resize image before crop, " 295 | f"but min_resize_size {min_resize_size} " 296 | f"< max_crop_size {max_crop_size}" 297 | ) 298 | 299 | def _build_metric_config(self, metric_type, metric_min, metric_max): 300 | """ 301 | Convert predict_type to load_dir and vrange for metric map reading. 302 | Supported predict_types: ssim_0_1, ssim_-1_1, mse, mae 303 | """ 304 | vrange = [metric_min, metric_max] 305 | 306 | if metric_type in ["ssim", "mae"]: 307 | load_dir = f"metric_map/{metric_type}" 308 | elif metric_type in ["mse"]: 309 | load_dir = "metric_map/mae" # mse can be derived from mae 310 | else: 311 | raise ValueError(f"Invalid metric type {metric_type}") 312 | 313 | cfg = { 314 | "type": metric_type, 315 | "vrange": vrange, 316 | "load_dir": load_dir, 317 | } 318 | cfg = OmegaConf.create(cfg) 319 | return cfg 320 | 321 | @staticmethod 322 | def get_paths(scene_paths, num_gaussians_iters, metric_load_dir): 323 | """Get paths points to images and metric maps. Define query and refenence sets. 324 | Naming convention: 325 | Query: 326 | The (noisy) image we want to measure. 327 | Reference: 328 | Images for making predictions. 329 | For cross ref, we consider captured images that from ref splits. 330 | Example: 331 | When query_split is "train", we consider captured test images as cross ref. 332 | When query_split is "test", we consider captured training images as cross ref. 333 | """ 334 | scene_name_list = sorted([scene_path.name for scene_path in scene_paths]) 335 | all_paths = { 336 | scene_name: { 337 | "train": { 338 | "renders": {}, 339 | "gt": {}, 340 | "score_map": {}, 341 | }, 342 | "test": { 343 | "renders": {}, 344 | "gt": {}, 345 | "score_map": {}, 346 | }, 347 | } 348 | for scene_name in scene_name_list 349 | } 350 | 351 | for scene_path in scene_paths: 352 | scene_name = scene_path.name 353 | for gs_split in all_paths[scene_name].keys(): 354 | dir_split = Path(scene_path, gs_split) 355 | dir_iter_list = sorted(os.listdir(dir_split), key=lambda x: int(x.split("_")[-1])) 356 | dir_iter_list = [Path(dir_split, d) for d in dir_iter_list] 357 | 358 | # This dataset contains images rendered from gaussian splatting checkpoints 359 | # at different iterations. Use this to use images from earlier checkpoints, 360 | # which have more artefacts. 361 | if num_gaussians_iters > 0: 362 | dir_iter_list = dir_iter_list[:num_gaussians_iters] 363 | 364 | for dir_iter in dir_iter_list: 365 | iter_num = int(dir_iter.name.split("_")[-1]) 366 | for img_type in all_paths[scene_name][gs_split].keys(): 367 | if img_type in ["renders", "gt"]: 368 | img_dir = dir_iter / img_type 369 | elif img_type == "score_map": 370 | img_dir = dir_iter / metric_load_dir 371 | else: 372 | raise ValueError(f"Unknown img_type {img_type}") 373 | 374 | if os.path.exists(img_dir): 375 | img_names = sorted(os.listdir(img_dir)) 376 | paths = [img_dir / img_n for img_n in img_names] 377 | paths = [str(p) for p in paths] 378 | else: 379 | # if no metric map available, use a placeholder "empty_image" 380 | paths = ["empty_image"] * len(all_paths[scene_name][gs_split]["gt"]) 381 | all_paths[scene_name][gs_split][img_type][iter_num] = paths 382 | 383 | # all types of items should have the same item number as gt 384 | for img_type in all_paths[scene_name][gs_split].keys(): 385 | for iter_num in all_paths[scene_name][gs_split][img_type].keys(): 386 | N_imgs = len(all_paths[scene_name][gs_split][img_type][iter_num]) 387 | N_gt = len(all_paths[scene_name][gs_split]["gt"][iter_num]) 388 | if N_imgs != N_gt: 389 | raise ValueError( 390 | f"Number of items mismatch in " 391 | f"{scene_name}/{gs_split}/{iter_num}/{img_type}" 392 | ) 393 | 394 | # assemble query and reference sets 395 | def get_cross_ref_split(query_split): 396 | all_splits = ["train", "test"] 397 | all_splits.remove(query_split) 398 | cross_ref_split = all_splits[0] 399 | return cross_ref_split 400 | 401 | results = {} 402 | for scene_name in scene_name_list: 403 | results[scene_name] = {} 404 | for gs_split in ["train", "test"]: 405 | cross_ref_split = get_cross_ref_split(gs_split) 406 | 407 | results[scene_name][f"gs_{gs_split}"] = { 408 | "query": { 409 | "images": all_paths[scene_name][gs_split]["renders"], 410 | "score_map": all_paths[scene_name][gs_split]["score_map"], 411 | "N_iters": len(all_paths[scene_name][gs_split]["renders"]), 412 | "N_imgs_per_iter": len( 413 | list(all_paths[scene_name][gs_split]["renders"].values())[0] 414 | ), 415 | }, 416 | "reference": { 417 | "cross": { 418 | "images": all_paths[scene_name][cross_ref_split]["gt"], 419 | "N_iters": len(all_paths[scene_name][cross_ref_split]["gt"]), 420 | "N_imgs_per_iter": len( 421 | list(all_paths[scene_name][cross_ref_split]["gt"].values())[0] 422 | ), 423 | }, 424 | }, 425 | } 426 | return results 427 | 428 | @staticmethod 429 | def load_content(item_paths, zero_reference, metric_config): 430 | results = { 431 | "query/img": None, 432 | "query/score_map": None, 433 | "reference/cross/imgs": [], 434 | } 435 | 436 | for k in item_paths.keys(): 437 | if k == "query/img": 438 | results[k] = torch.tensor(image_read(item_paths[k])).permute(2, 0, 1) # (3, H, W) 439 | elif k == "query/score_map": 440 | if metric_config.type == "ssim": 441 | if item_paths[k] == "empty_image": 442 | results[k] = torch.zeros_like(results["query/img"][0]) # (H, W) 443 | else: 444 | # (H, W), always read SSIM in range [-1, 1] 445 | results[k] = torch.tensor(metric_map_read(item_paths[k], vrange=[-1, 1])) 446 | if metric_config.vrange == [0, 1]: 447 | results[k] = results[k].clamp(0, 1) 448 | elif metric_config.type in ["mse", "mae"]: 449 | if item_paths[k] == "empty_image": 450 | results[k] = torch.full_like(results["query/img"][0], torch.nan) # (H, W) 451 | else: 452 | # (H, W), always read MAE in range [0, 1] 453 | results[k] = torch.tensor(metric_map_read(item_paths[k], vrange=[0, 1])) 454 | if metric_config.type == "mse": 455 | results[k] = results[k].square() # create mse from loaded mae 456 | elif metric_config.type is None: 457 | # for SimpleReference, which doesn't need to load score maps 458 | results[k] = torch.zeros_like(results["query/img"][0]) 459 | elif k in ["reference/cross/imgs"]: 460 | if len(item_paths[k]) == 0: 461 | continue # if not loading this reference set, skip 462 | for path in item_paths[k]: 463 | if path == "empty_image": 464 | # NOTE: this assumes ref_img_size == query_img_size 465 | tmp_img = torch.zeros_like(results["query/img"]) # (3, H, W) 466 | else: 467 | tmp_img = torch.tensor(image_read(path)).permute(2, 0, 1) # (3, H, W) 468 | results[k].append(tmp_img) # (3, H, W) 469 | results[k] = torch.stack(results[k], dim=0) # (N, 3, H, W) 470 | if zero_reference: 471 | results[k] = torch.zeros_like(results[k]) 472 | else: 473 | raise ValueError(f"Unknown key {k}") 474 | return results 475 | 476 | 477 | def vis_batch(cfg, batch, metric_type, metric_min, metric_max, e, b): 478 | import matplotlib.pyplot as plt 479 | from matplotlib.patches import Rectangle 480 | 481 | metric_vrange = [metric_min, metric_max] 482 | 483 | # mkdir to save figures 484 | save_fig_dir = Path(cfg.this_main.save_fig_dir).expanduser() 485 | save_fig_dir.mkdir(parents=True, exist_ok=True) 486 | 487 | # Vis batch[0] in two figures: one actual loaded and one with debug info 488 | # First figure with actual loaded data 489 | max_cols = max([3, cfg.data.neighbour_config.cross]) 490 | _, axes = plt.subplots(2, max_cols, figsize=(15, 9)) 491 | for ax in axes.flatten(): 492 | ax.set_axis_off() 493 | 494 | # first row: query 495 | axes[0][0].imshow(batch["query/img"][0].permute(1, 2, 0).clip(0, 1)) 496 | axes[0][0].set_title("query/img") 497 | axes[0][1].imshow( 498 | batch["query/score_map"][0], 499 | vmin=metric_vrange[0], 500 | vmax=metric_vrange[1], 501 | cmap="turbo", 502 | ) 503 | axes[0][1].set_title(f"query/{metric_type}_map") 504 | for i in range(2): 505 | axes[0][i].set_axis_on() 506 | 507 | # second row: cross ref 508 | if "reference/cross/imgs" in batch.keys(): 509 | for i in range(batch["reference/cross/imgs"].shape[1]): 510 | axes[1][i].imshow(batch["reference/cross/imgs"][0, i].permute(1, 2, 0).clip(0, 1)) 511 | axes[1][i].set_title(f"reference/cross/imgs_{i}") 512 | axes[1][i].set_axis_on() 513 | 514 | plt.tight_layout() 515 | plt.savefig(save_fig_dir / f"e{e}b{b}.jpg") 516 | plt.close() 517 | 518 | # Second figure with full debug details in three rows 519 | if cfg.this_main.return_debug_info: 520 | _, axes = plt.subplots(2, max_cols, figsize=(20, 10)) 521 | for ax in axes.flatten(): 522 | ax.set_axis_off() 523 | 524 | # first row: query 525 | axes[0][0].imshow(batch["debug"]["query/ori_img"][0].permute(1, 2, 0).clip(0, 1)) 526 | axes[0][0].set_title("query/ori_img") 527 | axes[0][1].imshow( 528 | batch["debug"][f"query/ori_score_map"][0], 529 | vmin=metric_vrange[0], 530 | vmax=metric_vrange[1], 531 | cmap="turbo", 532 | ) 533 | axes[0][1].set_title(f"query/ori_{metric_type}_map") 534 | # crop box 535 | crop_param = batch["debug"]["query/crop_param"][0] 536 | for i in range(2): 537 | rect = Rectangle( 538 | (crop_param[1], crop_param[0]), 539 | crop_param[3], 540 | crop_param[2], 541 | linewidth=3, 542 | edgecolor="r", 543 | facecolor="none", 544 | ) 545 | axes[0][i].add_patch(rect) 546 | axes[0][i].set_axis_on() 547 | 548 | # third row: cross ref 549 | if "reference/cross/imgs" in batch.keys(): 550 | for i in range(batch["debug"]["reference/cross/ori_imgs"].shape[1]): 551 | axes[1][i].imshow( 552 | batch["debug"]["reference/cross/ori_imgs"][0, i].permute(1, 2, 0).clip(0, 1) 553 | ) 554 | axes[1][i].set_title(f"reference/cross/ori_imgs_{i}") 555 | # crop box 556 | crop_param = batch["debug"]["reference/cross/crop_param"][0][i] 557 | rect = Rectangle( 558 | (crop_param[1], crop_param[0]), 559 | crop_param[3], 560 | crop_param[2], 561 | linewidth=3, 562 | edgecolor="r", 563 | facecolor="none", 564 | ) 565 | axes[1][i].add_patch(rect) 566 | axes[1][i].set_axis_on() 567 | 568 | plt.tight_layout() 569 | plt.savefig(save_fig_dir / f"e{e}b{b}_full.jpg") 570 | plt.close() 571 | 572 | 573 | if __name__ == "__main__": 574 | from lightning import seed_everything 575 | from torchvision.transforms import v2 as T 576 | from tqdm import tqdm 577 | from dataloading.transformation.crop import CropperFactory 578 | from utils.io.images import ImageNetMeanStd 579 | from omegaconf import OmegaConf 580 | 581 | seed_everything(1) 582 | 583 | cfg = { 584 | "data": { 585 | "dataset": { 586 | "path": "datadir/processed_training_ready/gaussian/map-free-reloc", 587 | "resolution": "res_540", 588 | "num_gaussians_iters": 1, 589 | "zero_reference": False, 590 | }, 591 | "loader": { 592 | "batch_size": 8, 593 | "num_workers": 4, 594 | "shuffle": True, 595 | "data_split": "train", 596 | "pin_memory": True, 597 | "persistent_workers": True, 598 | }, 599 | "transforms": { 600 | "crop_size": 518, 601 | }, 602 | "neighbour_config": { 603 | "strategy": "random", 604 | "cross": 5, 605 | "deterministic": False, 606 | }, 607 | }, 608 | "this_main": { 609 | "skip_vis": False, 610 | "save_fig_dir": "./debug/dataset/NvsData", 611 | "epochs": 3, 612 | "skip_batches": 5, 613 | "deterministic_crop": False, 614 | "return_debug_info": True, 615 | "return_item_paths": True, 616 | "resize_short_side": -1, # -1 to disable 617 | "crop_mode": "dataset_default", 618 | # "crop_mode": None, 619 | }, 620 | "model": { 621 | "patch_size": 14, 622 | "predict": { 623 | "metric": { 624 | "type": "ssim", 625 | "min": 0, 626 | "max": 1, 627 | }, 628 | }, 629 | }, 630 | } 631 | cfg = OmegaConf.create(cfg) 632 | 633 | # Overwrite cfg in some conditions 634 | if cfg.data.dataset.resolution == "res_540": 635 | cfg.data.transforms.crop_size = 518 636 | elif cfg.data.dataset.resolution == "res_400": 637 | cfg.data.transforms.crop_size = 392 638 | elif cfg.data.dataset.resolution == "res_200": 639 | cfg.data.transforms.crop_size = 196 640 | else: 641 | raise ValueError("Unknown resolution") 642 | 643 | if cfg.data.loader.num_workers == 0: 644 | cfg.data.loader.persistent_workers = False 645 | 646 | # Check config 647 | ConfigChecker(cfg).check_dataset() 648 | 649 | # Init transforms for dataset 650 | img_norm_stat = ImageNetMeanStd() 651 | transforms = { 652 | "img": T.Normalize(mean=img_norm_stat.mean, std=img_norm_stat.std), 653 | } 654 | 655 | if cfg.this_main.crop_mode == "dataset_default": 656 | transforms["query_crop"] = CropperFactory( 657 | output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), 658 | same_on_batch=True, 659 | deterministic=cfg.this_main.deterministic_crop, 660 | ) 661 | transforms["reference_crop"] = CropperFactory( 662 | output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), 663 | same_on_batch=False, 664 | deterministic=cfg.this_main.deterministic_crop, 665 | ) 666 | 667 | if cfg.this_main.resize_short_side > 0: 668 | transforms["resize"] = T.Resize( 669 | cfg.this_main.resize_short_side, 670 | interpolation=T.InterpolationMode.BILINEAR, 671 | antialias=True, 672 | ) 673 | 674 | # Init dataset and dataloader 675 | dataset = NvsDataset( 676 | dataset_path=cfg.data.dataset.path, 677 | resolution=cfg.data.dataset.resolution, 678 | data_split=cfg.data.loader.data_split, 679 | transforms=transforms, 680 | neighbour_config=cfg.data.neighbour_config, 681 | metric_type=cfg.model.predict.metric.type, 682 | metric_min=cfg.model.predict.metric.min, 683 | metric_max=cfg.model.predict.metric.max, 684 | return_debug_info=cfg.this_main.return_debug_info, 685 | return_item_paths=cfg.this_main.return_item_paths, 686 | num_gaussians_iters=cfg.data.dataset.num_gaussians_iters, 687 | zero_reference=cfg.data.dataset.zero_reference, 688 | ) 689 | 690 | dataloader = torch.utils.data.DataLoader( 691 | dataset, 692 | batch_size=cfg.data.loader.batch_size, 693 | shuffle=cfg.data.loader.shuffle, 694 | num_workers=cfg.data.loader.num_workers, 695 | pin_memory=cfg.data.loader.pin_memory, 696 | persistent_workers=cfg.data.loader.persistent_workers, 697 | collate_fn=dataset.collate_fn_debug if cfg.this_main.return_debug_info else None, 698 | ) 699 | 700 | # Actual looping dataset 701 | for e in tqdm(range(cfg.this_main.epochs), desc="Epoch", dynamic_ncols=True): 702 | for b, batch in enumerate(tqdm(dataloader, desc="Batch", dynamic_ncols=True)): 703 | 704 | if cfg.this_main.skip_batches > 0 and b >= cfg.this_main.skip_batches: 705 | break 706 | 707 | if cfg.this_main.skip_vis: 708 | continue 709 | 710 | vis_batch( 711 | cfg, 712 | batch, 713 | metric_type=cfg.model.predict.metric.type, 714 | metric_min=cfg.model.predict.metric.min, 715 | metric_max=cfg.model.predict.metric.max, 716 | e=e, 717 | b=b, 718 | ) 719 | -------------------------------------------------------------------------------- /dataloading/dataset/simple_reference.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from pathlib import Path 3 | import torch 4 | from omegaconf import OmegaConf 5 | 6 | sys.path.append(str(Path(__file__).parents[2])) 7 | from dataloading.dataset.nvs_dataset import NvsDataset, NeighbourSelector, vis_batch 8 | 9 | 10 | class SimpleReference(NvsDataset): 11 | def __init__( 12 | self, 13 | query_dir, 14 | reference_dir, 15 | transforms, 16 | neighbour_config, 17 | return_debug_info=False, 18 | return_item_paths=False, 19 | **kwargs, 20 | ): 21 | self.transforms = transforms 22 | self.neighbour_config = neighbour_config 23 | self.return_debug_info = return_debug_info 24 | self.return_item_paths = return_item_paths 25 | self.zero_reference = kwargs.get("zero_reference", False) 26 | 27 | self._detect_conflict_transforms() 28 | self.metric_config = self._build_empty_metric_config() 29 | 30 | self.all_paths = self.get_paths(query_dir, reference_dir) 31 | self.neighbour_selector = NeighbourSelector(self.all_paths, self.neighbour_config) 32 | 33 | def _build_empty_metric_config(self): 34 | cfg = { 35 | "type": None, 36 | "vrange": None, 37 | "load_dir": None, 38 | } 39 | cfg = OmegaConf.create(cfg) 40 | return cfg 41 | 42 | @staticmethod 43 | def get_paths(query_dir, reference_dir): 44 | """Define query and reference paths for ONE scene. 45 | This function is written in a way that mimics the 46 | NvsDataset.get_paths(), so that we can reuse most NvsDataset methods. 47 | 48 | :param scene_name: str 49 | :param query_dir: str, a dir that contains query images 50 | :param reference_dir: str, a dir that contains reference images 51 | """ 52 | 53 | query_dir = os.path.expanduser(query_dir) 54 | reference_dir = os.path.expanduser(reference_dir) 55 | query_paths = [os.path.join(query_dir, p) for p in sorted(os.listdir(query_dir))] 56 | reference_paths = [ 57 | os.path.join(reference_dir, p) for p in sorted(os.listdir(reference_dir)) 58 | ] 59 | 60 | fake_iter = -1 61 | query = { 62 | "images": {fake_iter: query_paths}, 63 | "score_map": {fake_iter: ["empty_image"] * len(query_paths)}, 64 | "N_iters": 1, 65 | "N_imgs_per_iter": len(query_paths), 66 | } 67 | reference = { 68 | "cross": { 69 | "images": {fake_iter: reference_paths}, 70 | "N_iters": 1, 71 | "N_imgs_per_iter": len(reference_paths), 72 | } 73 | } 74 | 75 | # use query dir as scene name and anonymize the path 76 | scene_name = str(query_dir).replace(str(Path.home()), "~") 77 | results = { 78 | scene_name: { 79 | "gs_test": { 80 | "query": query, 81 | "reference": reference, 82 | }, 83 | } 84 | } 85 | return results 86 | 87 | 88 | if __name__ == "__main__": 89 | from lightning import seed_everything 90 | from torchvision.transforms import v2 as T 91 | from tqdm import tqdm 92 | from dataloading.transformation.crop import CropperFactory 93 | from utils.io.images import ImageNetMeanStd 94 | from omegaconf import OmegaConf 95 | 96 | seed_everything(1) 97 | cfg = { 98 | "data": { 99 | "dataset": { 100 | "query_dir": "datadir/processed_training_ready/gaussian/map-free-reloc/res_540/s00000/test/ours_1000/renders", 101 | "reference_dir": "datadir/processed_training_ready/gaussian/map-free-reloc/res_540/s00000/train/ours_1000/gt", 102 | "resolution": "res_540", 103 | "zero_reference": False, 104 | }, 105 | "loader": { 106 | "batch_size": 8, 107 | "num_workers": 4, 108 | "shuffle": True, 109 | "pin_memory": True, 110 | "persistent_workers": False, 111 | }, 112 | "transforms": { 113 | "crop_size": 392, 114 | }, 115 | "neighbour_config": { 116 | "strategy": "random", 117 | "cross": 5, 118 | "deterministic": False, 119 | }, 120 | }, 121 | "this_main": { 122 | "skip_vis": False, 123 | "save_fig_dir": "./debug/dataset/simple_reference", 124 | "epochs": 1, 125 | "skip_batches": -1, 126 | "deterministic_crop": False, 127 | "return_debug_info": True, 128 | "return_item_paths": True, 129 | "resize_short_side": 518, # -1 to disable 130 | # "crop_mode": "dataset_default", 131 | "crop_mode": None, 132 | }, 133 | "model": { 134 | "patch_size": 14, 135 | "do_reference_cross": True, 136 | "predict": { 137 | "metric": { 138 | "type": "ssim", 139 | "min": 0, 140 | "max": 1, 141 | }, 142 | }, 143 | }, 144 | } 145 | cfg = OmegaConf.create(cfg) 146 | 147 | # overwrite cfg in some conditions 148 | if cfg.data.dataset.resolution == "res_540": 149 | cfg.data.transforms.crop_size = 518 150 | elif cfg.data.dataset.resolution == "res_400": 151 | cfg.data.transforms.crop_size = 392 152 | elif cfg.data.dataset.resolution == "res_200": 153 | cfg.data.transforms.crop_size = 196 154 | else: 155 | raise ValueError("Unknown resolution") 156 | 157 | if cfg.data.loader.num_workers == 0: 158 | cfg.data.loader.persistent_workers = False 159 | 160 | # for dataloader 161 | img_norm_stat = ImageNetMeanStd() 162 | transforms = { 163 | "img": T.Normalize(mean=img_norm_stat.mean, std=img_norm_stat.std), 164 | } 165 | 166 | if cfg.this_main.crop_mode == "dataset_default": 167 | transforms["query_crop"] = CropperFactory( 168 | output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), 169 | same_on_batch=True, 170 | deterministic=cfg.this_main.deterministic_crop, 171 | ) 172 | transforms["reference_crop"] = CropperFactory( 173 | output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), 174 | same_on_batch=False, 175 | deterministic=cfg.this_main.deterministic_crop, 176 | ) 177 | 178 | if cfg.this_main.resize_short_side > 0: 179 | transforms["resize"] = T.Resize( 180 | cfg.this_main.resize_short_side, 181 | interpolation=T.InterpolationMode.BILINEAR, 182 | antialias=True, 183 | ) 184 | 185 | dataset = SimpleReference( 186 | query_dir=cfg.data.dataset.query_dir, 187 | reference_dir=cfg.data.dataset.reference_dir, 188 | transforms=transforms, 189 | neighbour_config=cfg.data.neighbour_config, 190 | return_debug_info=cfg.this_main.return_debug_info, 191 | return_item_paths=cfg.this_main.return_item_paths, 192 | zero_reference=cfg.data.dataset.zero_reference, 193 | ) 194 | 195 | dataloader = torch.utils.data.DataLoader( 196 | dataset, 197 | batch_size=cfg.data.loader.batch_size, 198 | shuffle=cfg.data.loader.shuffle, 199 | num_workers=cfg.data.loader.num_workers, 200 | pin_memory=cfg.data.loader.pin_memory, 201 | persistent_workers=cfg.data.loader.persistent_workers, 202 | collate_fn=dataset.collate_fn_debug if cfg.this_main.return_debug_info else None, 203 | ) 204 | 205 | # actual looping dataset 206 | for e in tqdm(range(cfg.this_main.epochs), desc="Epoch", dynamic_ncols=True): 207 | for b, batch in enumerate(tqdm(dataloader, desc="Batch", dynamic_ncols=True)): 208 | if cfg.this_main.skip_batches > 0 and b >= cfg.this_main.skip_batches: 209 | break 210 | 211 | if cfg.this_main.skip_vis: 212 | continue 213 | 214 | vis_batch( 215 | cfg, 216 | batch, 217 | metric_type=cfg.model.predict.metric.type, 218 | metric_min=cfg.model.predict.metric.min, 219 | metric_max=cfg.model.predict.metric.max, 220 | e=e, 221 | b=b, 222 | ) 223 | -------------------------------------------------------------------------------- /dataloading/transformation/crop.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import numpy as np 3 | import torch 4 | from torchvision.transforms import v2 as T 5 | 6 | 7 | def get_crop_params(input_size, output_size, deterministic): 8 | """Get random crop parameters for a given image and output size. 9 | Args: 10 | img: numpy array hwc 11 | output_size (tuple): Expected output size of the crop. 12 | Returns: 13 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 14 | """ 15 | in_h, in_w = input_size 16 | out_h, out_w = output_size 17 | 18 | # i, j, h, w 19 | if deterministic: 20 | i, j = 0, 0 21 | else: 22 | i = np.random.randint(0, in_h - out_h + 1) 23 | j = np.random.randint(0, in_w - out_w + 1) 24 | return torch.tensor([i, j, out_h, out_w]) 25 | 26 | 27 | class Cropper(ABC): 28 | def __init__(self, output_size, deterministic=False): 29 | self.output_size = output_size 30 | self.deterministic = deterministic 31 | 32 | @abstractmethod 33 | def __call__(self, *args): 34 | raise NotImplementedError 35 | 36 | 37 | class RandomCropperBatchSeparate(Cropper): 38 | """For an input tensor, assuming it's batched, and apply **DIFF** crop params 39 | to each item in the batch. 40 | """ 41 | 42 | def __call__(self, imgs): 43 | # x: (B, C, H, W), (B, H, W) 44 | if imgs.ndim not in [3, 4]: 45 | raise ValueError("imgs.ndim must be one of [3, 4]") 46 | 47 | out_list = [] 48 | crop_param_list = [] 49 | for img in imgs: 50 | crop_param = get_crop_params(img.shape[-2:], self.output_size, self.deterministic) 51 | img = T.functional.crop(img, *crop_param) 52 | out_list.append(img) 53 | crop_param_list.append(crop_param) 54 | out_list = torch.stack(out_list) 55 | crop_param_list = torch.stack(crop_param_list) 56 | return { 57 | "out": out_list, # (B, C, H, W) or (B, H, W) 58 | "crop_param": crop_param_list, # (B, 4) 59 | } 60 | 61 | 62 | class RandomCropperBatchSame(Cropper): 63 | """For a list of input tensors, assuming they're batched, and apply **SAME** 64 | crop params to all. 65 | """ 66 | 67 | def __call__(self, *args): 68 | # use one set of crop params for all input 69 | crop_param = get_crop_params(args[0].shape[-2:], self.output_size, self.deterministic) 70 | out = [T.functional.crop(x, *crop_param) for x in args] 71 | return { 72 | "out": out, 73 | "crop_param": crop_param, 74 | } 75 | 76 | 77 | class CropperFactory: 78 | def __init__(self, output_size, same_on_batch, deterministic=False): 79 | self.output_size = output_size 80 | if same_on_batch: 81 | self.cropper = RandomCropperBatchSame(output_size, deterministic) 82 | else: 83 | self.cropper = RandomCropperBatchSeparate(output_size, deterministic) 84 | 85 | def __call__(self, *args): 86 | return self.cropper(*args) 87 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: CrossScore 2 | channels: 3 | - xformers 4 | - pytorch 5 | - huggingface 6 | - nvidia 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=2_gnu 12 | - abseil-cpp=20211102.0=h27087fc_1 13 | - accelerate=0.25.0=pyhd8ed1ab_0 14 | - aiohttp=3.9.1=py310h2372a71_0 15 | - aiosignal=1.3.1=pyhd8ed1ab_0 16 | - annotated-types=0.6.0=pyhd8ed1ab_0 17 | - anyio=3.7.1=pyhd8ed1ab_0 18 | - aom=3.7.1=h59595ed_0 19 | - arrow=1.3.0=pyhd8ed1ab_0 20 | - arrow-cpp=8.0.0=py310h3098874_0 21 | - asttokens=2.0.5=pyhd3eb1b0_0 22 | - async-timeout=4.0.3=pyhd8ed1ab_0 23 | - attrs=23.1.0=pyh71513ae_1 24 | - aws-c-common=0.4.57=he1b5a44_1 25 | - aws-c-event-stream=0.1.6=h72b8ae1_3 26 | - aws-checksums=0.1.9=h346380f_0 27 | - aws-sdk-cpp=1.8.185=hce553d0_0 28 | - backcall=0.2.0=pyhd3eb1b0_0 29 | - backoff=2.2.1=pyhd8ed1ab_0 30 | - beautifulsoup4=4.12.2=pyha770c72_0 31 | - blas=1.0=mkl 32 | - blessed=1.19.1=pyhe4f9e05_2 33 | - blosc=1.21.0=h4ff587b_1 34 | - boost-cpp=1.82.0=hdb19cb5_2 35 | - boto3=1.34.11=pyhd8ed1ab_0 36 | - botocore=1.34.11=pyhd8ed1ab_0 37 | - brotli=1.0.9=h5eee18b_7 38 | - brotli-bin=1.0.9=h5eee18b_7 39 | - brotli-python=1.0.9=py310h6a678d5_7 40 | - brunsli=0.1=h2531618_0 41 | - bzip2=1.0.8=h7b6447c_0 42 | - c-ares=1.23.0=hd590300_0 43 | - c-blosc2=2.10.2=hb4ffafa_0 44 | - ca-certificates=2024.7.2=h06a4308_0 45 | - cachecontrol=0.13.1=pyhd8ed1ab_0 46 | - cachecontrol-with-filecache=0.13.1=pyhd8ed1ab_0 47 | - cairo=1.18.0=h3faef2a_0 48 | - certifi=2024.8.30=py310h06a4308_0 49 | - cffi=1.16.0=py310h5eee18b_0 50 | - cfitsio=4.0.0=h9a35b8e_0 51 | - charls=2.2.0=h2531618_0 52 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 53 | - cleo=2.1.0=pyhd8ed1ab_0 54 | - click=8.1.7=py310h06a4308_0 55 | - colorama=0.4.6=pyhd8ed1ab_0 56 | - contourpy=1.2.0=py310hdb19cb5_0 57 | - crashtest=0.4.1=pyhd8ed1ab_0 58 | - croniter=1.4.1=pyhd8ed1ab_0 59 | - cryptography=41.0.3=py310h130f0dd_0 60 | - cuda-cudart=12.1.105=0 61 | - cuda-cupti=12.1.105=0 62 | - cuda-libraries=12.1.0=0 63 | - cuda-nvrtc=12.1.105=0 64 | - cuda-nvtx=12.1.105=0 65 | - cuda-opencl=12.3.101=0 66 | - cuda-runtime=12.1.0=0 67 | - cycler=0.11.0=pyhd3eb1b0_0 68 | - cyrus-sasl=2.1.28=h9c0eb46_1 69 | - dataclasses=0.8=pyh6d0b6a4_7 70 | - datasets=2.15.0=py_0 71 | - dateutils=0.6.12=py_0 72 | - dav1d=1.2.1=h5eee18b_0 73 | - dbus=1.13.18=hb2f20db_0 74 | - decorator=5.1.1=pyhd3eb1b0_0 75 | - deepdiff=6.7.1=pyhd8ed1ab_0 76 | - dill=0.3.7=pyhd8ed1ab_0 77 | - distlib=0.3.8=pyhd8ed1ab_0 78 | - dulwich=0.21.7=py310h2372a71_0 79 | - exceptiongroup=1.2.0=pyhd8ed1ab_0 80 | - executing=0.8.3=pyhd3eb1b0_0 81 | - expat=2.5.0=h6a678d5_0 82 | - fastapi=0.103.2=pyhd8ed1ab_0 83 | - ffmpeg=6.1.1=gpl_hf3b701a_101 84 | - ffmpeg-python=0.2.0=py_0 85 | - filelock=3.13.1=py310h06a4308_0 86 | - font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0 87 | - font-ttf-inconsolata=2.001=hcb22688_0 88 | - font-ttf-source-code-pro=2.030=hd3eb1b0_0 89 | - font-ttf-ubuntu=0.83=h8b1ccd4_0 90 | - fontconfig=2.14.2=h14ed4e7_0 91 | - fonts-anaconda=1=h8fa9717_0 92 | - fonts-conda-ecosystem=1=hd3eb1b0_0 93 | - fonttools=4.25.0=pyhd3eb1b0_0 94 | - freetype=2.12.1=h4a9f257_0 95 | - fribidi=1.0.10=h7b6447c_0 96 | - frozenlist=1.4.0=py310h2372a71_1 97 | - fsspec=2023.10.0=py310h06a4308_0 98 | - future=0.18.3=py310h06a4308_0 99 | - gflags=2.2.2=he1b5a44_1004 100 | - giflib=5.2.1=h5eee18b_3 101 | - git-lfs=3.5.1=ha770c72_0 102 | - glib=2.78.4=h6a678d5_0 103 | - glib-tools=2.78.4=h6a678d5_0 104 | - glog=0.5.0=h48cff8f_0 105 | - gmp=6.3.0=h59595ed_1 106 | - gmpy2=2.1.2=py310heeb90bb_0 107 | - gnutls=3.7.9=hb077bed_0 108 | - graphite2=1.3.14=h295c915_1 109 | - grpc-cpp=1.46.1=h33aed49_1 110 | - gst-plugins-base=1.14.1=h6a678d5_1 111 | - gstreamer=1.14.1=h5eee18b_1 112 | - h11=0.14.0=pyhd8ed1ab_0 113 | - harfbuzz=8.3.0=h3d44ed6_0 114 | - huggingface_hub=0.19.4=py_0 115 | - icu=73.2=h59595ed_0 116 | - idna=3.4=py310h06a4308_0 117 | - imagecodecs=2021.11.20=py310ha26f956_1 118 | - imageio=2.31.4=py310h06a4308_0 119 | - importlib-metadata=7.0.1=pyha770c72_0 120 | - importlib_metadata=7.0.1=hd8ed1ab_0 121 | - inquirer=3.1.4=pyhd8ed1ab_0 122 | - intel-openmp=2023.1.0=hdb19cb5_46306 123 | - ipython=8.15.0=py310h06a4308_0 124 | - itsdangerous=2.1.2=pyhd8ed1ab_0 125 | - jaraco.classes=3.3.0=pyhd8ed1ab_0 126 | - jedi=0.18.1=py310h06a4308_1 127 | - jeepney=0.8.0=pyhd8ed1ab_0 128 | - jinja2=3.1.2=py310h06a4308_0 129 | - jmespath=1.0.1=pyhd8ed1ab_0 130 | - joblib=1.2.0=py310h06a4308_0 131 | - jpeg=9e=h5eee18b_1 132 | - jxrlib=1.1=h7b6447c_2 133 | - keyring=24.3.0=py310hff52083_0 134 | - kiwisolver=1.4.4=py310h6a678d5_0 135 | - krb5=1.20.1=h568e23c_1 136 | - lame=3.100=h7b6447c_0 137 | - lazy_loader=0.3=py310h06a4308_0 138 | - lcms2=2.12=h3be6417_0 139 | - ld_impl_linux-64=2.38=h1181459_1 140 | - lerc=3.0=h295c915_0 141 | - libaec=1.1.3=h59595ed_0 142 | - libass=0.17.1=h8fe9dca_1 143 | - libboost=1.82.0=h109eef0_2 144 | - libbrotlicommon=1.0.9=h5eee18b_7 145 | - libbrotlidec=1.0.9=h5eee18b_7 146 | - libbrotlienc=1.0.9=h5eee18b_7 147 | - libclang=14.0.6=default_hc6dbbc7_1 148 | - libclang13=14.0.6=default_he11475f_1 149 | - libcublas=12.1.0.26=0 150 | - libcufft=11.0.2.4=0 151 | - libcufile=1.8.1.2=0 152 | - libcups=2.4.2=ha637b67_0 153 | - libcurand=10.3.4.101=0 154 | - libcurl=8.2.1=h91b91d3_0 155 | - libcusolver=11.4.4.55=0 156 | - libcusparse=12.0.2.55=0 157 | - libdeflate=1.8=h7f8727e_5 158 | - libdrm=2.4.120=hd590300_0 159 | - libedit=3.1.20221030=h5eee18b_0 160 | - libev=4.33=h516909a_1 161 | - libevent=2.1.10=h9b69904_4 162 | - libexpat=2.5.0=hcb278e6_1 163 | - libffi=3.4.4=h6a678d5_0 164 | - libgcc-ng=13.2.0=h807b86a_3 165 | - libgfortran-ng=11.2.0=h00389a5_1 166 | - libgfortran5=11.2.0=h1234567_1 167 | - libglib=2.78.4=hdc74915_0 168 | - libgomp=13.2.0=h807b86a_3 169 | - libiconv=1.17=hd590300_2 170 | - libidn2=2.3.4=h5eee18b_0 171 | - libjpeg-turbo=2.0.0=h9bf148f_0 172 | - libllvm14=14.0.6=hdb19cb5_3 173 | - libnghttp2=1.52.0=ha637b67_1 174 | - libnpp=12.0.2.50=0 175 | - libnsl=2.0.0=h5eee18b_0 176 | - libnvjitlink=12.1.105=0 177 | - libnvjpeg=12.1.1.14=0 178 | - libopenvino=2023.2.0=h59595ed_0 179 | - libopenvino-auto-batch-plugin=2023.2.0=h59595ed_0 180 | - libopenvino-auto-plugin=2023.2.0=h59595ed_0 181 | - libopenvino-hetero-plugin=2023.2.0=h59595ed_0 182 | - libopenvino-intel-cpu-plugin=2023.2.0=h59595ed_0 183 | - libopenvino-intel-gpu-plugin=2023.2.0=h59595ed_0 184 | - libopenvino-ir-frontend=2023.2.0=h59595ed_0 185 | - libopenvino-onnx-frontend=2023.2.0=h59595ed_0 186 | - libopenvino-paddle-frontend=2023.2.0=h59595ed_0 187 | - libopenvino-pytorch-frontend=2023.2.0=h59595ed_0 188 | - libopenvino-tensorflow-frontend=2023.2.0=h59595ed_0 189 | - libopenvino-tensorflow-lite-frontend=2023.2.0=h59595ed_0 190 | - libopus=1.3.1=h7b6447c_0 191 | - libpciaccess=0.18=hd590300_0 192 | - libpng=1.6.39=h5eee18b_0 193 | - libpq=12.15=h37d81fd_1 194 | - libprotobuf=3.20.3=he621ea3_0 195 | - libsqlite=3.45.3=h2797004_0 196 | - libssh2=1.10.0=ha56f1ee_2 197 | - libstdcxx-ng=13.2.0=h7e041cc_5 198 | - libtasn1=4.19.0=h5eee18b_0 199 | - libthrift=0.15.0=he6d91bd_0 200 | - libtiff=4.4.0=hecacb30_2 201 | - libunistring=0.9.10=h27cfd23_0 202 | - libuuid=2.38.1=h0b41bf4_0 203 | - libva=2.21.0=hd590300_0 204 | - libvpx=1.13.1=h6a678d5_0 205 | - libwebp=1.3.2=h11a3e52_0 206 | - libwebp-base=1.3.2=h5eee18b_0 207 | - libxcb=1.15=h7f8727e_0 208 | - libxkbcommon=1.7.0=h662e7e4_0 209 | - libxml2=2.12.6=h232c23b_2 210 | - libzlib=1.2.13=hd590300_5 211 | - libzopfli=1.0.3=he6710b0_0 212 | - lightning=2.1.3=pyhd8ed1ab_0 213 | - lightning-cloud=0.5.57=pyhd8ed1ab_0 214 | - lightning-utilities=0.10.0=pyhd8ed1ab_0 215 | - llvm-openmp=14.0.6=h9e868ea_0 216 | - lz4-c=1.9.3=h295c915_1 217 | - markdown-it-py=3.0.0=pyhd8ed1ab_0 218 | - markupsafe=2.1.1=py310h7f8727e_0 219 | - matplotlib=3.8.0=py310h06a4308_0 220 | - matplotlib-base=3.8.0=py310h1128e8f_0 221 | - matplotlib-inline=0.1.6=py310h06a4308_0 222 | - mdurl=0.1.0=pyhd8ed1ab_0 223 | - mkl=2023.1.0=h213fc3f_46344 224 | - mkl-service=2.4.0=py310h5eee18b_1 225 | - mkl_fft=1.3.8=py310h5eee18b_0 226 | - mkl_random=1.2.4=py310hdb19cb5_0 227 | - more-itertools=10.1.0=pyhd8ed1ab_0 228 | - mpc=1.1.0=h10f8cd9_1 229 | - mpfr=4.0.2=hb69a4c5_1 230 | - mpmath=1.3.0=py310h06a4308_0 231 | - msgpack-python=1.0.3=py310hbf28c38_1 232 | - multidict=6.0.4=py310h2372a71_1 233 | - multiprocess=0.70.15=py310h2372a71_1 234 | - munkres=1.1.4=py_0 235 | - mysql=5.7.24=he378463_2 236 | - ncurses=6.4=h6a678d5_0 237 | - nettle=3.9.1=h7ab15ed_0 238 | - networkx=3.1=py310h06a4308_0 239 | - numpy=1.26.2=py310h5f9d8c6_0 240 | - numpy-base=1.26.2=py310hb5e798b_0 241 | - ocl-icd=2.3.2=hd590300_1 242 | - ocl-icd-system=1.0.0=1 243 | - openh264=2.4.0=h59595ed_0 244 | - openjpeg=2.4.0=h3ad879b_0 245 | - openssl=1.1.1w=h7f8727e_0 246 | - orc=1.7.4=h07ed6aa_0 247 | - ordered-set=4.1.0=pyhd8ed1ab_0 248 | - orjson=3.9.10=py310h1e2579a_0 249 | - p11-kit=0.24.1=hc5aa10d_0 250 | - packaging=23.1=py310h06a4308_0 251 | - pandas=1.4.2=py310h769672d_1 252 | - parso=0.8.3=pyhd3eb1b0_0 253 | - pcre=8.45=h295c915_0 254 | - pcre2=10.42=hebb0a14_0 255 | - pexpect=4.8.0=pyh1a96a4e_2 256 | - pickleshare=0.7.5=pyhd3eb1b0_1003 257 | - pillow=10.0.1=py310ha6cbd5a_0 258 | - pip=23.3.1=py310h06a4308_0 259 | - pixman=0.43.2=h59595ed_0 260 | - pkginfo=1.9.6=pyhd8ed1ab_0 261 | - platformdirs=3.11.0=pyhd8ed1ab_0 262 | - ply=3.11=py310h06a4308_0 263 | - poetry=1.7.1=linux_pyha804496_0 264 | - poetry-core=1.8.1=pyhd8ed1ab_0 265 | - poetry-plugin-export=1.6.0=pyhd8ed1ab_0 266 | - prompt-toolkit=3.0.43=py310h06a4308_0 267 | - protobuf=3.20.3=py310h6a678d5_0 268 | - ptyprocess=0.7.0=pyhd3deb0d_0 269 | - pugixml=1.14=h59595ed_0 270 | - pure_eval=0.2.2=pyhd3eb1b0_0 271 | - pyarrow=8.0.0=py310h468efa6_0 272 | - pyarrow-hotfix=0.6=pyhd8ed1ab_0 273 | - pycparser=2.21=pyhd3eb1b0_0 274 | - pydantic=2.1.1=pyhd8ed1ab_0 275 | - pydantic-core=2.4.0=py310hcb5633a_0 276 | - pygments=2.17.2=pyhd8ed1ab_0 277 | - pyjwt=2.8.0=pyhd8ed1ab_0 278 | - pyopenssl=23.2.0=py310h06a4308_0 279 | - pyparsing=3.0.9=py310h06a4308_0 280 | - pyproject_hooks=1.0.0=pyhd8ed1ab_0 281 | - pyqt=5.15.10=py310h6a678d5_0 282 | - pyqt5-sip=12.13.0=py310h5eee18b_0 283 | - pysocks=1.7.1=py310h06a4308_0 284 | - python=3.10.8=h257c98d_0_cpython 285 | - python-build=1.0.3=pyhd8ed1ab_0 286 | - python-dateutil=2.8.2=pyhd3eb1b0_0 287 | - python-editor=1.0.4=py_0 288 | - python-fastjsonschema=2.19.1=pyhd8ed1ab_0 289 | - python-installer=0.7.0=pyhd8ed1ab_0 290 | - python-multipart=0.0.6=pyhd8ed1ab_0 291 | - python-xxhash=3.4.1=py310h2372a71_0 292 | - python_abi=3.10=2_cp310 293 | - pytorch=2.1.2=py3.10_cuda12.1_cudnn8.9.2_0 294 | - pytorch-cuda=12.1=ha16c6d3_5 295 | - pytorch-lightning=2.1.3=pyhd8ed1ab_0 296 | - pytorch-mutex=1.0=cuda 297 | - pytz=2023.3.post1=pyhd8ed1ab_0 298 | - pyyaml=6.0.1=py310h5eee18b_0 299 | - qt-main=5.15.2=h110a718_10 300 | - rapidfuzz=3.5.2=py310h6a678d5_0 301 | - re2=2022.04.01=h27087fc_0 302 | - readchar=4.0.5=pyhd8ed1ab_0 303 | - readline=8.2=h5eee18b_0 304 | - regex=2023.10.3=py310h5eee18b_0 305 | - requests=2.31.0=py310h06a4308_0 306 | - requests-toolbelt=1.0.0=pyhd8ed1ab_0 307 | - rich=13.7.0=pyhd8ed1ab_0 308 | - s3transfer=0.10.0=pyhd8ed1ab_0 309 | - sacremoses=master=py_0 310 | - safetensors=0.4.0=py310ha89cbab_0 311 | - scikit-image=0.22.0=py310h1128e8f_0 312 | - scipy=1.11.4=py310h5f9d8c6_0 313 | - secretstorage=3.3.3=py310hff52083_2 314 | - setuptools=68.0.0=py310h06a4308_0 315 | - shellingham=1.5.4=pyhd8ed1ab_0 316 | - sip=6.7.12=py310h6a678d5_0 317 | - six=1.16.0=pyhd3eb1b0_1 318 | - snappy=1.1.10=h6a678d5_1 319 | - sniffio=1.3.0=pyhd8ed1ab_0 320 | - soupsieve=2.5=pyhd8ed1ab_1 321 | - sqlite=3.41.2=h5eee18b_0 322 | - stack_data=0.2.0=pyhd3eb1b0_0 323 | - starlette=0.27.0=pyhd8ed1ab_0 324 | - starsessions=1.3.0=pyhd8ed1ab_0 325 | - svt-av1=1.8.0=h59595ed_0 326 | - sympy=1.12=py310h06a4308_0 327 | - tbb=2021.8.0=hdb19cb5_0 328 | - tifffile=2022.10.10=pyhd8ed1ab_0 329 | - tk=8.6.12=h1ccaba5_0 330 | - tokenizers=0.11.4=py310h3dcd8bd_1 331 | - tomli=2.0.1=py310h06a4308_0 332 | - tomlkit=0.12.3=pyha770c72_0 333 | - torchaudio=2.1.2=py310_cu121 334 | - torchmetrics=1.4.0.post0=pyhd8ed1ab_0 335 | - torchtriton=2.1.0=py310 336 | - torchvision=0.16.2=py310_cu121 337 | - tornado=6.3.3=py310h5eee18b_0 338 | - tqdm=4.65.0=py310h2f386ee_0 339 | - traitlets=5.14.0=pyhd8ed1ab_0 340 | - transformers=4.33.3=py_0 341 | - trove-classifiers=2023.11.29=pyhd8ed1ab_0 342 | - types-python-dateutil=2.8.19.14=pyhd8ed1ab_0 343 | - typing-extensions=4.7.1=py310h06a4308_0 344 | - typing_extensions=4.7.1=py310h06a4308_0 345 | - tzdata=2023c=h04d1e81_0 346 | - urllib3=1.26.18=py310h06a4308_0 347 | - utf8proc=2.6.1=h27cfd23_0 348 | - uvicorn=0.25.0=py310hff52083_0 349 | - virtualenv=20.25.0=pyhd8ed1ab_0 350 | - wcwidth=0.2.12=pyhd8ed1ab_0 351 | - websocket-client=1.7.0=pyhd8ed1ab_0 352 | - websockets=12.0=py310h2372a71_0 353 | - wheel=0.41.2=py310h06a4308_0 354 | - x264=1!164.3095=h166bdaf_2 355 | - x265=3.5=h924138e_3 356 | - xformers=0.0.23.post1=py310_cu12.1.0_pyt2.1.2 357 | - xkeyboard-config=2.41=hd590300_0 358 | - xorg-fixesproto=5.0=h7f98852_1002 359 | - xorg-kbproto=1.0.7=h7f98852_1002 360 | - xorg-libice=1.1.1=hd590300_0 361 | - xorg-libsm=1.2.4=h7391055_0 362 | - xorg-libx11=1.8.9=h8ee46fc_0 363 | - xorg-libxau=1.0.11=hd590300_0 364 | - xorg-libxext=1.3.4=h0b41bf4_2 365 | - xorg-libxfixes=5.0.3=h7f98852_1004 366 | - xorg-libxrender=0.9.11=hd590300_0 367 | - xorg-renderproto=0.11.1=h7f98852_1002 368 | - xorg-xextproto=7.3.0=h0b41bf4_1003 369 | - xorg-xproto=7.0.31=h27cfd23_1007 370 | - xxhash=0.8.2=hd590300_0 371 | - xz=5.4.5=h5eee18b_0 372 | - yaml=0.2.5=h7b6447c_0 373 | - yarl=1.9.3=py310h2372a71_0 374 | - zfp=0.5.5=h295c915_6 375 | - zipp=3.17.0=pyhd8ed1ab_0 376 | - zlib=1.2.13=hd590300_5 377 | - zlib-ng=2.0.7=h5eee18b_0 378 | - zstd=1.5.2=ha4553b6_0 379 | - pip: 380 | - addict==2.4.0 381 | - antlr4-python3-runtime==4.9.3 382 | - appdirs==1.4.4 383 | - blinker==1.7.0 384 | - comm==0.2.1 385 | - configargparse==1.7 386 | - dash==2.16.0 387 | - dash-core-components==2.0.0 388 | - dash-html-components==2.0.0 389 | - dash-table==5.0.0 390 | - docker-pycreds==0.4.0 391 | - flask==3.0.2 392 | - gitdb==4.0.11 393 | - gitpython==3.1.40 394 | - hydra-core==1.3.2 395 | - ipywidgets==8.1.2 396 | - jsonschema==4.21.1 397 | - jsonschema-specifications==2023.12.1 398 | - jupyter-core==5.7.1 399 | - jupyterlab-widgets==3.0.10 400 | - nbformat==5.9.2 401 | - nest-asyncio==1.6.0 402 | - omegaconf==2.3.0 403 | - plotly==5.19.0 404 | - psutil==5.9.6 405 | - pyquaternion==0.9.9 406 | - referencing==0.33.0 407 | - retrying==1.3.4 408 | - rpds-py==0.18.0 409 | - scikit-learn==1.4.1.post1 410 | - sentry-sdk==1.38.0 411 | - setproctitle==1.3.3 412 | - smmap==5.0.1 413 | - tenacity==8.2.3 414 | - threadpoolctl==3.3.0 415 | - wandb==0.16.1 416 | - werkzeug==3.0.1 417 | - widgetsnbextension==4.0.10 418 | -------------------------------------------------------------------------------- /model/cross_reference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .customised_transformer.transformer import ( 3 | TransformerDecoderLayerCustomised, 4 | TransformerDecoderCustomised, 5 | ) 6 | from .regression_layer import RegressionLayer 7 | from utils.misc.image import jigsaw_to_image 8 | 9 | 10 | class CrossReferenceNet(torch.nn.Module): 11 | def __init__(self, cfg, dinov2_cfg): 12 | super().__init__() 13 | self.cfg = cfg 14 | self.dinov2_cfg = dinov2_cfg 15 | 16 | # set up input projection 17 | self.input_proj = torch.nn.Identity() 18 | 19 | # set up output final activation function 20 | self.final_activation_fn = RegressionLayer( 21 | metric_type=self.cfg.model.predict.metric.type, 22 | metric_min=self.cfg.model.predict.metric.min, 23 | metric_max=self.cfg.model.predict.metric.max, 24 | pow_factor=self.cfg.model.predict.metric.power_factor, 25 | ) 26 | 27 | # layers 28 | self.attn = TransformerDecoderCustomised( 29 | decoder_layer=TransformerDecoderLayerCustomised( 30 | d_model=self.dinov2_cfg.hidden_size, 31 | nhead=8, 32 | dim_feedforward=self.dinov2_cfg.hidden_size, 33 | dropout=0.0, 34 | batch_first=True, 35 | do_self_attn=self.cfg.model.decoder_do_self_attn, 36 | do_short_cut=self.cfg.model.decoder_do_short_cut, 37 | ), 38 | num_layers=2, 39 | ) 40 | 41 | # set up head out dimension 42 | out_size = cfg.model.patch_size**2 43 | 44 | # head 45 | self.head = torch.nn.Sequential( 46 | torch.nn.Linear(self.dinov2_cfg.hidden_size, self.dinov2_cfg.hidden_size), 47 | torch.nn.LeakyReLU(), 48 | torch.nn.Linear(self.dinov2_cfg.hidden_size, out_size), 49 | self.final_activation_fn, 50 | ) 51 | 52 | def forward( 53 | self, 54 | featmap_query, 55 | featmap_ref, 56 | memory_mask, 57 | dim_params, 58 | need_attn_weights, 59 | need_attn_weights_head_id, 60 | ): 61 | """ 62 | :param featmap_query: (B, num_patches, hidden_size) 63 | :param featmap_ref: (B, N_ref * num_patches, hidden_size) 64 | :param memory_mask: None 65 | :param dim_params: dict 66 | """ 67 | B = dim_params["B"] 68 | N_patch_h = dim_params["N_patch_h"] 69 | N_patch_w = dim_params["N_patch_w"] 70 | N_ref = dim_params["N_ref"] 71 | 72 | results = {} 73 | score_map, _, mha_weights = self.attn( 74 | tgt=featmap_query, 75 | memory=featmap_ref, 76 | memory_mask=memory_mask, 77 | need_weights=need_attn_weights, 78 | need_weights_head_id=need_attn_weights_head_id, 79 | ) # (B, num_patches_tgt, hidden_size), (B, num_patches_tgt, num_patches_mem) 80 | 81 | # return score map 82 | score_map = self.head(score_map) # (B, num_patches, num_ssim_pixels) 83 | 84 | # reshape to image size 85 | P = self.cfg.model.patch_size 86 | score_map = score_map.view(B, -1, P, P) 87 | score_map = jigsaw_to_image(score_map, grid_size=(N_patch_h, N_patch_w)) # (B, H, W) 88 | results["score_map"] = score_map 89 | 90 | # return attn weights 91 | if need_attn_weights: 92 | mha_weights = mha_weights.view(B, N_patch_h, N_patch_w, N_ref, N_patch_h, N_patch_w) 93 | results["attn_weights_map_mha"] = mha_weights 94 | return results 95 | -------------------------------------------------------------------------------- /model/customised_transformer/transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | from torch.nn.modules.transformer import ( 7 | TransformerDecoder, 8 | _get_seq_len, 9 | _detect_is_causal_mask, 10 | _get_activation_fn, 11 | ) 12 | from torch.nn.modules.activation import MultiheadAttention 13 | from torch.nn.modules.dropout import Dropout 14 | from torch.nn.modules.linear import Linear 15 | from torch.nn.modules.normalization import LayerNorm 16 | 17 | 18 | # fmt: off 19 | # Zirui: This class is copied and modified from TransformerDecoderLayer in pytorch 2.1.2 20 | class TransformerDecoderLayerCustomised(torch.nn.Module): 21 | r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. 22 | This standard decoder layer is based on the paper "Attention Is All You Need". 23 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 24 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 25 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 26 | in a different way during application. 27 | 28 | Args: 29 | d_model: the number of expected features in the input (required). 30 | nhead: the number of heads in the multiheadattention models (required). 31 | dim_feedforward: the dimension of the feedforward network model (default=2048). 32 | dropout: the dropout value (default=0.1). 33 | activation: the activation function of the intermediate layer, can be a string 34 | ("relu" or "gelu") or a unary callable. Default: relu 35 | layer_norm_eps: the eps value in layer normalization components (default=1e-5). 36 | batch_first: If ``True``, then the input and output tensors are provided 37 | as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 38 | norm_first: if ``True``, layer norm is done prior to self attention, multihead 39 | attention and feedforward operations, respectively. Otherwise it's done after. 40 | Default: ``False`` (after). 41 | bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive 42 | bias. Default: ``True``. 43 | 44 | Examples:: 45 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 46 | >>> memory = torch.rand(10, 32, 512) 47 | >>> tgt = torch.rand(20, 32, 512) 48 | >>> out = decoder_layer(tgt, memory) 49 | 50 | Alternatively, when ``batch_first`` is ``True``: 51 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) 52 | >>> memory = torch.rand(32, 10, 512) 53 | >>> tgt = torch.rand(32, 20, 512) 54 | >>> out = decoder_layer(tgt, memory) 55 | """ 56 | __constants__ = ['norm_first'] 57 | 58 | def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, 59 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 60 | layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, 61 | bias: bool = True, device=None, dtype=None, do_self_attn=True, do_short_cut=True) -> None: 62 | factory_kwargs = {'device': device, 'dtype': dtype} 63 | super().__init__() 64 | self.do_self_attn = do_self_attn # Zirui: added 65 | self.do_short_cut = do_short_cut # Zirui: added 66 | 67 | if self.do_self_attn: 68 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, 69 | bias=bias, **factory_kwargs) 70 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, 71 | bias=bias, **factory_kwargs) 72 | # Implementation of Feedforward model 73 | self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) 74 | self.dropout = Dropout(dropout) 75 | self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) 76 | 77 | self.norm_first = norm_first 78 | self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 79 | self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 80 | self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 81 | self.dropout1 = Dropout(dropout) 82 | self.dropout2 = Dropout(dropout) 83 | self.dropout3 = Dropout(dropout) 84 | 85 | # Legacy string support for activation function. 86 | if isinstance(activation, str): 87 | self.activation = _get_activation_fn(activation) 88 | else: 89 | self.activation = activation 90 | 91 | def __setstate__(self, state): 92 | if 'activation' not in state: 93 | state['activation'] = F.relu 94 | super().__setstate__(state) 95 | 96 | def forward( 97 | self, 98 | tgt: Tensor, 99 | memory: Tensor, 100 | tgt_mask: Optional[Tensor] = None, 101 | memory_mask: Optional[Tensor] = None, 102 | tgt_key_padding_mask: Optional[Tensor] = None, 103 | memory_key_padding_mask: Optional[Tensor] = None, 104 | tgt_is_causal: bool = False, 105 | memory_is_causal: bool = False, 106 | need_weights=True, 107 | need_weights_head_id=0, 108 | ): 109 | r"""Pass the inputs (and mask) through the decoder layer. 110 | 111 | Args: 112 | tgt: the sequence to the decoder layer (required). 113 | memory: the sequence from the last layer of the encoder (required). 114 | tgt_mask: the mask for the tgt sequence (optional). 115 | memory_mask: the mask for the memory sequence (optional). 116 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 117 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 118 | tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. 119 | Default: ``False``. 120 | Warning: 121 | ``tgt_is_causal`` provides a hint that ``tgt_mask`` is 122 | the causal mask. Providing incorrect hints can result in 123 | incorrect execution, including forward and backward 124 | compatibility. 125 | memory_is_causal: If specified, applies a causal mask as 126 | ``memory mask``. 127 | Default: ``False``. 128 | Warning: 129 | ``memory_is_causal`` provides a hint that 130 | ``memory_mask`` is the causal mask. Providing incorrect 131 | hints can result in incorrect execution, including 132 | forward and backward compatibility. 133 | 134 | Shape: 135 | see the docs in Transformer class. 136 | """ 137 | # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf 138 | 139 | x = tgt 140 | if self.norm_first: 141 | if self.do_self_attn: 142 | sa_out, sa_weights = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal, need_weights) 143 | if self.do_short_cut: 144 | x = x + sa_out 145 | else: 146 | x = sa_out 147 | else: 148 | sa_weights = None 149 | 150 | mha_out, mha_weights = self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal, need_weights) 151 | if self.do_short_cut: 152 | x = x + mha_out 153 | else: 154 | x = mha_out 155 | 156 | x = x + self._ff_block(self.norm3(x)) 157 | else: 158 | if self.do_self_attn: 159 | sa_out, sa_weights = self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal, need_weights) 160 | if self.do_short_cut: 161 | x = self.norm1(x + sa_out) 162 | else: 163 | x = self.norm1(sa_out) 164 | else: 165 | sa_weights = None 166 | 167 | mha_out, mha_weights = self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal, need_weights) 168 | if self.do_short_cut: 169 | x = self.norm2(x + mha_out) 170 | else: 171 | x = self.norm2(mha_out) 172 | 173 | x = self.norm3(x + self._ff_block(x)) 174 | 175 | if sa_weights is not None: 176 | sa_weights = sa_weights[:, need_weights_head_id] # return attn weights of a specific head 177 | if mha_weights is not None: 178 | mha_weights = mha_weights[:, need_weights_head_id] # return attn weights of a specific head 179 | return x, sa_weights, mha_weights 180 | 181 | # self-attention block 182 | def _sa_block(self, x: Tensor, 183 | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False, need_weights: bool = True): 184 | x, attn_weights = self.self_attn(x, x, x, 185 | attn_mask=attn_mask, 186 | key_padding_mask=key_padding_mask, 187 | is_causal=is_causal, 188 | need_weights=need_weights, 189 | average_attn_weights=False,) 190 | if need_weights: 191 | attn_weights = attn_weights.detach() # attn weights of all heads 192 | return self.dropout1(x), attn_weights 193 | 194 | # multihead attention block 195 | def _mha_block(self, x: Tensor, mem: Tensor, 196 | attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False, need_weights: bool = True): 197 | x, attn_weights = self.multihead_attn(x, mem, mem, 198 | attn_mask=attn_mask, 199 | key_padding_mask=key_padding_mask, 200 | is_causal=is_causal, 201 | need_weights=need_weights, 202 | average_attn_weights=False,) 203 | if need_weights: 204 | attn_weights = attn_weights.detach() # attn weights of all heads 205 | return self.dropout2(x), attn_weights 206 | 207 | # feed forward block 208 | def _ff_block(self, x: Tensor) -> Tensor: 209 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 210 | return self.dropout3(x) 211 | 212 | 213 | class TransformerDecoderCustomised(TransformerDecoder): 214 | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, 215 | memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, 216 | memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None, 217 | memory_is_causal: bool = False, need_weights: bool = True, need_weights_head_id: int = 0): 218 | r"""Pass the inputs (and mask) through the decoder layer in turn. 219 | 220 | Args: 221 | tgt: the sequence to the decoder (required). 222 | memory: the sequence from the last layer of the encoder (required). 223 | tgt_mask: the mask for the tgt sequence (optional). 224 | memory_mask: the mask for the memory sequence (optional). 225 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 226 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 227 | tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. 228 | Default: ``None``; try to detect a causal mask. 229 | Warning: 230 | ``tgt_is_causal`` provides a hint that ``tgt_mask`` is 231 | the causal mask. Providing incorrect hints can result in 232 | incorrect execution, including forward and backward 233 | compatibility. 234 | memory_is_causal: If specified, applies a causal mask as 235 | ``memory mask``. 236 | Default: ``False``. 237 | Warning: 238 | ``memory_is_causal`` provides a hint that 239 | ``memory_mask`` is the causal mask. Providing incorrect 240 | hints can result in incorrect execution, including 241 | forward and backward compatibility. 242 | 243 | Shape: 244 | see the docs in Transformer class. 245 | """ 246 | output = tgt 247 | 248 | seq_len = _get_seq_len(tgt, self.layers[0].multihead_attn.batch_first) 249 | tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) 250 | 251 | for mod in self.layers: 252 | output, sa_weights, mha_weights = mod( 253 | output, memory, tgt_mask=tgt_mask, 254 | memory_mask=memory_mask, 255 | tgt_key_padding_mask=tgt_key_padding_mask, 256 | memory_key_padding_mask=memory_key_padding_mask, 257 | tgt_is_causal=tgt_is_causal, 258 | memory_is_causal=memory_is_causal, 259 | need_weights=need_weights, 260 | need_weights_head_id=need_weights_head_id, 261 | ) 262 | 263 | if self.norm is not None: 264 | output = self.norm(output) 265 | 266 | # Only return the last layer's attention weights. 267 | # If need_weights is False, they are None. 268 | return output, sa_weights, mha_weights 269 | -------------------------------------------------------------------------------- /model/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MultiViewPosionalEmbeddings(torch.nn.Module): 5 | 6 | def __init__( 7 | self, 8 | positional_encoding_h, 9 | positional_encoding_w, 10 | interpolate_mode, 11 | req_grad, 12 | patch_size=14, 13 | hidden_size=384, 14 | ): 15 | """Apply positional encoding to input multi-view embeddings. 16 | Conceptually, posional encoding are in (pe_h, pe_w, C) shape. 17 | We can interpolate in 2D to adapt to different input image size, but 18 | no interpolation in the view dimension. 19 | 20 | Shorthand: 21 | P: patch size 22 | C: hidden size 23 | N: number of 24 | pe: positional encoding 25 | mv: multi-view 26 | emb: embedding 27 | 28 | :param patch_size: DINOv2 patch size def 14 29 | :param hidden_size: DINOv2 hidden size def 384 (dinov2_small) 30 | """ 31 | super().__init__() 32 | self.P = patch_size 33 | self.C = hidden_size 34 | self.pe_h = positional_encoding_h 35 | self.pe_w = positional_encoding_w 36 | self.interpolate_mode = interpolate_mode 37 | self.PE = torch.nn.Parameter( 38 | torch.randn(1, self.pe_h, self.pe_w, self.C), 39 | req_grad, 40 | ) 41 | 42 | def forward(self, mv_emb, N_view, img_h, img_w): 43 | """ 44 | :param mv_emb: (B, N_patch, C), N_patch: N_view * emb_h * emb_w 45 | """ 46 | B = mv_emb.shape[0] 47 | emb_h = img_h // self.P 48 | emb_w = img_w // self.P 49 | 50 | # a short cut no need to interpolate 51 | if emb_h == self.pe_h and emb_w == self.pe_w: 52 | # no need to interpolate 53 | mv_emb = mv_emb.view(B, N_view, emb_h, emb_w, self.C) 54 | mv_emb = mv_emb + self.PE 55 | mv_emb = mv_emb.view(B, N_view * emb_h * emb_w, self.C) 56 | return mv_emb 57 | 58 | # 1. interpolate position embedding 59 | # we add a small number to avoid floating point error in the interpolation 60 | # see discussion at https://github.com/facebookresearch/dino/issues/8 61 | _PE = torch.nn.functional.interpolate( 62 | self.PE.permute(0, 3, 1, 2), 63 | scale_factor=( 64 | (emb_h + 1e-4) / self.pe_h, 65 | (emb_w + 1e-4) / self.pe_w, 66 | ), 67 | mode=self.interpolate_mode, 68 | align_corners=True, 69 | ) # (1, C, emb_h, emb_w) 70 | 71 | # 2. embed and reshape back 72 | mv_emb = mv_emb.view(B, N_view, emb_h, emb_w, self.C) 73 | mv_emb = mv_emb + _PE.permute(0, 2, 3, 1)[None] 74 | mv_emb = mv_emb.reshape(B, N_view * emb_h * emb_w, self.C) 75 | return mv_emb 76 | -------------------------------------------------------------------------------- /model/regression_layer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import sys 3 | from pathlib import Path 4 | import torch 5 | 6 | sys.path.append(str(Path(__file__).parents[1])) 7 | from utils.check_config import check_metric_prediction_config 8 | 9 | 10 | class RegressionLayer(torch.nn.Module): 11 | def __init__(self, metric_type, metric_min, metric_max, pow_factor="default"): 12 | """ 13 | Make a regression layer based on the metric configuration. 14 | Use power_factor to help predict very small numbers. 15 | """ 16 | super().__init__() 17 | 18 | check_metric_prediction_config(metric_type, metric_min, metric_max) 19 | self.metric_type = metric_type 20 | self.metric_min = metric_min 21 | self.metric_max = metric_max 22 | 23 | self.activation_fn = self._get_activation_fn() 24 | self.pow_fn = self._get_pow_fn(pow_factor) 25 | 26 | def forward(self, x): 27 | x = self.activation_fn(x) 28 | x = self.pow_fn(x) 29 | return x 30 | 31 | def _get_activation_fn(self): 32 | if self.metric_min == -1: 33 | activation_fn = torch.nn.Tanh() 34 | elif self.metric_min == 0: 35 | activation_fn = torch.nn.Sigmoid() 36 | else: 37 | raise ValueError(f"metric_min={self.metric_min} not supported") 38 | return activation_fn 39 | 40 | def _get_pow_fn(self, p): 41 | # define a lookup table for default power factor 42 | pow_default_table = { 43 | "ssim": 1, 44 | "mae": 2, 45 | "mse": 4, 46 | } 47 | 48 | # only apply power fn for a non-negative score value range 49 | if self.metric_min == 0: 50 | if p == "default": 51 | # use default power factor from the look up table 52 | p = pow_default_table[self.metric_type] 53 | else: 54 | pass # use the provided power factor 55 | else: 56 | p = 1 57 | 58 | if float(p) == 1.0: 59 | pow_fn = torch.nn.Identity() 60 | else: 61 | pow_fn = partial(torch.pow, exponent=p) 62 | return pow_fn 63 | 64 | 65 | if __name__ == "__main__": 66 | for metric_type in ["ssim", "mae", "mse"]: 67 | for metric_min in [-1, 0]: 68 | for p in ["some_typo", "default", 0.1, 1, 1.5, 5]: 69 | print(f"--------") 70 | print(f"metric_type: {metric_type}, metric_min: {metric_min}, pow_factor: {p}") 71 | try: 72 | l = RegressionLayer( 73 | metric_type=metric_type, 74 | metric_min=metric_min, 75 | metric_max=1, 76 | pow_factor=p, 77 | ) 78 | print(f"activation_fn: {l.activation_fn}") 79 | print(f"pow_fn: {l.pow_fn}") 80 | except Exception as e: 81 | print(f"Error: {e}") 82 | -------------------------------------------------------------------------------- /predict.sh: -------------------------------------------------------------------------------- 1 | # Note, in this example: 2 | # - query images are NVS RENDERED images from gaussian-splatting TEST split 3 | # - reference images are REAL CAPTURED images from gaussian-splatting TRAIN split 4 | 5 | ckpt_path=ckpt/CrossScore-v1.0.0.ckpt 6 | data_dir=datadir/processed_training_ready/gaussian/map-free-reloc/res_540 7 | 8 | for scene_name in s00076 s00231; do 9 | 10 | query_dir=$data_dir/$scene_name/test/ours_15000/renders 11 | reference_dir=$data_dir/$scene_name/train/ours_15000/gt 12 | 13 | python task/predict.py \ 14 | trainer.devices=[0] \ 15 | trainer.ckpt_path_to_load=ckpt/CrossScore-v1.0.0.ckpt \ 16 | data.dataset.query_dir=$query_dir \ 17 | data.dataset.reference_dir=$reference_dir \ 18 | alias=$scene_name 19 | done 20 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | "datadir/**", 4 | "log*/**" 5 | ], 6 | "ignore": [ 7 | "datadir/**", 8 | "log*/**" 9 | ] 10 | } -------------------------------------------------------------------------------- /task/core.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | import lightning 4 | import wandb 5 | from transformers import Dinov2Config, Dinov2Model 6 | from omegaconf import DictConfig, OmegaConf 7 | from lightning.pytorch.utilities import rank_zero_only 8 | from utils.evaluation.metric import abs2psnr, correlation 9 | from utils.evaluation.metric_logger import ( 10 | MetricLoggerScalar, 11 | MetricLoggerHistogram, 12 | MetricLoggerCorrelation, 13 | MetricLoggerImg, 14 | ) 15 | from utils.plot.batch_visualiser import BatchVisualiserFactory 16 | from utils.io.images import ImageNetMeanStd 17 | from utils.io.batch_writer import BatchWriter 18 | from utils.io.score_summariser import ( 19 | SummaryWriterPredictedOnline, 20 | SummaryWriterPredictedOnlineTestPrediction, 21 | ) 22 | from model.cross_reference import CrossReferenceNet 23 | from model.positional_encoding import MultiViewPosionalEmbeddings 24 | 25 | 26 | class CrossScoreNet(torch.nn.Module): 27 | def __init__(self, cfg): 28 | super().__init__() 29 | self.cfg = cfg 30 | 31 | # used in 1. denormalising images for visualisation 32 | # and 2. normalising images for training when required 33 | img_norm_stat = ImageNetMeanStd() 34 | self.register_buffer( 35 | "img_mean_std", torch.tensor([*img_norm_stat.mean, *img_norm_stat.std]) 36 | ) 37 | 38 | # backbone, freeze 39 | self.dinov2_cfg = Dinov2Config.from_pretrained(self.cfg.model.backbone.from_pretrained) 40 | self.backbone = Dinov2Model.from_pretrained(self.cfg.model.backbone.from_pretrained) 41 | for param in self.backbone.parameters(): 42 | param.requires_grad = False 43 | 44 | # positional encoding layer 45 | self.pos_enc_fn = MultiViewPosionalEmbeddings( 46 | positional_encoding_h=self.cfg.model.pos_enc.multi_view.h, 47 | positional_encoding_w=self.cfg.model.pos_enc.multi_view.w, 48 | interpolate_mode=self.cfg.model.pos_enc.multi_view.interpolate_mode, 49 | req_grad=self.cfg.model.pos_enc.multi_view.req_grad, 50 | patch_size=self.cfg.model.patch_size, 51 | hidden_size=self.dinov2_cfg.hidden_size, 52 | ) 53 | 54 | # cross reference predictor 55 | if self.cfg.model.do_reference_cross: 56 | self.ref_cross = CrossReferenceNet(cfg=self.cfg, dinov2_cfg=self.dinov2_cfg) 57 | 58 | def forward( 59 | self, 60 | query_img, 61 | ref_cross_imgs, 62 | need_attn_weights, 63 | need_attn_weights_head_id, 64 | norm_img, 65 | ): 66 | """ 67 | :param query_img: (B, 3, H, W) 68 | :param ref_cross_imgs: (B, N_ref_cross, 3, H, W) 69 | :param norm_img: bool, normalise an image with pixel value in [0, 1] with imagenet mean and std. 70 | """ 71 | B = query_img.shape[0] 72 | H, W = query_img.shape[-2:] 73 | N_patch_h = H // self.cfg.model.patch_size 74 | N_patch_w = W // self.cfg.model.patch_size 75 | 76 | if norm_img: 77 | img_mean = self.img_mean_std[None, :3, None, None] 78 | img_std = self.img_mean_std[None, :3, None, None] 79 | query_img = (query_img - img_mean) / img_std 80 | if ref_cross_imgs is not None: 81 | ref_cross_imgs = (ref_cross_imgs - img_mean[:, None]) / img_std[:, None] 82 | 83 | featmaps = self.get_featmaps(query_img, ref_cross_imgs) 84 | results = {} 85 | 86 | # processing (and predicting) for query 87 | featmaps["query"] = self.pos_enc_fn(featmaps["query"], N_view=1, img_h=H, img_w=W) 88 | 89 | if self.cfg.model.do_reference_cross: 90 | N_ref_cross = ref_cross_imgs.shape[1] 91 | 92 | # (B, N_ref_cross*num_patches, hidden_size) 93 | featmaps["ref_cross"] = self.pos_enc_fn( 94 | featmaps["ref_cross"], 95 | N_view=N_ref_cross, 96 | img_h=H, 97 | img_w=W, 98 | ) 99 | 100 | # prediction 101 | dim_params = { 102 | "B": B, 103 | "N_patch_h": N_patch_h, 104 | "N_patch_w": N_patch_w, 105 | "N_ref": N_ref_cross, 106 | } 107 | results_ref_cross = self.ref_cross( 108 | featmaps["query"], 109 | featmaps["ref_cross"], 110 | None, 111 | dim_params, 112 | need_attn_weights, 113 | need_attn_weights_head_id, 114 | ) 115 | results["score_map_ref_cross"] = results_ref_cross["score_map"] 116 | results["attn_weights_map_ref_cross"] = results_ref_cross["attn_weights_map_mha"] 117 | return results 118 | 119 | @torch.no_grad() 120 | def get_featmaps(self, query_img, ref_cross_imgs): 121 | """ 122 | :param query_img: (B, 3, H, W) 123 | :param ref_cross: (B, N_ref_cross, 3, H, W) 124 | """ 125 | B = query_img.shape[0] 126 | H, W = query_img.shape[-2:] 127 | N_patch_h = H // self.cfg.model.patch_size 128 | N_patch_w = W // self.cfg.model.patch_size 129 | N_query = 1 130 | N_ref_cross = 0 if ref_cross_imgs is None else ref_cross_imgs.shape[1] 131 | N_all_imgs = N_query + N_ref_cross 132 | 133 | # concat all images to go through backbone for once 134 | all_imgs = [query_img.view(B, 1, 3, H, W)] 135 | if ref_cross_imgs is not None: 136 | all_imgs.append(ref_cross_imgs) 137 | all_imgs = torch.cat(all_imgs, dim=1) 138 | all_imgs = all_imgs.view(B * N_all_imgs, 3, H, W) 139 | 140 | # bbo: backbone output 141 | bbo_all = self.backbone(all_imgs) 142 | featmap_all = bbo_all.last_hidden_state[:, 1:] 143 | featmap_all = featmap_all.view(B, N_all_imgs, N_patch_h * N_patch_w, -1) 144 | 145 | # query 146 | featmap_query = featmap_all[:, 0] # (B, num_patches, hidden_size) 147 | N_patches = featmap_query.shape[1] 148 | hidden_size = featmap_query.shape[2] 149 | 150 | # cross ref 151 | if ref_cross_imgs is not None: 152 | featmap_ref_cross = featmap_all[:, -N_ref_cross:] 153 | featmap_ref_cross = featmap_ref_cross.reshape(B, N_ref_cross * N_patches, hidden_size) 154 | else: 155 | featmap_ref_cross = None 156 | 157 | featmaps = { 158 | "query": featmap_query, # (B, num_patches, hidden_size) 159 | "ref_cross": featmap_ref_cross, # (B, N_ref_cross*num_patches, hidden_size) 160 | } 161 | return featmaps 162 | 163 | 164 | class CrossScoreLightningModule(lightning.LightningModule): 165 | def __init__(self, cfg: DictConfig): 166 | super().__init__() 167 | self.cfg = cfg 168 | 169 | # write config to wandb 170 | self.save_hyperparameters(OmegaConf.to_container(self.cfg, resolve=True)) 171 | 172 | # init my network 173 | self.model = CrossScoreNet(cfg=self.cfg) 174 | 175 | # init visualiser 176 | self.visualiser = BatchVisualiserFactory(self.cfg, self.model.img_mean_std)() 177 | 178 | # init loss fn 179 | if self.cfg.model.loss.fn == "l1": 180 | self.loss_fn = torch.nn.L1Loss() 181 | self.to_psnr_fn = abs2psnr 182 | else: 183 | raise NotImplementedError 184 | 185 | # logging related names 186 | self.ref_mode_names = [] 187 | if self.cfg.model.do_reference_cross: 188 | self.ref_mode_names.append("ref_cross") 189 | 190 | def on_fit_start(self): 191 | # reset logging cache 192 | if self.global_rank == 0: 193 | self._reset_logging_cache_train() 194 | self._reset_logging_cache_validation() 195 | 196 | self.frame_score_summariser = SummaryWriterPredictedOnline( 197 | metric_type=self.cfg.model.predict.metric.type, 198 | metric_min=self.cfg.model.predict.metric.min, 199 | ) 200 | 201 | def on_test_start(self): 202 | Path(self.cfg.logger.test.out_dir, "vis").mkdir(parents=True, exist_ok=True) 203 | if self.cfg.logger.test.write.flag.batch: 204 | self.batch_writer = BatchWriter(self.cfg, "test", self.model.img_mean_std) 205 | else: 206 | self.batch_writer = None 207 | 208 | self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( 209 | metric_type=self.cfg.model.predict.metric.type, 210 | metric_min=self.cfg.model.predict.metric.min, 211 | dir_out=self.cfg.logger.test.out_dir, 212 | ) 213 | 214 | def on_predict_start(self): 215 | Path(self.cfg.logger.predict.out_dir, "vis").mkdir(parents=True, exist_ok=True) 216 | if self.cfg.logger.predict.write.flag.batch: 217 | self.batch_writer = BatchWriter(self.cfg, "predict", self.model.img_mean_std) 218 | else: 219 | self.batch_writer = None 220 | 221 | self.frame_score_summariser = SummaryWriterPredictedOnlineTestPrediction( 222 | metric_type=self.cfg.model.predict.metric.type, 223 | metric_min=self.cfg.model.predict.metric.min, 224 | dir_out=self.cfg.logger.predict.out_dir, 225 | ) 226 | 227 | def _reset_logging_cache_train(self): 228 | self.train_cache = { 229 | "loss": { 230 | k: MetricLoggerScalar(max_length=self.cfg.logger.cache_size.train.n_scalar) 231 | for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names 232 | }, 233 | "correlation": { 234 | k: MetricLoggerCorrelation(max_length=self.cfg.logger.cache_size.train.n_scalar) 235 | for k in self.ref_mode_names 236 | }, 237 | "map": { 238 | "score": { 239 | k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) 240 | for k in self.ref_mode_names 241 | }, 242 | "l1_diff": { 243 | k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) 244 | for k in self.ref_mode_names 245 | }, 246 | "delta": { 247 | k: MetricLoggerHistogram(max_length=self.cfg.logger.cache_size.train.n_scalar) 248 | for k in ["self", "cross"] 249 | }, 250 | }, 251 | } 252 | 253 | def _reset_logging_cache_validation(self): 254 | self.validation_cache = { 255 | "loss": { 256 | k: MetricLoggerScalar(max_length=None) 257 | for k in ["final", "reg_self", "reg_cross"] + self.ref_mode_names 258 | }, 259 | "correlation": { 260 | k: MetricLoggerCorrelation(max_length=None) for k in self.ref_mode_names 261 | }, 262 | "fig": {k: MetricLoggerImg(max_length=None) for k in ["batch"]}, 263 | } 264 | 265 | def _core_step(self, batch, batch_idx, skip_loss=False): 266 | outputs = self.model( 267 | query_img=batch["query/img"], # (B, C, H, W) 268 | ref_cross_imgs=batch.get("reference/cross/imgs", None), # (B, N_ref_cross, C, H, W) 269 | need_attn_weights=self.cfg.model.need_attn_weights, 270 | need_attn_weights_head_id=self.cfg.model.need_attn_weights_head_id, 271 | norm_img=False, 272 | ) 273 | 274 | if skip_loss: # only used in predict_step 275 | return outputs 276 | 277 | score_map = batch["query/score_map"] # (B, H, W) 278 | 279 | loss = [] 280 | # cross reference model predicts 281 | if self.cfg.model.do_reference_cross: 282 | score_map_cross = outputs["score_map_ref_cross"] # (B, H, W) 283 | l1_diff_map_cross = torch.abs(score_map_cross - score_map) # (B, H, W) 284 | if self.cfg.model.loss.fn == "l1": 285 | loss_cross = l1_diff_map_cross.mean() 286 | else: 287 | loss_cross = self.loss_fn(score_map_cross, score_map) 288 | outputs["loss_cross"] = loss_cross 289 | outputs["l1_diff_map_ref_cross"] = l1_diff_map_cross 290 | loss.append(loss_cross) 291 | 292 | loss = torch.stack(loss).sum() 293 | outputs["loss"] = loss 294 | return outputs 295 | 296 | def training_step(self, batch, batch_idx): 297 | outputs = self._core_step(batch, batch_idx) 298 | return outputs 299 | 300 | def validation_step(self, batch, batch_idx): 301 | outputs = self._core_step(batch, batch_idx) 302 | return outputs 303 | 304 | def test_step(self, batch, batch_idx): 305 | outputs = self._core_step(batch, batch_idx) 306 | return outputs 307 | 308 | def predict_step(self, batch, batch_idx): 309 | outputs = self._core_step(batch, batch_idx, skip_loss=True) 310 | return outputs 311 | 312 | @rank_zero_only 313 | def on_train_batch_end(self, outputs, batch, batch_idx): 314 | self.train_cache["loss"]["final"].update(outputs["loss"]) 315 | 316 | if self.cfg.model.do_reference_cross: 317 | self.train_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) 318 | self.train_cache["correlation"]["ref_cross"].update( 319 | outputs["score_map_ref_cross"], batch["query/score_map"] 320 | ) 321 | self.train_cache["map"]["score"]["ref_cross"].update(outputs["score_map_ref_cross"]) 322 | self.train_cache["map"]["l1_diff"]["ref_cross"].update(outputs["l1_diff_map_ref_cross"]) 323 | 324 | # logger vis batch 325 | if self.global_step % self.cfg.logger.vis_imgs_every_n_train_steps == 0: 326 | fig = self.visualiser.vis(batch, outputs) 327 | self.logger.experiment.log({"train_batch": fig}) 328 | 329 | # logger vis X batches statics 330 | if self.global_step % self.cfg.logger.vis_scalar_every_n_train_steps == 0: 331 | # log loss 332 | tmp_loss = self.train_cache["loss"]["final"].compute() 333 | self.log("train/loss", tmp_loss, prog_bar=True) 334 | 335 | if self.cfg.model.do_reference_cross: 336 | tmp_loss_cross = self.train_cache["loss"]["ref_cross"].compute() 337 | self.log("train/loss_cross", tmp_loss_cross) 338 | 339 | # log psnr 340 | if self.cfg.model.do_reference_cross: 341 | self.log("train/psnr_cross", self.to_psnr_fn(tmp_loss_cross)) 342 | 343 | # log correlation 344 | if self.cfg.model.do_reference_cross: 345 | self.log( 346 | "train/correlation_cross", 347 | self.train_cache["correlation"]["ref_cross"].compute(), 348 | ) 349 | 350 | # logger vis X batches histogram 351 | if self.global_step % self.cfg.logger.vis_histogram_every_n_train_steps == 0: 352 | if self.cfg.model.do_reference_cross: 353 | self.logger.experiment.log( 354 | { 355 | "train/score_histogram_cross": wandb.Histogram( 356 | np_histogram=self.train_cache["map"]["score"]["ref_cross"].compute() 357 | ), 358 | "train/l1_diff_histogram_cross": wandb.Histogram( 359 | np_histogram=self.train_cache["map"]["l1_diff"]["ref_cross"].compute() 360 | ), 361 | } 362 | ) 363 | 364 | def on_validation_batch_end(self, outputs, batch, batch_idx): 365 | self.validation_cache["loss"]["final"].update(outputs["loss"]) 366 | 367 | if self.cfg.model.do_reference_cross: 368 | self.validation_cache["loss"]["ref_cross"].update(outputs["loss_cross"]) 369 | self.validation_cache["correlation"]["ref_cross"].update( 370 | outputs["score_map_ref_cross"], batch["query/score_map"] 371 | ) 372 | 373 | self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) 374 | 375 | if batch_idx < self.cfg.logger.cache_size.validation.n_fig: 376 | fig = self.visualiser.vis(batch, outputs) 377 | self.validation_cache["fig"]["batch"].update(fig) 378 | 379 | def on_test_batch_end(self, outputs, batch, batch_idx): 380 | results = {"test/loss": outputs["loss"]} 381 | 382 | if self.cfg.model.do_reference_cross: 383 | corr = correlation(outputs["score_map_ref_cross"], batch["query/score_map"]) 384 | psnr = self.to_psnr_fn(outputs["loss_cross"]) 385 | results["test/loss_cross"] = outputs["loss_cross"] 386 | results["test/corr_cross"] = corr 387 | results["test/psnr_cross"] = psnr 388 | 389 | self.log_dict( 390 | results, 391 | on_step=self.cfg.logger.test.on_step, 392 | sync_dist=self.cfg.logger.test.sync_dist, 393 | ) 394 | 395 | self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) 396 | 397 | # write image to vis 398 | if ( 399 | self.cfg.logger.test.write.config.vis_img_every_n_steps > 0 400 | and batch_idx % self.cfg.logger.test.write.config.vis_img_every_n_steps == 0 401 | ): 402 | fig = self.visualiser.vis(batch, outputs) 403 | fig.image.save( 404 | Path( 405 | self.cfg.logger.test.out_dir, 406 | "vis", 407 | f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", 408 | ) 409 | ) 410 | 411 | if self.cfg.logger.test.write.flag.batch: 412 | self.batch_writer.write_out( 413 | batch_input=batch, 414 | batch_output=outputs, 415 | local_rank=self.local_rank, 416 | batch_idx=batch_idx, 417 | ) 418 | 419 | def on_predict_batch_end(self, outputs, batch, batch_idx): 420 | self.frame_score_summariser.update(batch_input=batch, batch_output=outputs) 421 | 422 | # write image to vis 423 | if ( 424 | self.cfg.logger.predict.write.config.vis_img_every_n_steps > 0 425 | and batch_idx % self.cfg.logger.predict.write.config.vis_img_every_n_steps == 0 426 | ): 427 | fig = self.visualiser.vis(batch, outputs) 428 | fig.image.save( 429 | Path( 430 | self.cfg.logger.predict.out_dir, 431 | "vis", 432 | f"r{self.local_rank}_B{str(batch_idx).zfill(4)}_b{0}.png", 433 | ) 434 | ) 435 | 436 | if self.cfg.logger.predict.write.flag.batch: 437 | self.batch_writer.write_out( 438 | batch_input=batch, 439 | batch_output=outputs, 440 | local_rank=self.local_rank, 441 | batch_idx=batch_idx, 442 | ) 443 | 444 | @rank_zero_only 445 | def on_train_epoch_end(self): 446 | self._reset_logging_cache_train() 447 | 448 | def on_validation_epoch_end(self): 449 | sync_dist = True 450 | self.log( 451 | "validation/loss", 452 | self.validation_cache["loss"]["final"].compute(), 453 | prog_bar=True, 454 | sync_dist=sync_dist, 455 | ) 456 | self.logger.experiment.log( 457 | {"validation_batch": self.validation_cache["fig"]["batch"].compute()}, 458 | ) 459 | 460 | if self.cfg.model.do_reference_cross: 461 | self.log( 462 | "validation/loss_cross", 463 | self.validation_cache["loss"]["ref_cross"].compute(), 464 | sync_dist=sync_dist, 465 | ) 466 | self.log( 467 | "validation/correlation_cross", 468 | self.validation_cache["correlation"]["ref_cross"].compute(), 469 | sync_dist=sync_dist, 470 | ) 471 | self.log( 472 | "validation/psnr_cross", 473 | self.to_psnr_fn(self.validation_cache["loss"]["ref_cross"].compute()), 474 | sync_dist=sync_dist, 475 | ) 476 | 477 | self._reset_logging_cache_validation() 478 | self.frame_score_summariser.reset() 479 | 480 | def on_test_epoch_end(self): 481 | self.frame_score_summariser.summarise() 482 | 483 | def on_predict_epoch_end(self): 484 | self.frame_score_summariser.summarise() 485 | 486 | def configure_optimizers(self): 487 | # how to use configure_optimizers: 488 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers 489 | 490 | # we freeze backbone and we only pass parameters that requires grad to optimizer: 491 | # https://discuss.pytorch.org/t/how-to-train-a-part-of-a-network/8923 492 | # https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractor 493 | # https://discuss.pytorch.org/t/for-freezing-certain-layers-why-do-i-need-a-two-step-process/175289/2 494 | parameters = [p for p in self.model.parameters() if p.requires_grad] 495 | optimizer = torch.optim.AdamW( 496 | params=parameters, 497 | lr=self.cfg.trainer.optimizer.lr, 498 | ) 499 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 500 | optimizer, 501 | step_size=self.cfg.trainer.lr_scheduler.step_size, 502 | gamma=self.cfg.trainer.lr_scheduler.gamma, 503 | ) 504 | 505 | results = { 506 | "optimizer": optimizer, 507 | "lr_scheduler": { 508 | "scheduler": lr_scheduler, 509 | "interval": self.cfg.trainer.lr_scheduler.step_interval, 510 | "frequency": 1, 511 | }, 512 | } 513 | return results 514 | -------------------------------------------------------------------------------- /task/predict.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import sys 3 | from pathlib import Path 4 | 5 | sys.path.append(str(Path(__file__).parents[1])) 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torchvision.transforms import v2 as T 10 | import lightning 11 | from lightning.pytorch.strategies import DDPStrategy 12 | import hydra 13 | from omegaconf import DictConfig, open_dict 14 | 15 | from core import CrossScoreLightningModule 16 | from dataloading.dataset.simple_reference import SimpleReference 17 | from dataloading.transformation.crop import CropperFactory 18 | from utils.io.images import ImageNetMeanStd 19 | 20 | 21 | @hydra.main(version_base="1.3", config_path="../config", config_name="default_predict") 22 | def predict(cfg: DictConfig): 23 | lightning.seed_everything(cfg.lightning.seed, workers=True) 24 | NUM_GPUS = len(cfg.trainer.devices) 25 | 26 | # double check batch size with user 27 | if ( 28 | cfg.this_main.force_batch_size is False 29 | and cfg.data.loader.validation.batch_size > 8 30 | and cfg.this_main.crop_mode is None 31 | ): 32 | tmp = input( 33 | "Test full image resolution in a large batch size " 34 | f"{cfg.data.loader.validation.batch_size}. " 35 | "Press Enter to continue, or enter a new batch size: " 36 | ) 37 | if tmp == "": 38 | pass 39 | elif tmp.isdigit(): 40 | new_batch_size = int(tmp) 41 | with open_dict(cfg): 42 | cfg.data.loader.validation.batch_size = new_batch_size 43 | print(f"Set batch size to {new_batch_size}") 44 | else: 45 | raise ValueError("Invalid input") 46 | 47 | if cfg.trainer.ckpt_path_to_load is None: 48 | # no ckpt, use current timestamp as log dir 49 | now = datetime.now().strftime("%Y%m%d_%H%M%S.%f") 50 | log_dir = Path("log") / now 51 | test_dir_name = "predict_empty_ckpt" 52 | else: 53 | # has ckpt, use the log dir of the ckpt path 54 | log_dir = Path(cfg.trainer.ckpt_path_to_load).parents[1] 55 | test_dir_name = "predict" 56 | log_dir = log_dir / test_dir_name 57 | log_dir.mkdir(parents=True, exist_ok=True) 58 | 59 | # add logdir to cfg so lightning model can write to it 60 | with open_dict(cfg): 61 | if cfg.logger.predict.out_dir is None: 62 | now = datetime.now().strftime("%Y%m%d_%H%M%S.%f") 63 | cfg.logger.predict.out_dir = f"{log_dir}/{now}" 64 | if cfg.alias != "": 65 | cfg.logger.predict.out_dir += f"_{cfg.alias}" 66 | 67 | # init dataset 68 | img_norm_stat = ImageNetMeanStd() 69 | transforms = { 70 | "img": T.Normalize( 71 | mean=img_norm_stat.mean, 72 | std=img_norm_stat.std, 73 | ), 74 | } 75 | 76 | if cfg.this_main.crop_mode == "dataset_default": 77 | transforms["query_crop"] = CropperFactory( 78 | output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), 79 | same_on_batch=True, 80 | deterministic=True, 81 | ) 82 | transforms["reference_crop"] = CropperFactory( 83 | output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), 84 | same_on_batch=False, 85 | deterministic=True, 86 | ) 87 | 88 | if cfg.this_main.resize_short_side > 0: 89 | transforms["resize"] = T.Resize( 90 | cfg.this_main.resize_short_side, 91 | interpolation=T.InterpolationMode.BILINEAR, 92 | antialias=True, 93 | ) 94 | 95 | dataset = { 96 | "predict": SimpleReference( 97 | query_dir=cfg.data.dataset.query_dir, 98 | reference_dir=cfg.data.dataset.reference_dir, 99 | transforms=transforms, 100 | neighbour_config=cfg.data.neighbour_config, 101 | return_item_paths=True, 102 | zero_reference=cfg.data.dataset.zero_reference, 103 | ), 104 | } 105 | 106 | dataloader_predict = DataLoader( 107 | dataset["predict"], 108 | batch_size=cfg.data.loader.validation.batch_size, 109 | shuffle=False, 110 | num_workers=cfg.data.loader.validation.num_workers, 111 | pin_memory=True, 112 | persistent_workers=False, 113 | ) 114 | 115 | # lightning model 116 | model = CrossScoreLightningModule(cfg) 117 | 118 | # DDP strategy 119 | if NUM_GPUS > 1: 120 | strategy = DDPStrategy(find_unused_parameters=False, static_graph=True) 121 | use_distributed_sampler = True 122 | else: 123 | strategy = "auto" 124 | use_distributed_sampler = False 125 | 126 | # lightning trainer 127 | trainer = lightning.Trainer( 128 | accelerator=cfg.trainer.accelerator, 129 | devices=cfg.trainer.devices, 130 | precision=cfg.trainer.precision, 131 | strategy=strategy, 132 | use_distributed_sampler=use_distributed_sampler, 133 | limit_test_batches=cfg.trainer.limit_test_batches, 134 | logger=False, 135 | ) 136 | 137 | trainer.predict( 138 | model, 139 | dataloader_predict, 140 | ckpt_path=cfg.trainer.ckpt_path_to_load, 141 | ) 142 | 143 | 144 | if __name__ == "__main__": 145 | with torch.no_grad(): 146 | predict() 147 | -------------------------------------------------------------------------------- /task/test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).parents[1])) 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torchvision.transforms import v2 as T 9 | import lightning 10 | from lightning.pytorch.strategies import DDPStrategy 11 | from lightning.pytorch.loggers import CSVLogger 12 | import hydra 13 | from omegaconf import DictConfig, open_dict 14 | 15 | from core import CrossScoreLightningModule 16 | from dataloading.data_manager import get_dataset 17 | from dataloading.transformation.crop import CropperFactory 18 | from utils.io.images import ImageNetMeanStd 19 | 20 | 21 | @hydra.main(version_base="1.3", config_path="../config", config_name="default_test") 22 | def test(cfg: DictConfig): 23 | lightning.seed_everything(cfg.lightning.seed, workers=True) 24 | NUM_GPUS = len(cfg.trainer.devices) 25 | 26 | if ( 27 | cfg.this_main.force_batch_size is False 28 | and cfg.data.loader.validation.batch_size > 8 29 | and cfg.this_main.crop_mode in [None, "integer_patches"] 30 | ): 31 | tmp = input( 32 | "Test full image resolution in a large batch size " 33 | f"{cfg.data.loader.validation.batch_size}. " 34 | "Press Enter to continue, or enter a new batch size: " 35 | ) 36 | if tmp == "": 37 | pass 38 | elif tmp.isdigit(): 39 | new_batch_size = int(tmp) 40 | with open_dict(cfg): 41 | cfg.data.loader.validation.batch_size = new_batch_size 42 | print(f"Set batch size to {new_batch_size}") 43 | else: 44 | raise ValueError("Invalid input") 45 | 46 | if cfg.trainer.ckpt_path_to_load is None: 47 | # no ckpt, use current timestamp as log dir 48 | import datetime 49 | 50 | now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S.%f") 51 | log_dir = Path("log") / now 52 | test_dir_name = "test_empty_ckpt" 53 | else: 54 | # has ckpt, use the log dir of the ckpt path 55 | log_dir = Path(cfg.trainer.ckpt_path_to_load).parents[1] 56 | test_dir_name = "test" 57 | log_dir.mkdir(parents=True, exist_ok=True) 58 | 59 | # init csv logger 60 | csv_logger = CSVLogger(save_dir=log_dir, name=test_dir_name) 61 | 62 | # add logdir to cfg so lightning model can write to it 63 | with open_dict(cfg): 64 | if cfg.logger.test.out_dir is None: 65 | cfg.logger.test.out_dir = csv_logger.log_dir + f"_{cfg.alias}" # the version_* dir 66 | 67 | # init dataset 68 | img_norm_stat = ImageNetMeanStd() 69 | transforms = { 70 | "img": T.Normalize( 71 | mean=img_norm_stat.mean, 72 | std=img_norm_stat.std, 73 | ), 74 | } 75 | 76 | if cfg.this_main.crop_mode == "dataset_default": 77 | transforms["query_crop"] = CropperFactory( 78 | output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), 79 | same_on_batch=True, 80 | deterministic=True, 81 | ) 82 | transforms["reference_crop"] = CropperFactory( 83 | output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), 84 | same_on_batch=False, 85 | deterministic=True, 86 | ) 87 | elif cfg.this_main.crop_mode == "integer_patches": 88 | transforms["crop_integer_patches"] = "adaptive" 89 | 90 | if cfg.this_main.resize_short_side > 0: 91 | transforms["resize"] = T.Resize( 92 | cfg.this_main.resize_short_side, 93 | interpolation=T.InterpolationMode.BILINEAR, 94 | antialias=True, 95 | ) 96 | 97 | dataset = { 98 | "test": get_dataset(cfg, transforms, cfg.this_main.data_split, return_item_paths=True) 99 | } 100 | dataloader_test = DataLoader( 101 | dataset["test"], 102 | batch_size=cfg.data.loader.validation.batch_size, 103 | shuffle=cfg.data.loader.validation.shuffle, 104 | num_workers=cfg.data.loader.validation.num_workers, 105 | pin_memory=True, 106 | persistent_workers=False, 107 | ) 108 | 109 | # lightning model 110 | model = CrossScoreLightningModule(cfg) 111 | 112 | # DDP strategy 113 | if NUM_GPUS > 1: 114 | strategy = DDPStrategy(find_unused_parameters=False, static_graph=True) 115 | use_distributed_sampler = True 116 | else: 117 | strategy = "auto" 118 | use_distributed_sampler = False 119 | 120 | # lightning trainer 121 | trainer = lightning.Trainer( 122 | accelerator=cfg.trainer.accelerator, 123 | devices=cfg.trainer.devices, 124 | precision=cfg.trainer.precision, 125 | strategy=strategy, 126 | use_distributed_sampler=use_distributed_sampler, 127 | limit_test_batches=cfg.trainer.limit_test_batches, 128 | logger=csv_logger, 129 | ) 130 | 131 | trainer.test( 132 | model, 133 | dataloader_test, 134 | ckpt_path=cfg.trainer.ckpt_path_to_load, 135 | ) 136 | 137 | 138 | if __name__ == "__main__": 139 | with torch.no_grad(): 140 | test() 141 | -------------------------------------------------------------------------------- /task/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from datetime import timedelta 4 | 5 | sys.path.append(str(Path(__file__).parents[1])) 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torchvision.transforms import v2 as T 10 | import lightning 11 | from lightning.pytorch.loggers.wandb import WandbLogger 12 | from lightning.pytorch.strategies import DDPStrategy 13 | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor 14 | from lightning.pytorch.utilities import rank_zero_only 15 | import hydra 16 | from hydra.core.hydra_config import HydraConfig 17 | from omegaconf import DictConfig 18 | 19 | from core import CrossScoreLightningModule 20 | from dataloading.data_manager import get_dataset 21 | from dataloading.transformation.crop import CropperFactory 22 | from utils.io.images import ImageNetMeanStd 23 | from utils.check_config import ConfigChecker 24 | 25 | 26 | @hydra.main(version_base="1.3", config_path="../config", config_name="default") 27 | def train(cfg: DictConfig): 28 | lightning.seed_everything(cfg.lightning.seed, workers=True) 29 | NUM_GPUS = len(cfg.trainer.devices) 30 | ConfigChecker(cfg).check_train_val() 31 | 32 | # use hydra timestamp as log dir 33 | log_dir = Path(HydraConfig.get().runtime.output_dir) 34 | log_dir.mkdir(parents=True, exist_ok=True) 35 | 36 | # init wandb logger # trick 1 for ddp 37 | wandb_logger = WandbLogger( 38 | project=cfg.project.name, 39 | save_dir=log_dir, 40 | log_model=False, # let lightning's checkpoint callback handle it 41 | ) 42 | 43 | @rank_zero_only 44 | def set_wandb_exp_name(): # trick 2 for ddp 45 | wandb_exp_id = wandb_logger.experiment.id # trick 2 for ddp 46 | if type(wandb_exp_id) == str: # trick 2 for ddp 47 | if cfg.alias is None or cfg.alias == "": 48 | wandb_logger.experiment.name = wandb_exp_id 49 | else: 50 | wandb_logger.experiment.name = wandb_exp_id + "_" + cfg.alias 51 | 52 | # sync wandb run name with id, in lightning wandb's "run" is renamed to "experiment" 53 | set_wandb_exp_name() # trick 2 for ddp 54 | 55 | # init dataset 56 | img_norm_stat = ImageNetMeanStd() 57 | transforms = { 58 | "query_crop": CropperFactory( 59 | output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), 60 | same_on_batch=True, 61 | deterministic=cfg.trainer.overfit_batches > 0, 62 | ), 63 | "reference_crop": CropperFactory( 64 | output_size=(cfg.data.transforms.crop_size, cfg.data.transforms.crop_size), 65 | same_on_batch=False, 66 | deterministic=cfg.trainer.overfit_batches > 0, 67 | ), 68 | "img": T.Normalize( 69 | mean=img_norm_stat.mean, 70 | std=img_norm_stat.std, 71 | ), 72 | } 73 | 74 | if cfg.this_main.resize_short_side > 0: 75 | transforms["resize"] = T.Resize( 76 | cfg.this_main.resize_short_side, 77 | interpolation=T.InterpolationMode.BILINEAR, 78 | antialias=True, 79 | ) 80 | 81 | dataset = { 82 | "train": get_dataset(cfg, transforms, "train"), 83 | "test": get_dataset(cfg, transforms, "test", return_item_paths=True), 84 | } 85 | 86 | dataloader_train = DataLoader( 87 | dataset["train"], 88 | batch_size=cfg.data.loader.train.batch_size, 89 | shuffle=cfg.data.loader.train.shuffle, 90 | num_workers=cfg.data.loader.train.num_workers, 91 | pin_memory=cfg.data.loader.train.pin_memory, 92 | persistent_workers=cfg.data.loader.train.persistent_workers, 93 | prefetch_factor=cfg.data.loader.train.prefetch_factor, 94 | ) 95 | dataloader_test = DataLoader( 96 | dataset["test"], 97 | batch_size=cfg.data.loader.validation.batch_size, 98 | shuffle=cfg.data.loader.validation.shuffle, 99 | num_workers=cfg.data.loader.validation.num_workers, 100 | pin_memory=cfg.data.loader.validation.pin_memory, 101 | persistent_workers=cfg.data.loader.validation.persistent_workers, 102 | prefetch_factor=cfg.data.loader.validation.prefetch_factor, 103 | ) 104 | 105 | # lightning model 106 | model = CrossScoreLightningModule(cfg) 107 | 108 | # DDP strategy 109 | if NUM_GPUS > 1: 110 | strategy = DDPStrategy(find_unused_parameters=False, static_graph=True) 111 | use_distributed_sampler = True 112 | else: 113 | strategy = "auto" 114 | use_distributed_sampler = False 115 | 116 | # checkpoint callback 117 | ckpt_dir = log_dir / "ckpt" 118 | if cfg.trainer.checkpointing.train_time_interval is not None: 119 | train_time_interval = timedelta(hours=float(cfg.trainer.checkpointing.train_time_interval)) 120 | else: 121 | train_time_interval = None 122 | checkpoint_callback = ModelCheckpoint( 123 | dirpath=ckpt_dir, 124 | every_n_epochs=cfg.trainer.checkpointing.every_n_epochs, 125 | every_n_train_steps=cfg.trainer.checkpointing.every_n_train_steps, 126 | train_time_interval=train_time_interval, 127 | save_last=cfg.trainer.checkpointing.save_last, 128 | save_top_k=cfg.trainer.checkpointing.save_top_k, 129 | ) 130 | 131 | # learning rate monitor 132 | lr_monitor = LearningRateMonitor() 133 | 134 | if cfg.trainer.do_profiling: 135 | from lightning.pytorch.profilers import PyTorchProfiler 136 | from datetime import datetime 137 | 138 | profiler = PyTorchProfiler( 139 | dirpath=log_dir / "profiler", 140 | filename=datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 141 | schedule=torch.profiler.schedule(wait=10, warmup=2, active=10, repeat=1), 142 | ) 143 | else: 144 | profiler = None 145 | 146 | # lightning trainer 147 | trainer = lightning.Trainer( 148 | accelerator=cfg.trainer.accelerator, 149 | devices=cfg.trainer.devices, 150 | precision=cfg.trainer.precision, 151 | max_epochs=cfg.trainer.max_epochs, 152 | max_steps=cfg.trainer.max_steps, 153 | strategy=strategy, 154 | use_distributed_sampler=use_distributed_sampler, 155 | limit_train_batches=cfg.trainer.limit_train_batches, 156 | limit_val_batches=cfg.trainer.limit_val_batches, 157 | num_sanity_val_steps=cfg.trainer.num_sanity_val_steps, 158 | overfit_batches=cfg.trainer.overfit_batches, 159 | logger=wandb_logger, 160 | log_every_n_steps=cfg.trainer.log_every_n_steps, 161 | callbacks=[checkpoint_callback, lr_monitor], 162 | profiler=profiler, 163 | ) 164 | 165 | trainer.fit( 166 | model, 167 | dataloader_train, 168 | dataloader_test, 169 | ckpt_path=cfg.trainer.ckpt_path_to_load, 170 | ) 171 | 172 | 173 | if __name__ == "__main__": 174 | train() 175 | -------------------------------------------------------------------------------- /utils/check_config.py: -------------------------------------------------------------------------------- 1 | def check_metric_prediction_config( 2 | metric_type, 3 | metric_min, 4 | metric_max, 5 | ): 6 | valid_max = False 7 | valid_min = False 8 | valid_type = False 9 | 10 | if metric_type in ["ssim", "mse", "mae"]: 11 | valid_type = True 12 | 13 | if metric_max == 1: 14 | valid_max = True 15 | 16 | if metric_type == "ssim": 17 | if metric_min in [-1, 0]: 18 | valid_min = True 19 | elif metric_type in ["mse", "mae"]: 20 | if metric_min == 0: 21 | valid_min = True 22 | 23 | if not valid_type: 24 | raise ValueError(f"Invalid metric type {metric_type}") 25 | 26 | valid_range = valid_min and valid_max 27 | if not valid_range: 28 | raise ValueError(f"Invalid metric range {metric_min} to {metric_max} for {metric_type}") 29 | 30 | 31 | def check_reference_type(do_reference_cross): 32 | if do_reference_cross: 33 | ref_type = "cross" 34 | else: 35 | raise ValueError("Reference type must be 'cross'") 36 | return ref_type 37 | 38 | 39 | class ConfigChecker: 40 | """ 41 | Check if a config object is valid for 42 | - train/val/test/predict steps that correspond to the lightning module; 43 | - dataloader creation. 44 | """ 45 | 46 | def __init__(self, cfg): 47 | self.cfg = cfg 48 | 49 | def _check_common_lightning(self): 50 | check_reference_type(self.cfg.model.do_reference_cross) 51 | check_metric_prediction_config( 52 | self.cfg.model.predict.metric.type, 53 | self.cfg.model.predict.metric.min, 54 | self.cfg.model.predict.metric.max, 55 | ) 56 | 57 | def check_train_val(self): 58 | self._check_common_lightning() 59 | 60 | def check_test(self): 61 | self._check_common_lightning() 62 | 63 | def check_predict(self): 64 | self._check_common_lightning() 65 | 66 | def check_dataset(self): 67 | check_metric_prediction_config( 68 | self.cfg.model.predict.metric.type, 69 | self.cfg.model.predict.metric.min, 70 | self.cfg.model.predict.metric.max, 71 | ) 72 | -------------------------------------------------------------------------------- /utils/data_processing/split_gaussian_processed.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from pathlib import Path 5 | from pprint import pprint 6 | import numpy as np 7 | 8 | 9 | def split_list_by_ratio(list_input, ratio_dict): 10 | # Check if the sum of ratios is close to 1 11 | if not 0.999 < sum(ratio_dict.values()) < 1.001: 12 | raise ValueError("The sum of the ratios must be close to 1") 13 | 14 | total_length = len(list_input) 15 | lengths = {k: int(v * total_length) for k, v in ratio_dict.items()} 16 | 17 | # Adjust the last split to include any rounding difference 18 | last_split_name = list(ratio_dict.keys())[-1] 19 | lengths[last_split_name] = total_length - sum(lengths.values()) + lengths[last_split_name] 20 | 21 | # Split the list 22 | split_lists = {} 23 | start = 0 24 | for split_name, length in lengths.items(): 25 | split_lists[split_name] = list_input[start : start + length] 26 | start += length 27 | 28 | split_lists = {k: v.tolist() for k, v in split_lists.items()} 29 | return split_lists 30 | 31 | 32 | if __name__ == "__main__": 33 | np.random.seed(1234) 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--data_path", 37 | type=str, 38 | default="~/projects/mview/storage/scratch_dataset/gaussian-splatting-processed/RealEstate200/res_540", 39 | ) 40 | parser.add_argument("--min_seq_len", type=int, default=2, help="keep seq with len(imgs) >= 2") 41 | parser.add_argument("--min_psnr", type=float, default=10.0, help="keep seq with psnr >= 10.0") 42 | parser.add_argument( 43 | "--split_ratio", 44 | nargs="+", 45 | type=float, 46 | default=[0.8, 0.1, 0.1], 47 | help="train/val/test split ratio", 48 | ) 49 | args = parser.parse_args() 50 | 51 | data_path = Path(args.data_path).expanduser() 52 | log_files = sorted([f for f in os.listdir(data_path) if f.endswith(".log")]) 53 | 54 | # get low psnr scenes from log files 55 | scene_all = [] 56 | scene_low_psnr = {} 57 | for log_f in log_files: 58 | with open(data_path / log_f, "r") as f: 59 | lines = f.readlines() 60 | for line in lines: 61 | # assume scene name printed a few lines before PSNR 62 | if "Output folder" in line: 63 | scene_name = line.split("Output folder: ")[1].split("/")[-1] 64 | scene_name = scene_name.removesuffix("\n") 65 | elif "[ITER 7000] Evaluating train" in line: 66 | psnr = line.split("PSNR ")[1] 67 | psnr = psnr.removesuffix("\n") 68 | psnr = float(psnr) 69 | 70 | scene_all.append(scene_name) 71 | if psnr < args.min_psnr: 72 | scene_low_psnr[scene_name] = psnr 73 | else: 74 | pass 75 | 76 | # get low seq length scenes from data folders 77 | gaussian_splits = ["train", "test"] 78 | scene_low_length = {} 79 | for scene_name in scene_all: 80 | for gs_split in gaussian_splits: 81 | tmp_dir = data_path / scene_name / gs_split / "ours_1000" / "gt" 82 | num_img = len(os.listdir(tmp_dir)) 83 | if num_img < args.min_seq_len: 84 | scene_low_length[scene_name] = num_img 85 | 86 | num_scene_total_after_gaussian = len(scene_all) 87 | num_scene_low_psnr = len(scene_low_psnr) 88 | num_scene_low_length = len(scene_low_length) 89 | 90 | # filter out low psnr scenes 91 | scene_all = [s for s in scene_all if s not in scene_low_psnr.keys()] 92 | num_scene_total_filtered_low_psnr = len(scene_all) 93 | 94 | # filter out low seq length scenes 95 | scene_all = [s for s in scene_all if s not in scene_low_length.keys()] 96 | num_scene_total_filtered_low_length = len(scene_all) 97 | 98 | # split train/val/test 99 | scene_all = np.random.permutation(scene_all) 100 | num_scene_after_all_filtering = len(scene_all) 101 | ratio = { 102 | "train": args.split_ratio[0], 103 | "val": args.split_ratio[1], 104 | "test": args.split_ratio[2], 105 | } 106 | scene_split_info = split_list_by_ratio(scene_all, ratio) 107 | num_scene_train = len(scene_split_info["train"]) 108 | num_scene_val = len(scene_split_info["val"]) 109 | num_scene_test = len(scene_split_info["test"]) 110 | num_scene_after_split = sum([len(v) for v in scene_split_info.values()]) 111 | assert num_scene_after_split == num_scene_after_all_filtering 112 | 113 | # save to json 114 | stats = { 115 | "min_psnr": args.min_psnr, 116 | "min_seq_len": args.min_seq_len, 117 | "split_ratio": args.split_ratio, 118 | "num_scene_total_after_gaussian": num_scene_total_after_gaussian, 119 | "num_scene_low_psnr": num_scene_low_psnr, 120 | "num_scene_low_length": num_scene_low_length, 121 | "num_scene_total_filtered_low_psnr": num_scene_total_filtered_low_psnr, 122 | "num_scene_total_filtered_low_length": num_scene_total_filtered_low_length, 123 | "num_scene_after_all_filtering": num_scene_after_all_filtering, 124 | "num_scene_train": num_scene_train, 125 | "num_scene_val": num_scene_val, 126 | "num_scene_test": num_scene_test, 127 | "num_scene_after_split": num_scene_after_split, 128 | } 129 | 130 | pprint(stats, sort_dicts=False) 131 | out_dict = {"stats": stats, **scene_split_info} 132 | 133 | with open(data_path / "split.json", "w") as f: 134 | json.dump(out_dict, f, indent=2) 135 | -------------------------------------------------------------------------------- /utils/evaluation/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def psnr(a, b, return_map=False): 6 | mse_map = torch.nn.functional.mse_loss(a, b, reduction="none") 7 | psnr_map = -10 * torch.log10(mse_map) 8 | if return_map: 9 | return psnr_map 10 | else: 11 | return psnr_map.mean() 12 | 13 | 14 | def mse2psnr(a): 15 | return -10 * torch.log10(a) 16 | 17 | 18 | def abs2psnr(a): 19 | return -10 * torch.log10(a.pow(2)) 20 | 21 | 22 | def psnr2mse(a): 23 | return 10 ** (-a / 10) 24 | 25 | 26 | def correlation(a, b): 27 | x = torch.stack([a.flatten(), b.flatten()], dim=0) # (2, N) 28 | corr = x.corrcoef() # (2, 2) 29 | corr = corr[0, 1] # only this one is meaningful 30 | return corr 31 | -------------------------------------------------------------------------------- /utils/evaluation/metric_logger.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import numpy as np 4 | from .metric import correlation 5 | 6 | 7 | class MetricLogger(ABC): 8 | def __init__(self, max_length): 9 | self.storage = [] 10 | self.max_length = max_length 11 | 12 | @torch.no_grad() 13 | def update(self, x): 14 | if self.max_length is not None and len(self) >= self.max_length: 15 | self.reset() 16 | self.storage.append(x) 17 | 18 | def reset(self): 19 | self.storage.clear() 20 | 21 | def __len__(self): 22 | return len(self.storage) 23 | 24 | @abstractmethod 25 | def compute(self): 26 | raise NotImplementedError 27 | 28 | 29 | class MetricLoggerScalar(MetricLogger): 30 | @torch.no_grad() 31 | def compute(self, aggregation_fn=torch.mean): 32 | tmp = torch.stack(self.storage) 33 | result = aggregation_fn(tmp) 34 | return result 35 | 36 | 37 | class MetricLoggerHistogram(MetricLogger): 38 | @torch.no_grad() 39 | def compute(self, bins=10, range=None): 40 | tmp = torch.cat(self.storage).cpu().numpy() 41 | result = np.histogram(tmp, bins=bins, range=range) 42 | return result 43 | 44 | 45 | class MetricLoggerCorrelation(MetricLoggerScalar): 46 | @torch.no_grad() 47 | def update(self, a, b): 48 | corr = correlation(a, b) 49 | super().update(corr) 50 | 51 | 52 | class MetricLoggerImg(MetricLogger): 53 | @torch.no_grad() 54 | def compute(self): 55 | return self.storage 56 | -------------------------------------------------------------------------------- /utils/evaluation/summarise_score_gt.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import sys 3 | from pathlib import Path 4 | 5 | sys.path.append(str(Path(__file__).parents[2])) 6 | from utils.io.score_summariser import SummaryWriterGroundTruth 7 | 8 | 9 | def parse_args(): 10 | parser = ArgumentParser(description="Summarise the ground truth results.") 11 | parser.add_argument( 12 | "--dir_in", 13 | type=str, 14 | default="datadir/processed_training_ready/gaussian/map-free-reloc/res_540", 15 | help="The ground truth data dir that contains scene dirs.", 16 | ) 17 | parser.add_argument( 18 | "--dir_out", 19 | type=str, 20 | default="~/projects/mview/storage/scratch_dataset/score_summary", 21 | help="The output directory to save the summarised results.", 22 | ) 23 | parser.add_argument( 24 | "--fast_debug", 25 | type=int, 26 | default=-1, 27 | help="num batch to load for debug. Set to -1 to disable", 28 | ) 29 | parser.add_argument("-n", "--num_workers", type=int, default=16) 30 | parser.add_argument("-f", "--force", type=eval, default=False, choices=[True, False]) 31 | return parser.parse_args() 32 | 33 | 34 | if __name__ == "__main__": 35 | args = parse_args() 36 | summariser = SummaryWriterGroundTruth( 37 | dir_in=args.dir_in, 38 | dir_out=args.dir_out, 39 | num_workers=args.num_workers, 40 | fast_debug=args.fast_debug, 41 | force=args.force, 42 | ) 43 | summariser.write_csv() 44 | -------------------------------------------------------------------------------- /utils/io/batch_writer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from PIL import Image 4 | import numpy as np 5 | from utils.io.images import metric_map_write, u8 6 | from utils.misc.image import gray2rgb, attn2rgb, de_norm_img 7 | 8 | 9 | def get_vrange(predict_metric_type, predict_metric_min, predict_metric_max): 10 | # Using uint16 when write in gray, normalise to the intrinsic range 11 | # based on score type, regardless of the model prediction range 12 | if predict_metric_type == "ssim": 13 | vrange_intrinsic = [-1, 1] 14 | elif predict_metric_type in ["mse", "mae"]: 15 | vrange_intrinsic = [0, 1] 16 | else: 17 | raise ValueError(f"metric_type {predict_metric_type} not supported") 18 | 19 | # RGB for visualization only, normalise to the model prediction range 20 | vrange_vis = [predict_metric_min, predict_metric_max] 21 | return vrange_intrinsic, vrange_vis 22 | 23 | 24 | class BatchWriter: 25 | """Write batch outputs to disk.""" 26 | 27 | def __init__(self, cfg, phase: str, img_mean_std): 28 | if phase not in ["test", "predict"]: 29 | raise ValueError( 30 | f"Phase {phase} not supported. Has to be a Lightening phase test/predict." 31 | ) 32 | self.cfg = cfg 33 | self.phase = phase 34 | self.img_mean_std = img_mean_std 35 | 36 | self.out_dir = Path(self.cfg.logger[phase].out_dir) 37 | self.write_config = self.cfg.logger[phase].write.config 38 | self.write_flag = self.cfg.logger[phase].write.flag 39 | 40 | # overwrite the flag for attn_weights if the model does not have it 41 | self.write_flag.attn_weights = ( 42 | self.write_flag.attn_weights and self.cfg.model.need_attn_weights 43 | ) 44 | 45 | self.predict_metric_type = self.cfg.model.predict.metric.type 46 | self.predict_metric_min = self.cfg.model.predict.metric.min 47 | self.predict_metric_max = self.cfg.model.predict.metric.max 48 | 49 | self.vrange_intrinsic, self.vrange_vis = get_vrange( 50 | self.predict_metric_type, self.predict_metric_min, self.predict_metric_max 51 | ) 52 | 53 | # prepare out_dirs, create them if required 54 | self.out_dir_dict = {"batch": Path(self.out_dir, "batch")} 55 | if self.write_flag["batch"]: 56 | for k in self.write_flag.keys(): 57 | if k not in ["batch", "score_map_prediction"]: 58 | if self.write_flag[k]: 59 | self.out_dir_dict[k] = Path(self.out_dir_dict["batch"], k) 60 | self.out_dir_dict[k].mkdir(parents=True, exist_ok=True) 61 | 62 | def write_out(self, batch_input, batch_output, local_rank, batch_idx): 63 | if self.write_flag["score_map_prediction"]: 64 | self._write_score_map_prediction( 65 | self.out_dir_dict["batch"], 66 | batch_input, 67 | batch_output, 68 | local_rank, 69 | batch_idx, 70 | ) 71 | 72 | if self.write_flag["score_map_gt"]: 73 | self._write_score_map_gt( 74 | self.out_dir_dict["score_map_gt"], 75 | batch_input, 76 | local_rank, 77 | batch_idx, 78 | ) 79 | 80 | if self.write_flag["item_path_json"]: 81 | self._write_item_path_json( 82 | self.out_dir_dict["item_path_json"], 83 | batch_input, 84 | local_rank, 85 | batch_idx, 86 | ) 87 | 88 | if self.write_flag["image_query"]: 89 | self._write_query_image( 90 | self.out_dir_dict["image_query"], 91 | batch_input, 92 | local_rank, 93 | batch_idx, 94 | ) 95 | 96 | if self.write_flag["image_reference"]: 97 | self._write_reference_image( 98 | self.out_dir_dict["image_reference"], 99 | batch_input, 100 | local_rank, 101 | batch_idx, 102 | ) 103 | 104 | if self.write_flag["attn_weights"]: 105 | self._write_attn_weights( 106 | self.out_dir_dict["attn_weights"], 107 | batch_input, 108 | batch_output, 109 | local_rank, 110 | batch_idx, 111 | check_patch_mode="centre", 112 | ) 113 | 114 | def _write_score_map_prediction( 115 | self, out_dir, batch_input, batch_output, local_rank, batch_idx 116 | ): 117 | query_img_paths = [ 118 | str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") 119 | for p in batch_input["item_paths"]["query/img"] 120 | ] 121 | score_map_type_list = [k for k in batch_output.keys() if k.startswith("score_map")] 122 | for score_map_type in score_map_type_list: 123 | tmp_out_dir = Path(out_dir, score_map_type) 124 | tmp_out_dir.mkdir(parents=True, exist_ok=True) 125 | 126 | if len(query_img_paths) != len(batch_output[score_map_type]): 127 | raise ValueError("num of query images and score maps are not equal") 128 | 129 | for b, (query_img_p, score_map) in enumerate( 130 | zip(query_img_paths, batch_output[score_map_type]) 131 | ): 132 | tmp_out_name = f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}.png" 133 | self._write_a_score_map_with_colour_mode( 134 | out_path=tmp_out_dir / tmp_out_name, score_map=score_map.cpu().numpy() 135 | ) 136 | 137 | def _write_score_map_gt(self, out_dir, batch_input, local_rank, batch_idx): 138 | query_img_paths = [ 139 | str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") 140 | for p in batch_input["item_paths"]["query/img"] 141 | ] 142 | 143 | if len(query_img_paths) != len(batch_input["query/score_map"]): 144 | raise ValueError("num of query images and score maps are not equal") 145 | 146 | for b, (query_img_p, score_map) in enumerate( 147 | zip(query_img_paths, batch_input["query/score_map"]) 148 | ): 149 | tmp_out_name = f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}.png" 150 | self._write_a_score_map_with_colour_mode( 151 | out_path=out_dir / tmp_out_name, score_map=score_map.cpu().numpy() 152 | ) 153 | 154 | def _write_item_path_json(self, out_dir, batch_input, local_rank, batch_idx): 155 | out_path = out_dir / f"r{local_rank}_B{str(batch_idx).zfill(4)}.json" 156 | 157 | # get a deep copy to avoid infering other writing functions 158 | item_paths = batch_input["item_paths"].copy() 159 | for ref_type in ["reference/cross/imgs"]: 160 | if len(item_paths[ref_type]) > 0: 161 | # transpose ref paths to (N_ref, B) 162 | item_paths[ref_type] = np.array(item_paths[ref_type]).T.tolist() 163 | 164 | with open(out_path, "w") as f: 165 | json.dump(item_paths, f, indent=2) 166 | 167 | def _write_query_image(self, out_dir, batch_input, local_rank, batch_idx): 168 | query_img_paths = [ 169 | str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") 170 | for p in batch_input["item_paths"]["query/img"] 171 | ] 172 | 173 | for b, (query_img_p, image_query) in enumerate( 174 | zip(query_img_paths, batch_input["query/img"]) 175 | ): 176 | tmp_out_name = f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}.png" 177 | tmp_out_path = Path(out_dir, tmp_out_name) 178 | image_query = de_norm_img(image_query.permute(1, 2, 0), self.img_mean_std) 179 | image_query = u8(image_query.cpu().numpy()) 180 | Image.fromarray(image_query).save(tmp_out_path) 181 | 182 | def _write_reference_image(self, out_dir, batch_input, local_rank, batch_idx): 183 | query_img_paths = [ 184 | str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") 185 | for p in batch_input["item_paths"]["query/img"] 186 | ] 187 | for ref_type in ["reference/cross/imgs"]: 188 | if len(batch_input["item_paths"][ref_type]) > 0: 189 | ref_img_paths = np.array(batch_input["item_paths"][ref_type]).T # (B, N_ref) 190 | 191 | # create a subfolder for each query image to store its ref images 192 | for b, query_img_p in enumerate(query_img_paths): 193 | tmp_out_dir = Path( 194 | out_dir, 195 | f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}", 196 | ref_type.split("/")[1], 197 | ) 198 | tmp_out_dir.mkdir(parents=True, exist_ok=True) 199 | tmp_ref_img_paths = [ 200 | str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") 201 | for p in ref_img_paths[b] # (N_ref, ) 202 | ] 203 | ref_imgs = batch_input[ref_type][b] # (N_ref, C, H, W) 204 | for ref_idx, (ref_img_p, ref_img) in enumerate( 205 | zip(tmp_ref_img_paths, ref_imgs) 206 | ): 207 | tmp_out_name = f"ref{ref_idx:02}_{ref_img_p}.png" 208 | tmp_out_path = Path(tmp_out_dir, tmp_out_name) 209 | ref_img = de_norm_img(ref_img.permute(1, 2, 0), self.img_mean_std) 210 | ref_img = u8(ref_img.cpu().numpy()) 211 | Image.fromarray(ref_img).save(tmp_out_path) 212 | 213 | def _write_attn_weights( 214 | self, out_dir, batch_input, batch_output, local_rank, batch_idx, check_patch_mode 215 | ): 216 | query_img_paths = [ 217 | str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") 218 | for p in batch_input["item_paths"]["query/img"] 219 | ] 220 | for ref_type in ["reference/cross/imgs"]: 221 | if len(batch_input["item_paths"][ref_type]) > 0: 222 | ref_img_paths = np.array(batch_input["item_paths"][ref_type]).T # (B, N_ref) 223 | ref_type_short = ref_type.split("/")[1] 224 | 225 | # create a subfolder for each query image to store its ref images 226 | for b, query_img_p in enumerate(query_img_paths): 227 | tmp_out_dir = Path( 228 | out_dir, 229 | f"r{local_rank}_B{batch_idx:04}_b{b:03}_{query_img_p}", 230 | ref_type_short, 231 | ) 232 | tmp_out_dir.mkdir(parents=True, exist_ok=True) 233 | tmp_ref_img_paths = [ 234 | str(Path(*Path(p).parts[-5:])).replace("/", "_").replace(".png", "") 235 | for p in ref_img_paths[b] # (N_ref, ) 236 | ] 237 | 238 | # get attn maps of the centre patch in the query image 239 | # (B, H, W, N_ref, H, W) 240 | attn_weights_map = batch_output[f"attn_weights_map_ref_{ref_type_short}"] 241 | tmp_h, tmp_w = attn_weights_map.shape[1:3] 242 | 243 | if check_patch_mode == "centre": 244 | query_patch = (tmp_h // 2, tmp_w // 2) 245 | elif check_patch_mode == "random": 246 | query_patch = ( 247 | np.random.randint(0, tmp_h), 248 | np.random.randint(0, tmp_w), 249 | ) 250 | else: 251 | raise ValueError(f"Unknown check_patch_mode: {check_patch_mode}") 252 | attn_weights_map = attn_weights_map[b][query_patch] # (N_ref, H, W) 253 | 254 | # write attn maps 255 | for ref_idx, (ref_img_p, attn_m) in enumerate( 256 | zip(tmp_ref_img_paths, attn_weights_map) 257 | ): 258 | tmp_out_name = f"ref{ref_idx:02}_{ref_img_p}.png" 259 | tmp_out_path = Path(tmp_out_dir, tmp_out_name) 260 | attn_m = attn2rgb(attn_m.cpu().numpy()) # (H, W, 3) 261 | Image.fromarray(attn_m).save(tmp_out_path) 262 | 263 | def _write_a_score_map_with_colour_mode(self, out_path, score_map): 264 | if self.write_config.score_map_colour_mode == "gray": 265 | metric_map_write(out_path, score_map, self.vrange_intrinsic) 266 | elif self.write_config.score_map_colour_mode == "rgb": 267 | rgb = gray2rgb(score_map, self.vrange_vis) 268 | Image.fromarray(rgb).save(out_path) 269 | else: 270 | raise ValueError(f"colour_mode {self.write_config.score_map_colour_mode} not supported") 271 | -------------------------------------------------------------------------------- /utils/io/images.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | from dataclasses import dataclass 4 | from PIL import Image 5 | from typing import List 6 | 7 | 8 | @dataclass 9 | class ImageNetMeanStd: 10 | mean = (0.485, 0.456, 0.406) 11 | std = (0.229, 0.224, 0.225) 12 | 13 | 14 | def f32(img): 15 | img = img.astype(np.float32) 16 | img = img / 255.0 17 | return img 18 | 19 | 20 | def u8(img): 21 | img = img * 255.0 22 | img = img.astype(np.uint8) 23 | return img 24 | 25 | 26 | def image_read(p): 27 | img = np.array(Image.open(p)) 28 | img = f32(img) 29 | return img 30 | 31 | 32 | def metric_map_read(p, vrange: List[int]): 33 | """Read metric maps and convert to float. 34 | Note: 35 | - when read/write int32 to png, it acutally reads/writes uint16 but looks like int32. 36 | - uint16 has range [0, 65535] 37 | """ 38 | m = np.array(Image.open(p)) # HW np.int32 39 | m = m.astype(np.float32) 40 | if vrange == [0, 1]: 41 | m = m / 65535 42 | elif vrange == [-1, 1]: 43 | m = m / 32767 - 1 44 | else: 45 | raise ValueError("Invalid range for metric map reading. Must be '[0,1]' or '[-1,1]'") 46 | return m # HW np.float32 47 | 48 | 49 | def metric_map_write(p, m, vrange: List[int]): 50 | """Convert float metric maps to integer and write to png. 51 | Note: 52 | - when read/write int32 to png, it acutally reads/writes uint16 but looks like int32. 53 | - uint16 has range [0, 65535] 54 | """ 55 | if vrange == [0, 1]: 56 | m = m * 65535 # [0,1] -> [0, 65535] 57 | elif vrange == [-1, 1]: 58 | m = (m + 1) * 32767 # [-1,1] -> [0, 2] -> [0, 65534] 59 | else: 60 | raise ValueError("Invalid range for metric map writing. Must be '[0,1]' or '[-1,1]'") 61 | m = m.astype(np.int32) 62 | # set compression level 0 for even faster writing 63 | imageio.imwrite(p, m) 64 | -------------------------------------------------------------------------------- /utils/io/score_summariser.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | from glob import glob 4 | 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | import numpy as np 8 | import pandas as pd 9 | from pandas import DataFrame 10 | from tqdm import tqdm 11 | 12 | from utils.io.images import metric_map_read 13 | from utils.evaluation.metric import mse2psnr 14 | 15 | 16 | class ScoreReader(Dataset): 17 | def __init__(self, score_map_dir_list): 18 | # get all paths to read 19 | read_score_types = ["ssim", "mae"] 20 | self.read_paths_all = {k: [] for k in read_score_types} 21 | for score_read_type in read_score_types: 22 | for score_map_dir in score_map_dir_list: 23 | tmp_dir = os.path.join(score_map_dir, score_read_type) 24 | tmp_paths = [os.path.join(tmp_dir, n) for n in sorted(os.listdir(tmp_dir))] 25 | self.read_paths_all[score_read_type].extend(tmp_paths) 26 | 27 | # (N_frames, 2) 28 | self.read_paths_all = np.stack([self.read_paths_all[k] for k in read_score_types], axis=1) 29 | 30 | def __len__(self): 31 | return len(self.read_paths_all) 32 | 33 | def __getitem__(self, idx): 34 | path_ssim, path_mae = self.read_paths_all[idx] 35 | ssim_map = metric_map_read(path_ssim, vrange=[-1, 1]) 36 | mae_map = metric_map_read(path_mae, vrange=[0, 1]) 37 | mse_map = np.square(mae_map) 38 | 39 | score_ssim_n11 = ssim_map.mean() 40 | score_ssim_01 = ssim_map.clip(0, 1).mean() 41 | score_mae = mae_map.mean() 42 | score_mse = mse_map.mean() 43 | score_psnr = mse2psnr(torch.tensor([score_mse])).numpy() 44 | 45 | results = { 46 | "ssim_-1_1": score_ssim_n11, 47 | "ssim_0_1": score_ssim_01, 48 | "mae": score_mae, 49 | "mse": score_mse, 50 | "psnr": score_psnr, 51 | "path_ssim": path_ssim, 52 | } 53 | return results 54 | 55 | 56 | class SummaryWriterGroundTruth: 57 | """ 58 | Load the ground truth results from disk and summarise the results. 59 | """ 60 | 61 | def __init__(self, dir_in, dir_out, num_workers, fast_debug, force): 62 | self.dir_in = Path(dir_in).expanduser() 63 | self.dir_out = Path(dir_out).expanduser() 64 | self.num_workers = num_workers 65 | self.fast_debug = fast_debug 66 | self.force = force 67 | 68 | self.dataset_type = self.dir_in.parent.name 69 | self.rendering_method = self.dir_in.parents[1].name 70 | self.csv_dir = self.dir_out / self.dataset_type 71 | self.csv_path = self.csv_dir / f"{self.rendering_method}.csv" 72 | self.csv_dir.mkdir(parents=True, exist_ok=True) 73 | self.rows = [] 74 | self.columns = [ 75 | "scene_name", 76 | "rendered_dir", 77 | "image_name", 78 | "gt_ssim_-1_1", 79 | "gt_ssim_0_1", 80 | "gt_mae", 81 | "gt_mse", 82 | "gt_psnr", 83 | ] 84 | 85 | def write_csv(self): 86 | write = self._check_write(self.csv_path, self.force) 87 | if write: 88 | rows = self._load_per_frame_score() 89 | df = DataFrame(data=rows, columns=self.columns) 90 | df.to_csv(self.csv_path, index=False, float_format="%.4f") 91 | 92 | def _check_write(self, csv_path, force): 93 | if csv_path.exists(): 94 | if force: 95 | csv_path.unlink() 96 | print(f"Write to csv {csv_path} (OVERWRITE)") 97 | write = True 98 | else: 99 | print(f"Write to csv {csv_path} (SKIP)") 100 | write = False 101 | else: 102 | print(f"Write to csv {csv_path} (NORMAL)") 103 | write = True 104 | return write 105 | 106 | def _load_per_frame_score(self): 107 | # use glob to find all dir named "metric_map" in the scene dirs 108 | score_map_dir_list = sorted(glob(str(self.dir_in / "**/metric_map"), recursive=True)) 109 | score_reader = ScoreReader(score_map_dir_list) 110 | score_loader = DataLoader( 111 | dataset=score_reader, 112 | batch_size=16, 113 | shuffle=False, 114 | num_workers=self.num_workers, 115 | ) 116 | 117 | # process score maps to csv rows, each row contains the following columns 118 | rows = [] 119 | for i, data in enumerate(tqdm(score_loader, desc=f"Loading gt scores", dynamic_ncols=True)): 120 | for j in range(len(data["path_ssim"])): 121 | path_ssim = data["path_ssim"][j] 122 | scene_name = path_ssim.split("/")[-6] 123 | rendered_dir = os.path.join(*path_ssim.split("/")[:-3]) 124 | image_name = path_ssim.split("/")[-1] 125 | image_name = image_name.replace("frame_", "") 126 | tmp_row = [ 127 | scene_name, 128 | rendered_dir, 129 | image_name, 130 | data["ssim_-1_1"][j].item(), 131 | data["ssim_0_1"][j].item(), 132 | data["mae"][j].item(), 133 | data["mse"][j].item(), 134 | data["psnr"][j].item(), 135 | ] 136 | rows.append(tmp_row) 137 | if self.fast_debug > 0 and i >= self.fast_debug: 138 | break 139 | return rows 140 | 141 | 142 | class SummaryWriterPredictedOnline: 143 | """ 144 | Used in at the fit/test/predict phase with Lightning Module. 145 | """ 146 | 147 | def __init__(self, metric_type, metric_min): 148 | metric_type_str = self._get_metric_type_str(metric_type, metric_min) 149 | self.columns = [ 150 | "scene_name", 151 | "rendered_dir", 152 | "image_name", 153 | f"pred_{metric_type_str}", 154 | ] 155 | self.reset() 156 | 157 | def _get_metric_type_str(self, metric_type, metric_min): 158 | if metric_type == "ssim": 159 | if metric_min == -1: 160 | metric_str = f"{metric_type}_-1_1" 161 | elif metric_min == 0: 162 | metric_str = f"{metric_type}_0_1" 163 | else: 164 | metric_str = f"{metric_type}" 165 | return metric_str 166 | 167 | def reset(self): 168 | self.rows = DataFrame(columns=self.columns) 169 | 170 | def update(self, batch_input, batch_output): 171 | """Store per frame scores to rows for each batch.""" 172 | query_img_paths = batch_input["item_paths"]["query/img"] # (B,) 173 | ref_types = [t for t in batch_output.keys() if t.startswith("score_map")] 174 | 175 | if len(ref_types) != 1: 176 | raise ValueError(f"Expect exactly one ref_type: self/cross, but got {ref_types}.") 177 | 178 | rows_batch = [] 179 | for ref_type in ref_types: 180 | score_maps = batch_output[ref_type] # (B, H, W) 181 | scores = score_maps.mean(dim=[-1, -2]) # (B,) 182 | 183 | scene_names = [p.split("/")[-5] for p in query_img_paths] 184 | 185 | rendered_dirs = [os.path.join(*p.split("/")[:-2]) for p in query_img_paths] 186 | image_names = [p.split("/")[-1] for p in query_img_paths] 187 | image_names = [n.replace("frame_", "") for n in image_names] 188 | 189 | for i in range(len(scene_names)): 190 | rows_batch.append( 191 | [scene_names[i], rendered_dirs[i], image_names[i], scores[i].item()] 192 | ) 193 | 194 | # concat rows to panda dataframe 195 | self.rows = pd.concat([self.rows, DataFrame(rows_batch, columns=self.columns)]) 196 | 197 | def summarise(self): 198 | """Organise rows using dataset type and rendering method and sort them.""" 199 | 200 | # get unique rendering methods 201 | rendering_method_list = self.rows["rendered_dir"].apply(lambda x: x.split("/")[-6]).unique() 202 | 203 | # get unique dataset types 204 | dataset_type_list = self.rows["rendered_dir"].apply(lambda x: x.split("/")[-5]).unique() 205 | 206 | # organise rows using dataset type and rendering method 207 | self.summary = {} 208 | for dataset_type in dataset_type_list: 209 | self.summary[dataset_type] = {} 210 | for rendering_method in rendering_method_list: 211 | # get rows with the same dataset type and rendering method 212 | tmp_rows = self.rows[ 213 | (self.rows["rendered_dir"].str.contains(rendering_method)) 214 | & (self.rows["rendered_dir"].str.contains(dataset_type)) 215 | ] 216 | 217 | # sort rows by scene_name, rendered_dir, image_name 218 | tmp_rows = tmp_rows.sort_values(by=["scene_name", "rendered_dir", "image_name"]) 219 | 220 | self.summary[dataset_type][rendering_method] = tmp_rows 221 | 222 | def __len__(self): 223 | return len(self.rows) 224 | 225 | def __repr__(self): 226 | return self.rows.__repr__() 227 | 228 | 229 | class SummaryWriterPredictedOnlineTestPrediction(SummaryWriterPredictedOnline): 230 | """ 231 | Used in at the end of validation/test/predict phase. 232 | Summarising with online predicted results to avoid reading from disks. 233 | """ 234 | 235 | def __init__(self, metric_type, metric_min, dir_out): 236 | super().__init__(metric_type, metric_min) 237 | self.csv_dir = Path(dir_out).expanduser() / "score_summary" 238 | self.csv_dir.mkdir(parents=True, exist_ok=True) 239 | self.cache_csv_path = self.csv_dir / f"summarise_cache.csv" 240 | 241 | def summarise(self): 242 | super().summarise() 243 | 244 | # write to csv files 245 | for dataset_type, dataset_summary in self.summary.items(): 246 | for rendering_method, rows in dataset_summary.items(): 247 | tmp_csv_dir = self.csv_dir / dataset_type 248 | tmp_csv_dir.mkdir(parents=True, exist_ok=True) 249 | tmp_csv_path = tmp_csv_dir / f"{rendering_method}.csv" 250 | rows.to_csv(tmp_csv_path, index=False, float_format="%.4f") 251 | 252 | 253 | class SummaryReader: 254 | @staticmethod 255 | def read_summary(summary_dir, dataset, method_list, scene_list, split_list, iter_list): 256 | summary_dir = Path(summary_dir).expanduser() 257 | summary_dir = summary_dir / dataset 258 | 259 | methods_available = [f.stem for f in summary_dir.iterdir() if f.is_file()] 260 | methods_to_read = [] 261 | if method_list != [""]: 262 | for m in method_list: 263 | if m in methods_available: 264 | methods_to_read.append(m) 265 | else: 266 | raise ValueError(f"{m} is not available in {summary_dir}") 267 | else: 268 | methods_to_read = methods_available 269 | 270 | summary_files = [summary_dir / f"{m}.csv" for m in methods_to_read] 271 | 272 | # read csv files, create a new colume as 0th column for method_name 273 | summary = pd.concat( 274 | [pd.read_csv(f).assign(method_name=m) for f, m in zip(summary_files, methods_to_read)] 275 | ) 276 | 277 | # filter with scene list using summary's scene_name column 278 | if scene_list != [""]: 279 | summary = summary[summary["scene_name"].isin(scene_list)] 280 | 281 | # filter split using rendered_dir column 282 | if split_list != [""]: 283 | new_s = [] 284 | for split in split_list: 285 | tmp_s = summary[summary["rendered_dir"].str.split("/").str[-2] == split] 286 | new_s.append(tmp_s) 287 | summary = pd.concat(new_s) 288 | 289 | # filter iter using rendered_dir column, the last part of the path should be EXACTLY "ours_{iter}" 290 | if len(iter_list) > 0: 291 | new_s = [] 292 | for i in iter_list: 293 | tmp_s = summary[summary["rendered_dir"].str.endswith(f"ours_{i}")] 294 | new_s.append(tmp_s) 295 | summary = pd.concat(new_s) 296 | 297 | # sort by scene_name, rendered_dir, image_name, method_name 298 | summary = summary.sort_values(["scene_name", "rendered_dir", "image_name", "method_name"]) 299 | 300 | # reset index 301 | summary = summary.reset_index(drop=True) 302 | return summary 303 | 304 | @staticmethod 305 | def check_summary_gt_prediction_rows(summary_gt, summary_prediction): 306 | # these tow dataframes should have the same length 307 | # and the columns [rendered_dir, image_name] should be identical 308 | if len(summary_gt) != len(summary_prediction): 309 | raise ValueError("Summary GT and prediction have different length") 310 | 311 | if not summary_gt["rendered_dir"].equals(summary_prediction["rendered_dir"]): 312 | raise ValueError("Summary GT and prediction have different rendered_dir") 313 | 314 | if not summary_gt["image_name"].equals(summary_prediction["image_name"]): 315 | raise ValueError("Summary GT and prediction have different image_name") 316 | -------------------------------------------------------------------------------- /utils/misc/image.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import matplotlib.cm as cm 4 | from PIL import Image, ImageDraw, ImageFont 5 | from utils.io.images import u8 6 | 7 | 8 | def jigsaw_to_image(x, grid_size): 9 | """ 10 | :param x: (B, N_patch_h * N_patch_w, patch_size_h, patch_size_w) 11 | :param grid_size: a tuple: (N_patch_h, N_patch_w) 12 | :return: (B, H, W) 13 | """ 14 | batch_size, num_patches, jigsaw_h, jigsaw_w = x.size() 15 | assert num_patches == grid_size[0] * grid_size[1] 16 | x_image = x.view(batch_size, grid_size[0], grid_size[1], jigsaw_h, jigsaw_w) 17 | output_h = grid_size[0] * jigsaw_h 18 | output_w = grid_size[1] * jigsaw_w 19 | x_image = x_image.permute(0, 1, 3, 2, 4).contiguous() 20 | x_image = x_image.view(batch_size, output_h, output_w) 21 | return x_image 22 | 23 | 24 | def de_norm_img(img, mean_std): 25 | """De-normalize images that are normalized by mean and std in ImageNet-style. 26 | :param img: (H, W, 3) 27 | :param mean_std: (6, ) 28 | """ 29 | mean, std = mean_std[:3], mean_std[3:] 30 | img = img * std[None, None] 31 | img = img + mean[None, None] 32 | return img 33 | 34 | 35 | def gray2rgb(img, vrange, cmap="turbo"): 36 | """ 37 | Args: 38 | img: HW, numpy.float32 39 | vrange: (min, max), float 40 | cmap: str 41 | """ 42 | vmin, vmax = vrange 43 | norm_op = plt.Normalize(vmin=vmin, vmax=vmax) 44 | colormap = cm.get_cmap(cmap) 45 | 46 | img = norm_op(img) 47 | img = colormap(img) 48 | rgb_image = u8(img[:, :, :3]) 49 | return rgb_image 50 | 51 | 52 | def attn2rgb(attn_map, cmap="turbo"): 53 | """Visualise attention map in rgb. 54 | The attn_map is softmaxed so we need to use log to make it more visible. 55 | Args: 56 | attn_map: HW, numpy.float32 57 | cmap: str 58 | """ 59 | eps = 1e-8 # to avoid log(0) 60 | attn_map = attn_map.clip(0, 1) 61 | attn_map = attn_map + eps # (1e-8, 1 + 1e-8) 62 | attn_map = attn_map.clip(0, 1) # (1e-8, 1) 63 | # invert softmax (exp'd) attn weights 64 | attn_map = np.log(attn_map) # (np.log(eps), 0) 65 | attn_map = attn_map - np.log(eps) # (0, -np.log(eps)) 66 | 67 | # some norm_op and colormap 68 | norm_op = plt.Normalize(vmin=0, vmax=-np.log(eps)) 69 | colormap = cm.get_cmap(cmap) 70 | attn_map = norm_op(attn_map) 71 | attn_map = colormap(attn_map) 72 | rgb_image = u8(attn_map[:, :, :3]) 73 | return rgb_image 74 | 75 | 76 | def img_add_text( 77 | img_rgb, 78 | text, 79 | text_position=(20, 20), 80 | text_colour=(255, 255, 255), 81 | font_size=50, 82 | font_path="/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf", 83 | ): 84 | img = Image.fromarray(img_rgb) 85 | font = ImageFont.truetype(font_path, font_size) 86 | draw = ImageDraw.Draw(img) 87 | draw.text(text_position, text, text_colour, font=font) 88 | img = np.array(img) 89 | return img 90 | -------------------------------------------------------------------------------- /utils/neighbour/sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import numpy as np 3 | 4 | 5 | class SampleBase(ABC): 6 | def __init__(self, N_sample): 7 | self.N_sample = N_sample 8 | 9 | @abstractmethod 10 | def sample(self): 11 | pass 12 | 13 | 14 | class SamplerRandom(SampleBase): 15 | def __init__(self, N_sample, deterministic): 16 | self.deterministic = deterministic 17 | super().__init__(N_sample) 18 | 19 | def sample(self, query, ref_list): 20 | num_ref = len(ref_list) 21 | if self.N_sample > num_ref: 22 | # pad empty_image placeholders if ref list < N_sample 23 | num_empty = self.N_sample - num_ref 24 | placeholder = ["empty_image"] * num_empty 25 | result = ref_list + placeholder 26 | result = np.random.permutation(result).tolist() 27 | else: 28 | result = [] 29 | 30 | if self.deterministic: 31 | samples = ref_list[: self.N_sample] 32 | else: 33 | samples = np.random.choice(ref_list, self.N_sample, replace=False).tolist() 34 | result.extend(samples) 35 | return result 36 | 37 | 38 | class SamplerFactory: 39 | def __init__( 40 | self, 41 | strategy_name, 42 | N_sample, 43 | deterministic, 44 | **kwargs, 45 | ): 46 | self.N_sample = N_sample 47 | self.deterministic = deterministic 48 | 49 | if strategy_name == "random": 50 | self.sampler = SamplerRandom( 51 | N_sample=self.N_sample, 52 | deterministic=self.deterministic, 53 | ) 54 | else: 55 | raise NotImplementedError 56 | 57 | def __call__(self, query, ref_list): 58 | return self.sampler.sample(query, ref_list) 59 | -------------------------------------------------------------------------------- /utils/plot/batch_visualiser.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | import torch 5 | import wandb 6 | import matplotlib.pyplot as plt 7 | from matplotlib.patches import Rectangle 8 | from torchvision.utils import make_grid 9 | from PIL import Image 10 | from utils.io.images import u8 11 | from utils.misc.image import de_norm_img 12 | from utils.check_config import check_reference_type 13 | 14 | 15 | class BatchVisualiserBase(ABC): 16 | """Organise batch data and preview in a single image.""" 17 | 18 | def __init__(self, cfg, img_mean_std, ref_type): 19 | self.cfg = cfg 20 | self.img_mean_std = img_mean_std.cpu().numpy() 21 | self.ref_type = ref_type 22 | self.metric_type = cfg.model.predict.metric.type 23 | self.metric_min = cfg.model.predict.metric.min 24 | self.metric_max = cfg.model.predict.metric.max 25 | 26 | @abstractmethod 27 | def vis(self, **kwargs): 28 | raise NotImplementedError 29 | 30 | 31 | class BatchVisualiserRef(BatchVisualiserBase): 32 | @torch.no_grad() 33 | def vis(self, batch_input, batch_output, vis_id=0): 34 | ref_type = self.ref_type 35 | 36 | fig = plt.figure(figsize=(6, 6.5), layout="constrained") 37 | # plt.style.use("dark_background") 38 | # fig.set_facecolor("gray") 39 | 40 | # arrange subfigure 41 | N_ref_imgs = batch_input[f"reference/{ref_type}/imgs"].shape[1] 42 | h_ratios_subfig = [3, np.ceil(N_ref_imgs / 5)] 43 | fig_top, fig_bottom = fig.subfigures(2, 1, height_ratios=h_ratios_subfig) 44 | 45 | # arrange subplot for top fig 46 | inner = [ 47 | ["score_map/gt", f"score_map/ref_{ref_type}"], 48 | ] 49 | 50 | mosaic_top = [ 51 | ["query", inner], 52 | ] 53 | width_ratio_top = [3, 2] 54 | ax_dict_top = fig_top.subplot_mosaic( 55 | mosaic_top, 56 | empty_sentinel=None, 57 | width_ratios=width_ratio_top, 58 | ) 59 | 60 | if "item_paths" in batch_input.keys(): 61 | # 1. anonymize the path 62 | # 2. split to two rows 63 | # 3. add query path to title 64 | query_path = batch_input["item_paths"]["query/img"][vis_id] 65 | query_path = query_path.replace(str(Path("~").expanduser()), "~/") 66 | query_path = query_path.split("/") 67 | query_path.insert(len(query_path) // 2, "\n") 68 | query_path = Path(*query_path) 69 | fig_top.suptitle(f"{query_path}", fontsize=7) 70 | 71 | # arrange subplot for bottom fig 72 | mosaic_bottom = [[f"ref/{ref_type}/imgs"]] 73 | height_ratios = [1] 74 | ax_dict_bottom = fig_bottom.subplot_mosaic( 75 | mosaic_bottom, 76 | empty_sentinel=None, 77 | height_ratios=height_ratios, 78 | ) 79 | 80 | # getting batch input output data 81 | img_dict = {"query": batch_input["query/img"][vis_id]} # 3HW 82 | img_dict[f"ref/{ref_type}/imgs"] = make_grid( 83 | batch_input[f"reference/{ref_type}/imgs"][vis_id], nrow=5 84 | ) # 3HW 85 | img_dict["score_map/gt"] = batch_input[f"query/score_map"][vis_id] # HW 86 | img_dict[f"score_map/ref_{ref_type}"] = batch_output[f"score_map_ref_{ref_type}"][ 87 | vis_id 88 | ] # HW 89 | 90 | # plot top fig 91 | for k, v in img_dict.items(): 92 | if k not in ax_dict_top.keys(): 93 | continue 94 | if k.startswith("score_map"): 95 | tmp_img = v.cpu().numpy() 96 | ax_dict_top[k].imshow( 97 | tmp_img, 98 | vmin=self.metric_min, 99 | vmax=self.metric_max, 100 | cmap="turbo", 101 | ) 102 | title_txt = k.replace("score", self.metric_type) 103 | title_txt = title_txt.replace("_map", "").replace("/", " ").replace("_", " ") 104 | title_txt = title_txt + f"\n{tmp_img.mean():.3f}" 105 | if k == f"score_map/ref_{ref_type}": 106 | if f"l1_diff_map_ref_{ref_type}" in batch_output.keys(): 107 | loss = batch_output[f"l1_diff_map_ref_{ref_type}"][vis_id].mean().item() 108 | title_txt = title_txt + f" ∆{loss:.2f}" 109 | ax_dict_top[k].set_title(title_txt, fontsize=8) 110 | elif k.startswith("delta_map"): 111 | tmp_img = v.cpu().numpy() 112 | # ax_dict_top[k].imshow(tmp_img, vmin=-0.5, vmax=0.5, cmap="turbo") 113 | ax_dict_top[k].imshow(tmp_img, vmin=-0.5, vmax=0.5, cmap="bwr") 114 | ax_dict_top[k].set_title(k.replace("/", "\n"), fontsize=8) 115 | else: 116 | tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 117 | tmp_img = de_norm_img(tmp_img, self.img_mean_std) 118 | tmp_img = tmp_img.clip(0, 1) 119 | tmp_img = u8(tmp_img) 120 | ax_dict_top[k].imshow(tmp_img) 121 | ax_dict_top[k].set_title(k) 122 | ax_dict_top[k].set_xticks([]) 123 | ax_dict_top[k].set_yticks([]) 124 | 125 | # add a colorbar, turn off ticks 126 | ax_colorbar = fig_top.colorbar( 127 | plt.cm.ScalarMappable(cmap="turbo"), ax=ax_dict_top["query"], fraction=0.046, pad=0.04 128 | ) 129 | ax_colorbar.ax.set_yticklabels([]) 130 | 131 | # plot bottom fig 132 | for k, v in img_dict.items(): 133 | if k not in ax_dict_bottom.keys(): 134 | continue 135 | if k.startswith("attn_weights_map"): 136 | # tmp_img = v.permute(1, 2, 0) # HW3 137 | tmp_img = v[0] # HW 138 | tmp_img = tmp_img.clamp(0, 1) 139 | tmp_img = tmp_img.cpu().numpy() 140 | ax_dict_bottom[k].imshow(tmp_img, vmin=0, cmap="turbo") 141 | ax_dict_bottom[k].set_title(k) 142 | else: 143 | tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 144 | tmp_img = de_norm_img(tmp_img, self.img_mean_std) 145 | tmp_img = tmp_img.clip(0, 1) 146 | tmp_img = u8(tmp_img) 147 | ax_dict_bottom[k].imshow(tmp_img) 148 | ax_dict_bottom[k].set_title(k) 149 | 150 | ax_dict_bottom[k].set_xticks([]) 151 | ax_dict_bottom[k].set_yticks([]) 152 | 153 | # output 154 | fig.canvas.draw() 155 | out_img = Image.frombytes( 156 | "RGBA", 157 | fig.canvas.get_width_height(), 158 | fig.canvas.buffer_rgba(), 159 | ).convert("RGB") 160 | 161 | plt.close() 162 | out_img = wandb.Image(out_img, file_type="jpg") 163 | return out_img 164 | 165 | 166 | class BatchVisualiserRefAttnMap(BatchVisualiserBase): 167 | def __init__(self, cfg, img_mean_std, ref_type, check_patch_mode): 168 | super().__init__(cfg, img_mean_std, ref_type) 169 | self.check_patch_mode = check_patch_mode 170 | 171 | @torch.no_grad() 172 | def vis(self, batch_input, batch_output, vis_id=0): 173 | fig = plt.figure(figsize=(6, 6.5), layout="constrained") 174 | # plt.style.use("dark_background") 175 | fig.set_facecolor("gray") 176 | 177 | has_reference_cross = "reference/cross/imgs" in batch_input.keys() 178 | has_attn_weights_cross = batch_output.get("attn_weights_map_ref_cross", None) is not None 179 | P = patch_size = 14 180 | 181 | # arrange subfigure 182 | h_ratios_subfig = [2, 1] 183 | 184 | fig_top, fig_bottom = fig.subfigures(2, 1, height_ratios=h_ratios_subfig) 185 | 186 | # arrange subplot for top fig 187 | inner = [] 188 | if has_reference_cross: 189 | inner.append(["score_map/gt1", "score_map/ref_cross", "score_map/diff_cross"]) 190 | inner = np.array(inner).T.tolist() # vertical layout 191 | 192 | mosaic_top = [ 193 | ["query", inner], 194 | ] 195 | width_ratio_top = [4, 1] 196 | 197 | ax_dict_top = fig_top.subplot_mosaic( 198 | mosaic_top, 199 | empty_sentinel=None, 200 | width_ratios=width_ratio_top, 201 | ) 202 | 203 | # arrange subplot for bottom fig 204 | mosaic_bottom = [] 205 | if has_reference_cross: 206 | mosaic_bottom.append(["ref/cross/imgs"]) 207 | if has_attn_weights_cross: 208 | mosaic_bottom.append(["attn_weights_map_ref_cross"]) 209 | 210 | height_ratios = [1] * len(mosaic_bottom) 211 | 212 | ax_dict_bottom = fig_bottom.subplot_mosaic( 213 | mosaic_bottom, 214 | empty_sentinel=None, 215 | height_ratios=height_ratios, 216 | ) 217 | 218 | # getting batch input output data 219 | img_dict = {"query": batch_input["query/img"][vis_id]} # 3HW 220 | if has_reference_cross: 221 | img_dict["ref/cross/imgs"] = make_grid( 222 | batch_input["reference/cross/imgs"][vis_id], nrow=5 223 | ) # 3HW 224 | img_dict["score_map/gt1"] = batch_input["query/score_map"][vis_id] 225 | img_dict["score_map/ref_cross"] = batch_output["score_map_ref_cross"][vis_id] 226 | if "l1_diff_map_ref_cross" in batch_output.keys(): 227 | img_dict["score_map/diff_cross"] = batch_output["l1_diff_map_ref_cross"][vis_id] 228 | 229 | if has_attn_weights_cross: 230 | attn_weights_map_ref_cross = batch_output["attn_weights_map_ref_cross"] # BHWNHW 231 | tmp_h, tmp_w = attn_weights_map_ref_cross.shape[1:3] 232 | if self.check_patch_mode == "centre": 233 | query_patch_cross = (tmp_h // 2, tmp_w // 2) 234 | elif self.check_patch_mode == "random": 235 | query_patch_cross = (np.random.randint(0, tmp_h), np.random.randint(0, tmp_w)) 236 | else: 237 | raise ValueError(f"Unknown check_patch_mode: {self.check_patch_mode}") 238 | 239 | # NHW 240 | attn_weights_map_ref_cross = attn_weights_map_ref_cross[vis_id][query_patch_cross] 241 | img_dict["attn_weights_map_ref_cross"] = make_grid( 242 | attn_weights_map_ref_cross[:, None], nrow=5 # make grid expects N3HW 243 | ) 244 | 245 | # plot top fig 246 | for k, v in img_dict.items(): 247 | if k not in ax_dict_top.keys(): 248 | continue 249 | if k.startswith("score_map"): 250 | tmp_img = v.cpu().numpy() 251 | ax_dict_top[k].imshow(tmp_img, vmin=0, vmax=1, cmap="turbo") 252 | ax_dict_top[k].set_title(k.replace("/", "\n"), fontsize=8) 253 | else: 254 | tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 255 | tmp_img = de_norm_img(tmp_img, self.img_mean_std) 256 | tmp_img = tmp_img.clip(0, 1) 257 | tmp_img = u8(tmp_img) 258 | ax_dict_top[k].imshow(tmp_img) 259 | ax_dict_top[k].set_title(k) 260 | 261 | if has_attn_weights_cross and k == "query": 262 | # draw a rectangle patch 263 | query_pixel = (query_patch_cross[0] * P, query_patch_cross[1] * P) 264 | rect = Rectangle( 265 | (query_pixel[1], query_pixel[0]), 266 | P, 267 | P, 268 | linewidth=2, 269 | edgecolor="magenta", 270 | facecolor="none", 271 | ) 272 | ax_dict_top[k].add_patch(rect) 273 | ax_dict_top[k].set_xticks([]) 274 | ax_dict_top[k].set_yticks([]) 275 | 276 | # plot bottom fig 277 | for k, v in img_dict.items(): 278 | if k not in ax_dict_bottom.keys(): 279 | continue 280 | if k.startswith("attn_weights_map"): 281 | num_stabler = 1e-8 # to avoid log(0) 282 | tmp_img = v[0] # HW 283 | tmp_img = tmp_img.clamp(0, 1) 284 | tmp_img = tmp_img.cpu().numpy() 285 | # invert softmax (exp'd) attn weights 286 | tmp_img = np.log(tmp_img + num_stabler) - np.log(num_stabler) 287 | ax_dict_bottom[k].imshow(tmp_img, vmax=-np.log(num_stabler), cmap="turbo") 288 | ax_dict_bottom[k].set_title(k) 289 | else: 290 | tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 291 | tmp_img = de_norm_img(tmp_img, self.img_mean_std) 292 | tmp_img = tmp_img.clip(0, 1) 293 | tmp_img = u8(tmp_img) 294 | ax_dict_bottom[k].imshow(tmp_img) 295 | ax_dict_bottom[k].set_title(k) 296 | 297 | ax_dict_bottom[k].set_xticks([]) 298 | ax_dict_bottom[k].set_yticks([]) 299 | 300 | # output 301 | fig.canvas.draw() 302 | out_img = Image.frombytes( 303 | "RGBA", 304 | fig.canvas.get_width_height(), 305 | fig.canvas.buffer_rgba(), 306 | ).convert("RGB") 307 | 308 | plt.close() 309 | out_img = wandb.Image(out_img, file_type="jpg") 310 | return out_img 311 | 312 | 313 | class BatchVisualiserRefFree(BatchVisualiserBase): 314 | @torch.no_grad() 315 | def vis(self, batch_input, batch_output, vis_id=0): 316 | fig = plt.figure(figsize=(6.5, 5), layout="constrained") 317 | # plt.style.use("dark_background") 318 | # fig.set_facecolor("gray") 319 | 320 | ref_type = self.ref_type 321 | inner = [["score_map/gt", f"score_map/ref_{ref_type}", "score_map/diff"]] 322 | inner = np.array(inner).T.tolist() 323 | 324 | mosaic = [ 325 | ["query", inner], 326 | ] 327 | width_ratio = [5, 2] 328 | ax_dict = fig.subplot_mosaic( 329 | mosaic, 330 | empty_sentinel=None, 331 | width_ratios=width_ratio, 332 | ) 333 | 334 | if "item_paths" in batch_input.keys(): 335 | # 1. anonymize the path 336 | # 2. split to two rows 337 | # 3. add query path to title 338 | query_path = batch_input["item_paths"]["query/img"][vis_id] 339 | query_path = query_path.replace(str(Path("~").expanduser()), "~/") 340 | query_path = query_path.split("/") 341 | query_path.insert(len(query_path) // 2, "\n") 342 | query_path = Path(*query_path) 343 | fig.suptitle(f"{query_path}", fontsize=7) 344 | 345 | # getting batch input output data 346 | img_dict = {"query": batch_input["query/img"][vis_id]} # 3HW 347 | img_dict["score_map/gt"] = batch_input[f"query/score_map"][vis_id] # HW 348 | 349 | # plot top fig 350 | for k, v in img_dict.items(): 351 | if k not in ax_dict.keys(): 352 | continue 353 | if k.startswith("score_map"): 354 | tmp_img = v.cpu().numpy() 355 | ax_dict[k].imshow( 356 | tmp_img, 357 | vmin=self.metric_min, 358 | vmax=self.metric_max, 359 | cmap="turbo", 360 | ) 361 | title_txt = k.replace("score", self.metric_type) 362 | title_txt = title_txt.replace("_map", "").replace("/", " ").replace("_", " ") 363 | title_txt = title_txt + f"\n{tmp_img.mean():.3f}" 364 | if k == f"score_map/ref_{ref_type}": 365 | loss = batch_output[f"l1_diff_map_ref_{ref_type}"][vis_id].mean().item() 366 | title_txt = title_txt + f" ∆{loss:.2f}" 367 | ax_dict[k].set_title(title_txt, fontsize=8) 368 | else: 369 | tmp_img = v.permute(1, 2, 0).cpu().numpy() # HW3 370 | tmp_img = de_norm_img(tmp_img, self.img_mean_std) 371 | tmp_img = tmp_img.clip(0, 1) 372 | tmp_img = u8(tmp_img) 373 | ax_dict[k].imshow(tmp_img) 374 | ax_dict[k].set_title(k) 375 | ax_dict[k].set_xticks([]) 376 | ax_dict[k].set_yticks([]) 377 | 378 | # add a colorbar, turn off ticks 379 | ax_colorbar = fig.colorbar( 380 | plt.cm.ScalarMappable(cmap="turbo"), ax=ax_dict["query"], fraction=0.046, pad=0.04 381 | ) 382 | ax_colorbar.ax.set_yticklabels([]) 383 | 384 | # output 385 | fig.canvas.draw() 386 | out_img = Image.frombytes( 387 | "RGBA", 388 | fig.canvas.get_width_height(), 389 | fig.canvas.buffer_rgba(), 390 | ).convert("RGB") 391 | 392 | plt.close() 393 | out_img = wandb.Image(out_img, file_type="jpg") 394 | return out_img 395 | 396 | 397 | class BatchVisualiserFactory: 398 | def __init__(self, cfg, img_mean_std): 399 | self.cfg = cfg 400 | self.img_mean_std = img_mean_std 401 | self.ref_type = check_reference_type(self.cfg.model.do_reference_cross) 402 | 403 | if self.ref_type in ["cross"]: 404 | if self.cfg.model.need_attn_weights: 405 | self.visualiser = BatchVisualiserRefAttnMap( 406 | cfg, self.img_mean_std, self.ref_type, check_patch_mode="centre" 407 | ) 408 | else: 409 | self.visualiser = BatchVisualiserRef(cfg, self.img_mean_std, self.ref_type) 410 | else: 411 | raise NotImplementedError 412 | 413 | def __call__(self): 414 | return self.visualiser 415 | --------------------------------------------------------------------------------