├── License.txt ├── README.md ├── assets ├── acid_long.mp4 ├── acid_preview.gif ├── acid_short.mp4 ├── firstpage.jpg ├── geofree_variants.png ├── realestate_long.mp4 ├── realestate_preview.gif ├── realestate_short.mp4 └── rooms_scenic_01_wkr.jpg ├── configs ├── acid │ ├── acid_13x23_expl_emb.yaml │ ├── acid_13x23_expl_feat.yaml │ ├── acid_13x23_expl_img.yaml │ ├── acid_13x23_hybrid.yaml │ ├── acid_13x23_impl_catdepth.yaml │ ├── acid_13x23_impl_depth.yaml │ └── acid_13x23_impl_nodepth.yaml └── realestate │ ├── realestate_13x23_expl_emb.yaml │ ├── realestate_13x23_expl_feat.yaml │ ├── realestate_13x23_expl_img.yaml │ ├── realestate_13x23_hybrid.yaml │ ├── realestate_13x23_impl_catdepth.yaml │ ├── realestate_13x23_impl_depth.yaml │ └── realestate_13x23_impl_nodepth.yaml ├── data ├── acid_custom_frames.txt ├── acid_train_sequences.txt ├── realestate_custom_frames.txt └── realestate_train_sequences.txt ├── environment.yaml ├── geofree ├── __init__.py ├── data │ ├── __init__.py │ ├── acid.py │ ├── read_write_model.py │ └── realestate.py ├── examples │ ├── __init__.py │ ├── artist.jpg │ └── beach.jpg ├── lr_scheduler.py ├── main.py ├── models │ ├── __init__.py │ ├── transformers │ │ ├── __init__.py │ │ ├── geogpt.py │ │ ├── net2net.py │ │ └── warpgpt.py │ └── vqgan.py ├── modules │ ├── __init__.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ └── model.py │ ├── losses │ │ ├── __init__.py │ │ └── vqperceptual.py │ ├── transformer │ │ ├── __init__.py │ │ ├── mingpt.py │ │ └── warper.py │ ├── util.py │ ├── vqvae │ │ ├── __init__.py │ │ └── quantize.py │ └── warp │ │ ├── __init__.py │ │ └── midas.py └── util.py ├── scripts ├── braindance.ipynb ├── braindance.py ├── database.py ├── download_vqmodels.py ├── sparse_from_realestate_format.py ├── sparsify_acid.sh ├── sparsify_realestate.sh └── strip_ckpt.py └── setup.py /License.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Robin Rombach and Patrick Esser and Björn Ommer 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 19 | OR OTHER DEALINGS IN THE SOFTWARE./ 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Geometry-Free View Synthesis: Transformers and no 3D Priors 2 | ![teaser](assets/firstpage.jpg) 3 | 4 | [**Geometry-Free View Synthesis: Transformers and no 3D Priors**](https://compvis.github.io/geometry-free-view-synthesis/)
5 | [Robin Rombach](https://github.com/rromb)\*, 6 | [Patrick Esser](https://github.com/pesser)\*, 7 | [Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
8 | \* equal contribution 9 | 10 | [arXiv](https://arxiv.org/abs/2104.07652) | [BibTeX](#bibtex) | [Colab](https://colab.research.google.com/github/CompVis/geometry-free-view-synthesis/blob/master/scripts/braindance.ipynb) 11 | 12 | ### Interactive Scene Exploration Results 13 | 14 | [RealEstate10K](https://google.github.io/realestate10k/):
15 | ![realestate](assets/realestate_preview.gif)
16 | Videos: [short (2min)](assets/realestate_short.mp4) / [long (12min)](assets/realestate_long.mp4) 17 | 18 | [ACID](https://infinite-nature.github.io/):
19 | ![acid](assets/acid_preview.gif)
20 | Videos: [short (2min)](assets/acid_short.mp4) / [long (9min)](assets/acid_long.mp4) 21 | 22 | ### Demo 23 | 24 | For a quickstart, you can try the [Colab 25 | demo](https://colab.research.google.com/github/CompVis/geometry-free-view-synthesis/blob/master/scripts/braindance.ipynb), 26 | but for a smoother experience we recommend installing the local demo as 27 | described below. 28 | 29 | #### Installation 30 | 31 | The demo requires building a PyTorch extension. If you have a sane development 32 | environment with PyTorch, g++ and nvcc, you can simply 33 | 34 | ``` 35 | pip install git+https://github.com/CompVis/geometry-free-view-synthesis#egg=geometry-free-view-synthesis 36 | ``` 37 | 38 | If you run into problems and have a GPU with compute capability below 8, you 39 | can also use the provided conda environment: 40 | 41 | ``` 42 | git clone https://github.com/CompVis/geometry-free-view-synthesis 43 | conda env create -f geometry-free-view-synthesis/environment.yaml 44 | conda activate geofree 45 | pip install geometry-free-view-synthesis/ 46 | ``` 47 | 48 | #### Running 49 | 50 | After [installation](#installation), running 51 | 52 | ``` 53 | braindance.py 54 | ``` 55 | 56 | will start the demo on [a sample scene](http://walledoffhotel.com/rooms.html). 57 | Explore the scene interactively using the `WASD` keys to move and `arrow keys` to 58 | look around. Once positioned, hit the `space bar` to render the novel view with 59 | GeoGPT. 60 | 61 | You can move again with WASD keys. Mouse control can be activated with the m 62 | key. Run `braindance.py ` to run the 63 | demo on your own images. By default, it uses the `re-impl-nodepth` (trained on 64 | RealEstate without explicit transformation and no depth input) which can be 65 | changed with the `--model` flag. The corresponding checkpoints will be 66 | downloaded the first time they are required. Specify an output path using 67 | `--video path/to/vid.mp4` to record a video. 68 | 69 | ``` 70 | > braindance.py -h 71 | usage: braindance.py [-h] [--model {re_impl_nodepth,re_impl_depth,ac_impl_nodepth,ac_impl_depth}] [--video [VIDEO]] [path] 72 | 73 | What's up, BD-maniacs? 74 | 75 | key(s) action 76 | ===================================== 77 | wasd move around 78 | arrows look around 79 | m enable looking with mouse 80 | space render with transformer 81 | q quit 82 | 83 | positional arguments: 84 | path path to image or directory from which to select image. Default example is used if not specified. 85 | 86 | optional arguments: 87 | -h, --help show this help message and exit 88 | --model {re_impl_nodepth,re_impl_depth,ac_impl_nodepth,ac_impl_depth} 89 | pretrained model to use. 90 | --video [VIDEO] path to write video recording to. (no recording if unspecified). 91 | ``` 92 | 93 | ## Training 94 | 95 | ### Data Preparation 96 | 97 | We support training on [RealEstate10K](https://google.github.io/realestate10k/) 98 | and [ACID](https://infinite-nature.github.io/). Both come in the same [format as 99 | described here](https://google.github.io/realestate10k/download.html) and the 100 | preparation is the same for both of them. You will need to have 101 | [`colmap`](https://github.com/colmap/colmap) installed and available on your 102 | `$PATH`. 103 | 104 | We assume that you have extracted the `.txt` files of the dataset you want to 105 | prepare into `$TXT_ROOT`, e.g. for RealEstate: 106 | 107 | ``` 108 | > tree $TXT_ROOT 109 | ├── test 110 | │   ├── 000c3ab189999a83.txt 111 | │   ├── ... 112 | │   └── fff9864727c42c80.txt 113 | └── train 114 | ├── 0000cc6d8b108390.txt 115 | ├── ... 116 | └── ffffe622a4de5489.txt 117 | ``` 118 | 119 | and that you have downloaded the frames (we downloaded them in resolution `640 120 | x 360`) into `$IMG_ROOT`, e.g. for RealEstate: 121 | 122 | ``` 123 | > tree $IMG_ROOT 124 | ├── test 125 | │   ├── 000c3ab189999a83 126 | │   │   ├── 45979267.png 127 | │   │   ├── ... 128 | │   │   └── 55255200.png 129 | │   ├── ... 130 | │   ├── 0017ce4c6a39d122 131 | │   │   ├── 40874000.png 132 | │   │   ├── ... 133 | │   │   └── 48482000.png 134 | ├── train 135 | │   ├── ... 136 | ``` 137 | 138 | To prepare the `$SPLIT` split of the dataset (`$SPLIT` being one of `train`, 139 | `test` for RealEstate and `train`, `test`, `validation` for ACID) in 140 | `$SPA_ROOT`, run the following within the `scripts` directory: 141 | 142 | ``` 143 | python sparse_from_realestate_format.py --txt_src ${TXT_ROOT}/${SPLIT} --img_src ${IMG_ROOT}/${SPLIT} --spa_dst ${SPA_ROOT}/${SPLIT} 144 | ``` 145 | 146 | You can also simply set `TXT_ROOT`, `IMG_ROOT` and `SPA_ROOT` as environment 147 | variables and run `./sparsify_realestate.sh` or `./sparsify_acid.sh`. Take a 148 | look into the sources to run with multiple workers in parallel. 149 | 150 | Finally, symlink `$SPA_ROOT` to `data/realestate_sparse`/`data/acid_sparse`. 151 | 152 | ### First Stage Models 153 | As described in [our paper](https://arxiv.org/abs/2104.07652), we train the transformer models in 154 | a compressed, discrete latent space of pretrained VQGANs. These pretrained models can be conveniently 155 | downloaded by running 156 | ``` 157 | python scripts/download_vqmodels.py 158 | ``` 159 | which will also create symlinks ensuring that the paths specified in the training configs (see `configs/*`) exist. 160 | In case some of the models have already been downloaded, the script will only create the symlinks. 161 | 162 | For training custom first stage models, we refer to the [taming transformers 163 | repository](https://github.com/CompVis/taming-transformers). 164 | 165 | ### Running the Training 166 | After both the preparation of the data and the first stage models are done, 167 | the experiments on ACID and RealEstate10K as described in our paper can be reproduced by running 168 | ``` 169 | python geofree/main.py --base configs//_13x23_.yaml -t --gpus 0, 170 | ``` 171 | where `` is one of `realestate`/`acid` and `` is one of 172 | `expl_img`/`expl_feat`/`expl_emb`/`impl_catdepth`/`impl_depth`/`impl_nodepth`/`hybrid`. 173 | These abbreviations correspond to the experiments listed in the following Table (see also Fig.2 in the main paper) 174 | 175 | ![variants](assets/geofree_variants.png) 176 | 177 | Note that each experiment was conducted on a GPU with 40 GB VRAM. 178 | 179 | ## BibTeX 180 | 181 | ``` 182 | @misc{rombach2021geometryfree, 183 | title={Geometry-Free View Synthesis: Transformers and no 3D Priors}, 184 | author={Robin Rombach and Patrick Esser and Björn Ommer}, 185 | year={2021}, 186 | eprint={2104.07652}, 187 | archivePrefix={arXiv}, 188 | primaryClass={cs.CV} 189 | } 190 | ``` 191 | -------------------------------------------------------------------------------- /assets/acid_long.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/acid_long.mp4 -------------------------------------------------------------------------------- /assets/acid_preview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/acid_preview.gif -------------------------------------------------------------------------------- /assets/acid_short.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/acid_short.mp4 -------------------------------------------------------------------------------- /assets/firstpage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/firstpage.jpg -------------------------------------------------------------------------------- /assets/geofree_variants.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/geofree_variants.png -------------------------------------------------------------------------------- /assets/realestate_long.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/realestate_long.mp4 -------------------------------------------------------------------------------- /assets/realestate_preview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/realestate_preview.gif -------------------------------------------------------------------------------- /assets/realestate_short.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/realestate_short.mp4 -------------------------------------------------------------------------------- /assets/rooms_scenic_01_wkr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/rooms_scenic_01_wkr.jpg -------------------------------------------------------------------------------- /configs/acid/acid_13x23_expl_emb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.warpgpt.WarpTransformer 4 | params: 5 | plot_cond_stage: True 6 | monitor: "val/loss" 7 | 8 | use_scheduler: True 9 | scheduler_config: 10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 11 | params: 12 | verbosity_interval: 0 # 0 or negative to disable 13 | warm_up_steps: 5000 14 | max_decay_steps: 500001 15 | lr_start: 2.5e-6 16 | lr_max: 1.5e-4 17 | lr_min: 1.0e-8 18 | 19 | transformer_config: 20 | target: geofree.modules.transformer.mingpt.WarpGPT 21 | params: 22 | vocab_size: 16384 23 | block_size: 597 # conditioning + 299 - 1 24 | n_unmasked: 299 # 299 cond embeddings 25 | n_layer: 32 26 | n_head: 16 27 | n_embd: 1024 28 | warper_config: 29 | target: geofree.modules.transformer.warper.ConvWarper 30 | params: 31 | size: [13, 23] 32 | 33 | first_stage_config: 34 | target: geofree.models.vqgan.VQModel 35 | params: 36 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt" 37 | embed_dim: 256 38 | n_embed: 16384 39 | ddconfig: 40 | double_z: False 41 | z_channels: 256 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 47 | num_res_blocks: 2 48 | attn_resolutions: [ 16 ] 49 | dropout: 0.0 50 | lossconfig: 51 | target: geofree.modules.losses.vqperceptual.DummyLoss 52 | 53 | cond_stage_config: "__is_first_stage__" 54 | 55 | data: 56 | target: geofree.main.DataModuleFromConfig 57 | params: 58 | # bs 8 and accumulate_grad_batches 2 for 34gb vram 59 | batch_size: 8 60 | num_workers: 16 61 | train: 62 | target: geofree.data.acid.ACIDSparseTrain 63 | params: 64 | size: 65 | - 208 66 | - 368 67 | 68 | validation: 69 | target: geofree.data.acid.ACIDCustomTest 70 | params: 71 | size: 72 | - 208 73 | - 368 74 | 75 | lightning: 76 | trainer: 77 | accumulate_grad_batches: 2 78 | benchmark: True 79 | -------------------------------------------------------------------------------- /configs/acid/acid_13x23_expl_feat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.net2net.WarpingFeatureTransformer 4 | params: 5 | plot_cond_stage: True 6 | monitor: "val/loss" 7 | 8 | use_scheduler: True 9 | scheduler_config: 10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 11 | params: 12 | verbosity_interval: 0 # 0 or negative to disable 13 | warm_up_steps: 5000 14 | max_decay_steps: 500001 15 | lr_start: 2.5e-6 16 | lr_max: 1.5e-4 17 | lr_min: 1.0e-8 18 | 19 | transformer_config: 20 | target: geofree.modules.transformer.mingpt.GPT 21 | params: 22 | vocab_size: 16384 23 | block_size: 597 # conditioning + 299 - 1 24 | n_unmasked: 299 # 299 cond embeddings 25 | n_layer: 32 26 | n_head: 16 27 | n_embd: 1024 28 | 29 | first_stage_key: 30 | x: "dst_img" 31 | 32 | cond_stage_key: 33 | c: "src_img" 34 | points: "src_points" 35 | R: "R_rel" 36 | t: "t_rel" 37 | K: "K" 38 | K_inv: "K_inv" 39 | 40 | first_stage_config: 41 | target: geofree.models.vqgan.VQModel 42 | params: 43 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt" 44 | embed_dim: 256 45 | n_embed: 16384 46 | ddconfig: 47 | double_z: False 48 | z_channels: 256 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 54 | num_res_blocks: 2 55 | attn_resolutions: [ 16 ] 56 | dropout: 0.0 57 | lossconfig: 58 | target: geofree.modules.losses.vqperceptual.DummyLoss 59 | 60 | cond_stage_config: "__is_first_stage__" 61 | 62 | data: 63 | target: geofree.main.DataModuleFromConfig 64 | params: 65 | # bs 8 and accumulate_grad_batches 2 for 34gb vram 66 | batch_size: 8 67 | num_workers: 16 68 | train: 69 | target: geofree.data.acid.ACIDSparseTrain 70 | params: 71 | size: 72 | - 208 73 | - 368 74 | 75 | validation: 76 | target: geofree.data.acid.ACIDCustomTest 77 | params: 78 | size: 79 | - 208 80 | - 368 81 | 82 | lightning: 83 | trainer: 84 | accumulate_grad_batches: 2 85 | benchmark: True 86 | -------------------------------------------------------------------------------- /configs/acid/acid_13x23_expl_img.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.net2net.WarpingTransformer 4 | params: 5 | plot_cond_stage: True 6 | monitor: "val/loss" 7 | 8 | use_scheduler: True 9 | scheduler_config: 10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 11 | params: 12 | verbosity_interval: 0 # 0 or negative to disable 13 | warm_up_steps: 5000 14 | max_decay_steps: 500001 15 | lr_start: 2.5e-6 16 | lr_max: 1.5e-4 17 | lr_min: 1.0e-8 18 | 19 | transformer_config: 20 | target: geofree.modules.transformer.mingpt.GPT 21 | params: 22 | vocab_size: 16384 23 | block_size: 597 # conditioning + 299 - 1 24 | n_unmasked: 299 # 299 cond embeddings 25 | n_layer: 32 26 | n_head: 16 27 | n_embd: 1024 28 | 29 | first_stage_config: 30 | target: geofree.models.vqgan.VQModel 31 | params: 32 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt" 33 | embed_dim: 256 34 | n_embed: 16384 35 | ddconfig: 36 | double_z: False 37 | z_channels: 256 38 | resolution: 256 39 | in_channels: 3 40 | out_ch: 3 41 | ch: 128 42 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 43 | num_res_blocks: 2 44 | attn_resolutions: [ 16 ] 45 | dropout: 0.0 46 | lossconfig: 47 | target: geofree.modules.losses.vqperceptual.DummyLoss 48 | 49 | data: 50 | target: geofree.main.DataModuleFromConfig 51 | params: 52 | # bs 8 and accumulate_grad_batches 2 for 34gb vram 53 | batch_size: 8 54 | num_workers: 16 55 | train: 56 | target: geofree.data.acid.ACIDSparseTrain 57 | params: 58 | size: 59 | - 208 60 | - 368 61 | 62 | validation: 63 | target: geofree.data.acid.ACIDCustomTest 64 | params: 65 | size: 66 | - 208 67 | - 368 68 | 69 | lightning: 70 | trainer: 71 | accumulate_grad_batches: 2 72 | benchmark: True 73 | -------------------------------------------------------------------------------- /configs/acid/acid_13x23_hybrid.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.geogpt.WarpGeoTransformer 4 | params: 5 | merge_channels: 512 # channels of cond vq + depth vq 6 | 7 | plot_cond_stage: True 8 | monitor: "val/loss" 9 | 10 | use_scheduler: True 11 | scheduler_config: 12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 13 | params: 14 | verbosity_interval: 0 # 0 or negative to disable 15 | warm_up_steps: 5000 16 | max_decay_steps: 500001 17 | lr_start: 2.5e-6 18 | lr_max: 1.5e-4 19 | lr_min: 1.0e-8 20 | 21 | transformer_config: 22 | target: geofree.modules.transformer.mingpt.WarpGPT 23 | params: 24 | vocab_size: 16384 25 | block_size: 627 # conditioning + 299 - 1 26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings 27 | n_layer: 32 28 | n_head: 16 29 | n_embd: 1024 30 | warper_config: 31 | target: geofree.modules.transformer.warper.ConvWarper 32 | params: 33 | start_idx: 30 # do not warp camera embeddings 34 | size: [13, 23] 35 | 36 | first_stage_key: 37 | x: "dst_img" 38 | 39 | cond_stage_key: 40 | c: "src_img" 41 | 42 | emb_stage_key: 43 | points: "src_points" 44 | R: "R_rel" 45 | t: "t_rel" 46 | K: "K" 47 | K_inv: "K_inv" 48 | 49 | first_stage_config: 50 | target: geofree.models.vqgan.VQModel 51 | params: 52 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt" 53 | embed_dim: 256 54 | n_embed: 16384 55 | ddconfig: 56 | double_z: False 57 | z_channels: 256 58 | resolution: 256 59 | in_channels: 3 60 | out_ch: 3 61 | ch: 128 62 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 63 | num_res_blocks: 2 64 | attn_resolutions: [ 16 ] 65 | dropout: 0.0 66 | lossconfig: 67 | target: geofree.modules.losses.vqperceptual.DummyLoss 68 | 69 | cond_stage_config: "__is_first_stage__" 70 | 71 | depth_stage_config: 72 | target: geofree.models.vqgan.VQModel 73 | params: 74 | ckpt_path: "pretrained_models/acid_depth_stage/last.ckpt" 75 | embed_dim: 256 76 | n_embed: 1024 77 | ddconfig: 78 | double_z: false 79 | z_channels: 256 80 | resolution: 256 81 | in_channels: 1 82 | out_ch: 1 83 | ch: 128 84 | ch_mult: 85 | - 1 86 | - 1 87 | - 2 88 | - 2 89 | - 4 90 | num_res_blocks: 2 91 | attn_resolutions: 92 | - 16 93 | dropout: 0.0 94 | lossconfig: 95 | target: geofree.modules.losses.vqperceptual.DummyLoss 96 | 97 | emb_stage_config: 98 | target: geofree.modules.util.MultiEmbedder 99 | params: 100 | keys: 101 | - "R" 102 | - "t" 103 | - "K" 104 | - "K_inv" 105 | n_positions: 30 106 | n_channels: 1 107 | n_embed: 1024 108 | bias: False 109 | 110 | data: 111 | target: geofree.main.DataModuleFromConfig 112 | params: 113 | # bs 8 and accumulate_grad_batches 2 for 36gb vram 114 | batch_size: 8 115 | num_workers: 16 116 | train: 117 | target: geofree.data.acid.ACIDSparseTrain 118 | params: 119 | size: 120 | - 208 121 | - 368 122 | 123 | validation: 124 | target: geofree.data.acid.ACIDCustomTest 125 | params: 126 | size: 127 | - 208 128 | - 368 129 | 130 | lightning: 131 | trainer: 132 | accumulate_grad_batches: 2 133 | benchmark: True 134 | -------------------------------------------------------------------------------- /configs/acid/acid_13x23_impl_catdepth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.geogpt.GeoTransformer 4 | params: 5 | merge_channels: null 6 | 7 | plot_cond_stage: True 8 | monitor: "val/loss" 9 | 10 | use_scheduler: True 11 | scheduler_config: 12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 13 | params: 14 | verbosity_interval: 0 # 0 or negative to disable 15 | warm_up_steps: 5000 16 | max_decay_steps: 500001 17 | lr_start: 2.5e-6 18 | lr_max: 1.5e-4 19 | lr_min: 1.0e-8 20 | 21 | transformer_config: 22 | target: geofree.modules.transformer.mingpt.GPT 23 | params: 24 | vocab_size: 16384 25 | block_size: 926 # conditioning + 299 - 1 26 | n_unmasked: 628 # 30 camera embeddings + 299 depth and 299 cond embeddings 27 | n_layer: 32 28 | n_head: 16 29 | n_embd: 1024 30 | 31 | first_stage_key: 32 | x: "dst_img" 33 | 34 | cond_stage_key: 35 | c: "src_img" 36 | 37 | emb_stage_key: 38 | points: "src_points" 39 | R: "R_rel" 40 | t: "t_rel" 41 | K: "K" 42 | K_inv: "K_inv" 43 | 44 | first_stage_config: 45 | target: geofree.models.vqgan.VQModel 46 | params: 47 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt" 48 | embed_dim: 256 49 | n_embed: 16384 50 | ddconfig: 51 | double_z: False 52 | z_channels: 256 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 58 | num_res_blocks: 2 59 | attn_resolutions: [ 16 ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: geofree.modules.losses.vqperceptual.DummyLoss 63 | 64 | cond_stage_config: "__is_first_stage__" 65 | 66 | depth_stage_config: 67 | target: geofree.models.vqgan.VQModel 68 | params: 69 | ckpt_path: "pretrained_models/acid_depth_stage/last.ckpt" 70 | embed_dim: 256 71 | n_embed: 1024 72 | ddconfig: 73 | double_z: false 74 | z_channels: 256 75 | resolution: 256 76 | in_channels: 1 77 | out_ch: 1 78 | ch: 128 79 | ch_mult: 80 | - 1 81 | - 1 82 | - 2 83 | - 2 84 | - 4 85 | num_res_blocks: 2 86 | attn_resolutions: 87 | - 16 88 | dropout: 0.0 89 | lossconfig: 90 | target: geofree.modules.losses.vqperceptual.DummyLoss 91 | 92 | emb_stage_config: 93 | target: geofree.modules.util.MultiEmbedder 94 | params: 95 | keys: 96 | - "R" 97 | - "t" 98 | - "K" 99 | - "K_inv" 100 | n_positions: 30 101 | n_channels: 1 102 | n_embed: 1024 103 | bias: False 104 | 105 | data: 106 | target: geofree.main.DataModuleFromConfig 107 | params: 108 | # bs 4 and accumulate_grad_batches 4 for 34gb vram 109 | batch_size: 4 110 | num_workers: 8 111 | train: 112 | target: geofree.data.acid.ACIDSparseTrain 113 | params: 114 | size: 115 | - 208 116 | - 368 117 | 118 | validation: 119 | target: geofree.data.acid.ACIDCustomTest 120 | params: 121 | size: 122 | - 208 123 | - 368 124 | 125 | lightning: 126 | trainer: 127 | accumulate_grad_batches: 4 128 | benchmark: True 129 | -------------------------------------------------------------------------------- /configs/acid/acid_13x23_impl_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.geogpt.GeoTransformer 4 | params: 5 | merge_channels: 512 # channels of cond vq + depth vq 6 | 7 | plot_cond_stage: True 8 | monitor: "val/loss" 9 | 10 | use_scheduler: True 11 | scheduler_config: 12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 13 | params: 14 | verbosity_interval: 0 # 0 or negative to disable 15 | warm_up_steps: 5000 16 | max_decay_steps: 500001 17 | lr_start: 2.5e-6 18 | lr_max: 1.5e-4 19 | lr_min: 1.0e-8 20 | 21 | transformer_config: 22 | target: geofree.modules.transformer.mingpt.GPT 23 | params: 24 | vocab_size: 16384 25 | block_size: 627 # conditioning + 299 - 1 26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings 27 | n_layer: 32 28 | n_head: 16 29 | n_embd: 1024 30 | 31 | first_stage_key: 32 | x: "dst_img" 33 | 34 | cond_stage_key: 35 | c: "src_img" 36 | 37 | emb_stage_key: 38 | points: "src_points" 39 | R: "R_rel" 40 | t: "t_rel" 41 | K: "K" 42 | K_inv: "K_inv" 43 | 44 | first_stage_config: 45 | target: geofree.models.vqgan.VQModel 46 | params: 47 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt" 48 | embed_dim: 256 49 | n_embed: 16384 50 | ddconfig: 51 | double_z: False 52 | z_channels: 256 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 58 | num_res_blocks: 2 59 | attn_resolutions: [ 16 ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: geofree.modules.losses.vqperceptual.DummyLoss 63 | 64 | cond_stage_config: "__is_first_stage__" 65 | 66 | depth_stage_config: 67 | target: geofree.models.vqgan.VQModel 68 | params: 69 | ckpt_path: "pretrained_models/acid_depth_stage/last.ckpt" 70 | embed_dim: 256 71 | n_embed: 1024 72 | ddconfig: 73 | double_z: false 74 | z_channels: 256 75 | resolution: 256 76 | in_channels: 1 77 | out_ch: 1 78 | ch: 128 79 | ch_mult: 80 | - 1 81 | - 1 82 | - 2 83 | - 2 84 | - 4 85 | num_res_blocks: 2 86 | attn_resolutions: 87 | - 16 88 | dropout: 0.0 89 | lossconfig: 90 | target: geofree.modules.losses.vqperceptual.DummyLoss 91 | 92 | emb_stage_config: 93 | target: geofree.modules.util.MultiEmbedder 94 | params: 95 | keys: 96 | - "R" 97 | - "t" 98 | - "K" 99 | - "K_inv" 100 | n_positions: 30 101 | n_channels: 1 102 | n_embed: 1024 103 | bias: False 104 | 105 | data: 106 | target: geofree.main.DataModuleFromConfig 107 | params: 108 | # bs 8 and accumulate_grad_batches 2 for 36gb vram 109 | batch_size: 8 110 | num_workers: 16 111 | train: 112 | target: geofree.data.acid.ACIDSparseTrain 113 | params: 114 | size: 115 | - 208 116 | - 368 117 | 118 | validation: 119 | target: geofree.data.acid.ACIDCustomTest 120 | params: 121 | size: 122 | - 208 123 | - 368 124 | 125 | lightning: 126 | trainer: 127 | accumulate_grad_batches: 2 128 | benchmark: True 129 | -------------------------------------------------------------------------------- /configs/acid/acid_13x23_impl_nodepth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.geogpt.GeoTransformer 4 | params: 5 | use_depth: False # depth is not provided to transformer but only used to rescale t 6 | 7 | plot_cond_stage: True 8 | monitor: "val/loss" 9 | 10 | use_scheduler: True 11 | scheduler_config: 12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 13 | params: 14 | verbosity_interval: 0 # 0 or negative to disable 15 | warm_up_steps: 5000 16 | max_decay_steps: 500001 17 | lr_start: 2.5e-6 18 | lr_max: 1.5e-4 19 | lr_min: 1.0e-8 20 | 21 | transformer_config: 22 | target: geofree.modules.transformer.mingpt.GPT 23 | params: 24 | vocab_size: 16384 25 | block_size: 627 # conditioning + 299 - 1 26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings 27 | n_layer: 32 28 | n_head: 16 29 | n_embd: 1024 30 | 31 | first_stage_key: 32 | x: "dst_img" 33 | 34 | cond_stage_key: 35 | c: "src_img" 36 | 37 | emb_stage_key: 38 | points: "src_points" 39 | R: "R_rel" 40 | t: "t_rel" 41 | K: "K" 42 | K_inv: "K_inv" 43 | 44 | first_stage_config: 45 | target: geofree.models.vqgan.VQModel 46 | params: 47 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt" 48 | embed_dim: 256 49 | n_embed: 16384 50 | ddconfig: 51 | double_z: False 52 | z_channels: 256 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 58 | num_res_blocks: 2 59 | attn_resolutions: [ 16 ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: geofree.modules.losses.vqperceptual.DummyLoss 63 | 64 | cond_stage_config: "__is_first_stage__" 65 | 66 | depth_stage_config: 67 | target: geofree.models.vqgan.VQModel 68 | params: 69 | ckpt_path: "pretrained_models/acid_depth_stage/last.ckpt" 70 | embed_dim: 256 71 | n_embed: 1024 72 | ddconfig: 73 | double_z: false 74 | z_channels: 256 75 | resolution: 256 76 | in_channels: 1 77 | out_ch: 1 78 | ch: 128 79 | ch_mult: 80 | - 1 81 | - 1 82 | - 2 83 | - 2 84 | - 4 85 | num_res_blocks: 2 86 | attn_resolutions: 87 | - 16 88 | dropout: 0.0 89 | lossconfig: 90 | target: geofree.modules.losses.vqperceptual.DummyLoss 91 | 92 | emb_stage_config: 93 | target: geofree.modules.util.MultiEmbedder 94 | params: 95 | keys: 96 | - "R" 97 | - "t" 98 | - "K" 99 | - "K_inv" 100 | n_positions: 30 101 | n_channels: 1 102 | n_embed: 1024 103 | bias: False 104 | 105 | data: 106 | target: geofree.main.DataModuleFromConfig 107 | params: 108 | # bs 8 and accumulate_grad_batches 2 for 36gb vram 109 | batch_size: 8 110 | num_workers: 16 111 | train: 112 | target: geofree.data.acid.ACIDSparseTrain 113 | params: 114 | size: 115 | - 208 116 | - 368 117 | 118 | validation: 119 | target: geofree.data.acid.ACIDCustomTest 120 | params: 121 | size: 122 | - 208 123 | - 368 124 | 125 | lightning: 126 | trainer: 127 | accumulate_grad_batches: 2 128 | benchmark: True 129 | -------------------------------------------------------------------------------- /configs/realestate/realestate_13x23_expl_emb.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.warpgpt.WarpTransformer 4 | params: 5 | plot_cond_stage: True 6 | monitor: "val/loss" 7 | 8 | use_scheduler: True 9 | scheduler_config: 10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 11 | params: 12 | verbosity_interval: 0 # 0 or negative to disable 13 | warm_up_steps: 5000 14 | max_decay_steps: 500001 15 | lr_start: 2.5e-6 16 | lr_max: 1.5e-4 17 | lr_min: 1.0e-8 18 | 19 | transformer_config: 20 | target: geofree.modules.transformer.mingpt.WarpGPT 21 | params: 22 | vocab_size: 16384 23 | block_size: 597 # conditioning + 299 - 1 24 | n_unmasked: 299 # 299 cond embeddings 25 | n_layer: 32 26 | n_head: 16 27 | n_embd: 1024 28 | warper_config: 29 | target: geofree.modules.transformer.warper.ConvWarper 30 | params: 31 | size: [13, 23] 32 | 33 | first_stage_config: 34 | target: geofree.models.vqgan.VQModel 35 | params: 36 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt" 37 | embed_dim: 256 38 | n_embed: 16384 39 | ddconfig: 40 | double_z: False 41 | z_channels: 256 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 47 | num_res_blocks: 2 48 | attn_resolutions: [ 16 ] 49 | dropout: 0.0 50 | lossconfig: 51 | target: geofree.modules.losses.vqperceptual.DummyLoss 52 | 53 | cond_stage_config: "__is_first_stage__" 54 | 55 | data: 56 | target: geofree.main.DataModuleFromConfig 57 | params: 58 | # bs 8 and accumulate_grad_batches 2 for 34gb vram 59 | batch_size: 8 60 | num_workers: 16 61 | train: 62 | target: geofree.data.realestate.RealEstate10KSparseTrain 63 | params: 64 | size: 65 | - 208 66 | - 368 67 | 68 | validation: 69 | target: geofree.data.realestate.RealEstate10KCustomTest 70 | params: 71 | size: 72 | - 208 73 | - 368 74 | 75 | lightning: 76 | trainer: 77 | accumulate_grad_batches: 2 78 | benchmark: True 79 | -------------------------------------------------------------------------------- /configs/realestate/realestate_13x23_expl_feat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.net2net.WarpingFeatureTransformer 4 | params: 5 | plot_cond_stage: True 6 | monitor: "val/loss" 7 | 8 | use_scheduler: True 9 | scheduler_config: 10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 11 | params: 12 | verbosity_interval: 0 # 0 or negative to disable 13 | warm_up_steps: 5000 14 | max_decay_steps: 500001 15 | lr_start: 2.5e-6 16 | lr_max: 1.5e-4 17 | lr_min: 1.0e-8 18 | 19 | transformer_config: 20 | target: geofree.modules.transformer.mingpt.GPT 21 | params: 22 | vocab_size: 16384 23 | block_size: 597 # conditioning + 299 - 1 24 | n_unmasked: 299 # 299 cond embeddings 25 | n_layer: 32 26 | n_head: 16 27 | n_embd: 1024 28 | 29 | first_stage_key: 30 | x: "dst_img" 31 | 32 | cond_stage_key: 33 | c: "src_img" 34 | points: "src_points" 35 | R: "R_rel" 36 | t: "t_rel" 37 | K: "K" 38 | K_inv: "K_inv" 39 | 40 | first_stage_config: 41 | target: geofree.models.vqgan.VQModel 42 | params: 43 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt" 44 | embed_dim: 256 45 | n_embed: 16384 46 | ddconfig: 47 | double_z: False 48 | z_channels: 256 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 54 | num_res_blocks: 2 55 | attn_resolutions: [ 16 ] 56 | dropout: 0.0 57 | lossconfig: 58 | target: geofree.modules.losses.vqperceptual.DummyLoss 59 | 60 | cond_stage_config: "__is_first_stage__" 61 | 62 | data: 63 | target: geofree.main.DataModuleFromConfig 64 | params: 65 | # bs 8 and accumulate_grad_batches 2 for 34gb vram 66 | batch_size: 8 67 | num_workers: 16 68 | train: 69 | target: geofree.data.realestate.RealEstate10KSparseTrain 70 | params: 71 | size: 72 | - 208 73 | - 368 74 | 75 | validation: 76 | target: geofree.data.realestate.RealEstate10KCustomTest 77 | params: 78 | size: 79 | - 208 80 | - 368 81 | 82 | lightning: 83 | trainer: 84 | accumulate_grad_batches: 2 85 | benchmark: True 86 | -------------------------------------------------------------------------------- /configs/realestate/realestate_13x23_expl_img.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.net2net.WarpingTransformer 4 | params: 5 | plot_cond_stage: True 6 | monitor: "val/loss" 7 | 8 | use_scheduler: True 9 | scheduler_config: 10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 11 | params: 12 | verbosity_interval: 0 # 0 or negative to disable 13 | warm_up_steps: 5000 14 | max_decay_steps: 500001 15 | lr_start: 2.5e-6 16 | lr_max: 1.5e-4 17 | lr_min: 1.0e-8 18 | 19 | transformer_config: 20 | target: geofree.modules.transformer.mingpt.GPT 21 | params: 22 | vocab_size: 16384 23 | block_size: 597 # conditioning + 299 - 1 24 | n_unmasked: 299 # 299 cond embeddings 25 | n_layer: 32 26 | n_head: 16 27 | n_embd: 1024 28 | 29 | first_stage_config: 30 | target: geofree.models.vqgan.VQModel 31 | params: 32 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt" 33 | embed_dim: 256 34 | n_embed: 16384 35 | ddconfig: 36 | double_z: False 37 | z_channels: 256 38 | resolution: 256 39 | in_channels: 3 40 | out_ch: 3 41 | ch: 128 42 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 43 | num_res_blocks: 2 44 | attn_resolutions: [ 16 ] 45 | dropout: 0.0 46 | lossconfig: 47 | target: geofree.modules.losses.vqperceptual.DummyLoss 48 | 49 | data: 50 | target: geofree.main.DataModuleFromConfig 51 | params: 52 | # bs 8 and accumulate_grad_batches 2 for 34gb vram 53 | batch_size: 8 54 | num_workers: 16 55 | train: 56 | target: geofree.data.realestate.RealEstate10KSparseTrain 57 | params: 58 | size: 59 | - 208 60 | - 368 61 | 62 | validation: 63 | target: geofree.data.realestate.RealEstate10KCustomTest 64 | params: 65 | size: 66 | - 208 67 | - 368 68 | 69 | lightning: 70 | trainer: 71 | accumulate_grad_batches: 2 72 | benchmark: True 73 | -------------------------------------------------------------------------------- /configs/realestate/realestate_13x23_hybrid.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.geogpt.WarpGeoTransformer 4 | params: 5 | merge_channels: 512 # channels of cond vq + depth vq 6 | 7 | plot_cond_stage: True 8 | monitor: "val/loss" 9 | 10 | use_scheduler: True 11 | scheduler_config: 12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 13 | params: 14 | verbosity_interval: 0 # 0 or negative to disable 15 | warm_up_steps: 5000 16 | max_decay_steps: 500001 17 | lr_start: 2.5e-6 18 | lr_max: 1.5e-4 19 | lr_min: 1.0e-8 20 | 21 | transformer_config: 22 | target: geofree.modules.transformer.mingpt.WarpGPT 23 | params: 24 | vocab_size: 16384 25 | block_size: 627 # conditioning + 299 - 1 26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings 27 | n_layer: 32 28 | n_head: 16 29 | n_embd: 1024 30 | warper_config: 31 | target: geofree.modules.transformer.warper.ConvWarper 32 | params: 33 | start_idx: 30 # do not warp camera embeddings 34 | size: [13, 23] 35 | 36 | first_stage_key: 37 | x: "dst_img" 38 | 39 | cond_stage_key: 40 | c: "src_img" 41 | 42 | emb_stage_key: 43 | points: "src_points" 44 | R: "R_rel" 45 | t: "t_rel" 46 | K: "K" 47 | K_inv: "K_inv" 48 | 49 | first_stage_config: 50 | target: geofree.models.vqgan.VQModel 51 | params: 52 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt" 53 | embed_dim: 256 54 | n_embed: 16384 55 | ddconfig: 56 | double_z: False 57 | z_channels: 256 58 | resolution: 256 59 | in_channels: 3 60 | out_ch: 3 61 | ch: 128 62 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 63 | num_res_blocks: 2 64 | attn_resolutions: [ 16 ] 65 | dropout: 0.0 66 | lossconfig: 67 | target: geofree.modules.losses.vqperceptual.DummyLoss 68 | 69 | cond_stage_config: "__is_first_stage__" 70 | 71 | depth_stage_config: 72 | target: geofree.models.vqgan.VQModel 73 | params: 74 | ckpt_path: "pretrained_models/realestate_depth_stage/last.ckpt" 75 | embed_dim: 256 76 | n_embed: 1024 77 | ddconfig: 78 | double_z: false 79 | z_channels: 256 80 | resolution: 256 81 | in_channels: 1 82 | out_ch: 1 83 | ch: 128 84 | ch_mult: 85 | - 1 86 | - 1 87 | - 2 88 | - 2 89 | - 4 90 | num_res_blocks: 2 91 | attn_resolutions: 92 | - 16 93 | dropout: 0.0 94 | lossconfig: 95 | target: geofree.modules.losses.vqperceptual.DummyLoss 96 | 97 | emb_stage_config: 98 | target: geofree.modules.util.MultiEmbedder 99 | params: 100 | keys: 101 | - "R" 102 | - "t" 103 | - "K" 104 | - "K_inv" 105 | n_positions: 30 106 | n_channels: 1 107 | n_embed: 1024 108 | bias: False 109 | 110 | data: 111 | target: geofree.main.DataModuleFromConfig 112 | params: 113 | # bs 8 and accumulate_grad_batches 2 for 36gb vram 114 | batch_size: 8 115 | num_workers: 16 116 | train: 117 | target: geofree.data.realestate.RealEstate10KSparseTrain 118 | params: 119 | size: 120 | - 208 121 | - 368 122 | 123 | validation: 124 | target: geofree.data.realestate.RealEstate10KCustomTest 125 | params: 126 | size: 127 | - 208 128 | - 368 129 | 130 | lightning: 131 | trainer: 132 | accumulate_grad_batches: 2 133 | benchmark: True 134 | -------------------------------------------------------------------------------- /configs/realestate/realestate_13x23_impl_catdepth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.geogpt.GeoTransformer 4 | params: 5 | merge_channels: null 6 | 7 | plot_cond_stage: True 8 | monitor: "val/loss" 9 | 10 | use_scheduler: True 11 | scheduler_config: 12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 13 | params: 14 | verbosity_interval: 0 # 0 or negative to disable 15 | warm_up_steps: 5000 16 | max_decay_steps: 500001 17 | lr_start: 2.5e-6 18 | lr_max: 1.5e-4 19 | lr_min: 1.0e-8 20 | 21 | transformer_config: 22 | target: geofree.modules.transformer.mingpt.GPT 23 | params: 24 | vocab_size: 16384 25 | block_size: 926 # conditioning + 299 - 1 26 | n_unmasked: 628 # 30 camera embeddings + 299 depth and 299 cond embeddings 27 | n_layer: 32 28 | n_head: 16 29 | n_embd: 1024 30 | 31 | first_stage_key: 32 | x: "dst_img" 33 | 34 | cond_stage_key: 35 | c: "src_img" 36 | 37 | emb_stage_key: 38 | points: "src_points" 39 | R: "R_rel" 40 | t: "t_rel" 41 | K: "K" 42 | K_inv: "K_inv" 43 | 44 | first_stage_config: 45 | target: geofree.models.vqgan.VQModel 46 | params: 47 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt" 48 | embed_dim: 256 49 | n_embed: 16384 50 | ddconfig: 51 | double_z: False 52 | z_channels: 256 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 58 | num_res_blocks: 2 59 | attn_resolutions: [ 16 ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: geofree.modules.losses.vqperceptual.DummyLoss 63 | 64 | cond_stage_config: "__is_first_stage__" 65 | 66 | depth_stage_config: 67 | target: geofree.models.vqgan.VQModel 68 | params: 69 | ckpt_path: "pretrained_models/realestate_depth_stage/last.ckpt" 70 | embed_dim: 256 71 | n_embed: 1024 72 | ddconfig: 73 | double_z: false 74 | z_channels: 256 75 | resolution: 256 76 | in_channels: 1 77 | out_ch: 1 78 | ch: 128 79 | ch_mult: 80 | - 1 81 | - 1 82 | - 2 83 | - 2 84 | - 4 85 | num_res_blocks: 2 86 | attn_resolutions: 87 | - 16 88 | dropout: 0.0 89 | lossconfig: 90 | target: geofree.modules.losses.vqperceptual.DummyLoss 91 | 92 | emb_stage_config: 93 | target: geofree.modules.util.MultiEmbedder 94 | params: 95 | keys: 96 | - "R" 97 | - "t" 98 | - "K" 99 | - "K_inv" 100 | n_positions: 30 101 | n_channels: 1 102 | n_embed: 1024 103 | bias: False 104 | 105 | data: 106 | target: geofree.main.DataModuleFromConfig 107 | params: 108 | # bs 4 and accumulate_grad_batches 4 for 34gb vram 109 | batch_size: 4 110 | num_workers: 8 111 | train: 112 | target: geofree.data.realestate.RealEstate10KSparseTrain 113 | params: 114 | size: 115 | - 208 116 | - 368 117 | 118 | validation: 119 | target: geofree.data.realestate.RealEstate10KCustomTest 120 | params: 121 | size: 122 | - 208 123 | - 368 124 | 125 | lightning: 126 | trainer: 127 | accumulate_grad_batches: 4 128 | benchmark: True 129 | -------------------------------------------------------------------------------- /configs/realestate/realestate_13x23_impl_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.geogpt.GeoTransformer 4 | params: 5 | merge_channels: 512 # channels of cond vq + depth vq 6 | 7 | plot_cond_stage: True 8 | monitor: "val/loss" 9 | 10 | use_scheduler: True 11 | scheduler_config: 12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 13 | params: 14 | verbosity_interval: 0 # 0 or negative to disable 15 | warm_up_steps: 5000 16 | max_decay_steps: 500001 17 | lr_start: 2.5e-6 18 | lr_max: 1.5e-4 19 | lr_min: 1.0e-8 20 | 21 | transformer_config: 22 | target: geofree.modules.transformer.mingpt.GPT 23 | params: 24 | vocab_size: 16384 25 | block_size: 627 # conditioning + 299 - 1 26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings 27 | n_layer: 32 28 | n_head: 16 29 | n_embd: 1024 30 | 31 | first_stage_key: 32 | x: "dst_img" 33 | 34 | cond_stage_key: 35 | c: "src_img" 36 | 37 | emb_stage_key: 38 | points: "src_points" 39 | R: "R_rel" 40 | t: "t_rel" 41 | K: "K" 42 | K_inv: "K_inv" 43 | 44 | first_stage_config: 45 | target: geofree.models.vqgan.VQModel 46 | params: 47 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt" 48 | embed_dim: 256 49 | n_embed: 16384 50 | ddconfig: 51 | double_z: False 52 | z_channels: 256 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 58 | num_res_blocks: 2 59 | attn_resolutions: [ 16 ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: geofree.modules.losses.vqperceptual.DummyLoss 63 | 64 | cond_stage_config: "__is_first_stage__" 65 | 66 | depth_stage_config: 67 | target: geofree.models.vqgan.VQModel 68 | params: 69 | ckpt_path: "pretrained_models/realestate_depth_stage/last.ckpt" 70 | embed_dim: 256 71 | n_embed: 1024 72 | ddconfig: 73 | double_z: false 74 | z_channels: 256 75 | resolution: 256 76 | in_channels: 1 77 | out_ch: 1 78 | ch: 128 79 | ch_mult: 80 | - 1 81 | - 1 82 | - 2 83 | - 2 84 | - 4 85 | num_res_blocks: 2 86 | attn_resolutions: 87 | - 16 88 | dropout: 0.0 89 | lossconfig: 90 | target: geofree.modules.losses.vqperceptual.DummyLoss 91 | 92 | emb_stage_config: 93 | target: geofree.modules.util.MultiEmbedder 94 | params: 95 | keys: 96 | - "R" 97 | - "t" 98 | - "K" 99 | - "K_inv" 100 | n_positions: 30 101 | n_channels: 1 102 | n_embed: 1024 103 | bias: False 104 | 105 | data: 106 | target: geofree.main.DataModuleFromConfig 107 | params: 108 | # bs 8 and accumulate_grad_batches 2 for 36gb vram 109 | batch_size: 8 110 | num_workers: 16 111 | train: 112 | target: geofree.data.realestate.RealEstate10KSparseTrain 113 | params: 114 | size: 115 | - 208 116 | - 368 117 | 118 | validation: 119 | target: geofree.data.realestate.RealEstate10KCustomTest 120 | params: 121 | size: 122 | - 208 123 | - 368 124 | 125 | lightning: 126 | trainer: 127 | accumulate_grad_batches: 2 128 | benchmark: True 129 | -------------------------------------------------------------------------------- /configs/realestate/realestate_13x23_impl_nodepth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0625 3 | target: geofree.models.transformers.geogpt.GeoTransformer 4 | params: 5 | use_depth: False # depth is not provided to transformer but only used to rescale t 6 | 7 | plot_cond_stage: True 8 | monitor: "val/loss" 9 | 10 | use_scheduler: True 11 | scheduler_config: 12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler 13 | params: 14 | verbosity_interval: 0 # 0 or negative to disable 15 | warm_up_steps: 5000 16 | max_decay_steps: 500001 17 | lr_start: 2.5e-6 18 | lr_max: 1.5e-4 19 | lr_min: 1.0e-8 20 | 21 | transformer_config: 22 | target: geofree.modules.transformer.mingpt.GPT 23 | params: 24 | vocab_size: 16384 25 | block_size: 627 # conditioning + 299 - 1 26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings 27 | n_layer: 32 28 | n_head: 16 29 | n_embd: 1024 30 | 31 | first_stage_key: 32 | x: "dst_img" 33 | 34 | cond_stage_key: 35 | c: "src_img" 36 | 37 | emb_stage_key: 38 | points: "src_points" 39 | R: "R_rel" 40 | t: "t_rel" 41 | K: "K" 42 | K_inv: "K_inv" 43 | 44 | first_stage_config: 45 | target: geofree.models.vqgan.VQModel 46 | params: 47 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt" 48 | embed_dim: 256 49 | n_embed: 16384 50 | ddconfig: 51 | double_z: False 52 | z_channels: 256 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1 58 | num_res_blocks: 2 59 | attn_resolutions: [ 16 ] 60 | dropout: 0.0 61 | lossconfig: 62 | target: geofree.modules.losses.vqperceptual.DummyLoss 63 | 64 | cond_stage_config: "__is_first_stage__" 65 | 66 | depth_stage_config: 67 | target: geofree.models.vqgan.VQModel 68 | params: 69 | ckpt_path: "pretrained_models/realestate_depth_stage/last.ckpt" 70 | embed_dim: 256 71 | n_embed: 1024 72 | ddconfig: 73 | double_z: false 74 | z_channels: 256 75 | resolution: 256 76 | in_channels: 1 77 | out_ch: 1 78 | ch: 128 79 | ch_mult: 80 | - 1 81 | - 1 82 | - 2 83 | - 2 84 | - 4 85 | num_res_blocks: 2 86 | attn_resolutions: 87 | - 16 88 | dropout: 0.0 89 | lossconfig: 90 | target: geofree.modules.losses.vqperceptual.DummyLoss 91 | 92 | emb_stage_config: 93 | target: geofree.modules.util.MultiEmbedder 94 | params: 95 | keys: 96 | - "R" 97 | - "t" 98 | - "K" 99 | - "K_inv" 100 | n_positions: 30 101 | n_channels: 1 102 | n_embed: 1024 103 | bias: False 104 | 105 | data: 106 | target: geofree.main.DataModuleFromConfig 107 | params: 108 | # bs 8 and accumulate_grad_batches 2 for 36gb vram 109 | batch_size: 8 110 | num_workers: 16 111 | train: 112 | target: geofree.data.realestate.RealEstate10KSparseTrain 113 | params: 114 | size: 115 | - 208 116 | - 368 117 | 118 | validation: 119 | target: geofree.data.realestate.RealEstate10KCustomTest 120 | params: 121 | size: 122 | - 208 123 | - 368 124 | 125 | lightning: 126 | trainer: 127 | accumulate_grad_batches: 2 128 | benchmark: True 129 | -------------------------------------------------------------------------------- /data/acid_custom_frames.txt: -------------------------------------------------------------------------------- 1 | d0126a869518014c,27933333.png,30600000.png,33366667.png 2 | d080086d1e293422,1096195000.png,1092558000.png,1087886000.png 3 | d1512dba43b6ebd8,188555033.png,192625767.png,197163633.png 4 | d179d50f84f4b83f,217880000.png,224280000.png,231080000.png 5 | d19c9fe809c57c92,22822800.png,25692333.png,28962267.png 6 | d3738aa86d5eda8e,80583333.png,74125000.png,66916667.png 7 | d3835cc526e338bc,71871800.png,75842433.png,79846433.png 8 | d47a3786e17ee61c,47747000.png,53220000.png,58359000.png 9 | d60884a98e4608c7,157790967.png,164531033.png,171738233.png 10 | d674021d935b083e,270671000.png,279946000.png,289389000.png 11 | d68495d47bf40a86,254633000.png,258133000.png,260633000.png 12 | d6bdf9d5236d4b9e,131965167.png,143309833.png,155488667.png 13 | d6cf2995d0ee0f30,509308000.png,518318000.png,527260000.png 14 | d78eccc35066d590,21271250.png,26276250.png,31698333.png 15 | d7a050a81422bdcb,1168167000.png,1183916067.png,1199464933.png 16 | d7e53e00d33b4935,105005000.png,102503000.png,100033000.png 17 | d7fca566a7941495,4456218433.png,4440269167.png,4424686933.png 18 | d82270f5d9180bb6,323657000.png,328395000.png,332799000.png 19 | d8409fae7feda328,42542500.png,50383667.png,56598208.png 20 | d884a63dd53b6919,162863000.png,178378000.png,193794000.png 21 | d90a8043e1ebfeae,1234200000.png,1226392000.png,1216315000.png 22 | d93af9b55cfac3d6,71571000.png,79546000.png,87220000.png 23 | d94b14b1000175be,620586000.png,625291000.png,629896000.png 24 | d994ce5969d83614,876475600.png,891957733.png,907940367.png 25 | da346e30050c6c00,176376000.png,181114000.png,185452000.png 26 | daa02bdc0b99c679,66153844.png,69720278.png,73286711.png 27 | daafe4cf1afae0ac,54687967.png,64697967.png,75108367.png 28 | daea6ab83805891a,183483300.png,196496300.png,209576033.png 29 | db18bbcefe49f79f,142709233.png,145945800.png,149482667.png 30 | db5184b7af9fd1c2,47200000.png,50233333.png,53200000.png 31 | db593b5b5fa8324a,141908000.png,144377000.png,146646000.png 32 | db9b1d074c875cfe,59226000.png,65632000.png,71972000.png 33 | dbd581a742408adf,307015000.png,315190000.png,323156000.png 34 | dc4002fc40c73bcc,188922000.png,180814000.png,173973000.png 35 | dc5f1c50fbd89f2a,201068000.png,202436000.png,203903000.png 36 | dcc8e9a3ec4195e1,30800000.png,27866667.png,24733333.png 37 | dd0c70b67523bfe3,700967000.png,716382000.png,731831000.png 38 | dd3339c318949f11,80046000.png,80981000.png,81882000.png 39 | de813f957e093bd9,31322958.png,35076708.png,38747042.png 40 | ded5b92ff8802e34,4318747767.png,4303432467.png,4287583300.png 41 | e02ba6089a8e6996,319919000.png,313447000.png,307374000.png 42 | e0b669c5a9af8487,74008000.png,76676000.png,79179000.png 43 | e0e9f1803a42d9a6,2860024000.png,2865562000.png,2871235000.png 44 | e0fb6bc955dc2827,71866667.png,70533333.png,69133333.png 45 | e0ff100eaf2ff841,18585000.png,25092000.png,31564000.png 46 | e196d71034cf55dd,170003000.png,175675000.png,180881000.png 47 | e22ba3ffad552f90,33767067.png,48281567.png,61027633.png 48 | e27a6391e3fced31,198656000.png,212587000.png,224516000.png 49 | e28fb40231cc8e31,27894000.png,31164000.png,34468000.png 50 | e33da2efe5ad29ac,1386251533.png,1401166433.png,1416748667.png 51 | e3475cb70a67e1aa,190490300.png,196129267.png,201467933.png 52 | e355f7aaf7d08fc3,164289122.png,165915756.png,167500667.png 53 | e36cd137792f2711,72439033.png,84084000.png,99999900.png 54 | e3792317ae266d83,14047367.png,18885533.png,24057367.png 55 | e3c9b29aaacf1fba,58758000.png,63230000.png,67601000.png 56 | e3d4f7fbbcd3fc3b,12554208.png,13596917.png,14764750.png 57 | e3ef559c3137aa36,2161392000.png,2169934000.png,2178009000.png 58 | e40594059ec1a2a8,69904000.png,73040000.png,77577000.png 59 | e52993cdbcf75b65,109818044.png,120537089.png,131840044.png 60 | e5310a4dc228ec5b,36803433.png,41408033.png,45745700.png 61 | e614b4670c543a72,114966667.png,119566667.png,124466667.png 62 | e668f185a5ba7c38,395028000.png,401801000.png,408875000.png 63 | e70e7383f837f2d4,1274573300.png,1277542933.png,1280579300.png 64 | e71fbf373f35269e,71471467.png,75508833.png,79145800.png 65 | e7443b934b5f4b0a,103603500.png,109209100.png,113980533.png 66 | e74ca33c36a51184,665397462.png,666332665.png,667267869.png 67 | e75b4dd9975c6f71,31200000.png,32733333.png,34500000.png 68 | e7a51eee636a0ba5,110510000.png,114081000.png,117817000.png 69 | e831e23cd75ab6c6,100000000.png,110010000.png,121321000.png 70 | e8426e39152cce71,106205000.png,109042000.png,111645000.png 71 | e86c6167b214bb99,34902000.png,27794000.png,20254000.png 72 | e88f65da150600b1,100033333.png,103000000.png,106133333.png 73 | e8a2758fe5401c9a,738911156.png,741482966.png,744255177.png 74 | e93d2814054ac90c,151584767.png,154287467.png,156789967.png 75 | e9828e82fe988c70,2544709000.png,2538836000.png,2532663000.png 76 | e9d627308856bc62,45278000.png,34268000.png,31398000.png 77 | e9eb29c8c46a8dfe,154554400.png,159792967.png,165431933.png 78 | ea02488ece6242e8,36469767.png,31231200.png,26292933.png 79 | ea36b0ce652cd8b4,155955000.png,158458000.png,161128000.png 80 | eab2bd6815838f8e,2293758000.png,2277141000.png,2261926000.png 81 | eabe884be6f7e517,158066667.png,161200000.png,164300000.png 82 | ebb0c513c19ea84f,144042000.png,152875000.png,163208000.png 83 | ecec6359aa6acbcf,213313100.png,218785233.png,223957067.png 84 | ed01ba5660fd718a,304666667.png,308166667.png,311708333.png 85 | ed0e643389fb2c0f,869469000.png,880880000.png,892792000.png 86 | ed29587ce3e9c577,441308000.png,457291000.png,472506000.png 87 | ed3c89742143039a,44477767.png,47547500.png,50784067.png 88 | edeb8a878709b716,72640000.png,68168000.png,63663000.png 89 | edf5fca1c9f4cb6c,269168900.png,271638033.png,274173900.png 90 | ee09e048af8deba6,17751067.png,30730700.png,42642600.png 91 | ee32b574f99ff150,155155000.png,152277125.png,149649500.png 92 | ee4b7adccd9ecd5c,202667000.png,207125000.png,211917000.png 93 | ee69da79dbad7bb3,70894789.png,73412467.png,76209889.png 94 | eea426ba67142e9b,149416000.png,150651000.png,151852000.png 95 | ef0753c88ee52b6f,105480000.png,109400000.png,113400000.png 96 | ef1f6e769ebc1c7a,208700000.png,211266667.png,214000000.png 97 | efaf202f7c09123c,56347958.png,60018292.png,63313250.png 98 | f02d43a4dabe789d,85518000.png,92425000.png,99799000.png 99 | f16b9ff7838b81c6,52733000.png,55100000.png,57166000.png 100 | f1beb221d42d1b84,359459100.png,356022333.png,352485467.png 101 | f23cd2826eff1565,30230200.png,31064367.png,31798433.png 102 | f2a34c7254940f53,494827000.png,487854000.png,478811000.png 103 | f2aa477766df229c,189166667.png,190900000.png,192533333.png 104 | f31b3bb307daf84a,51250000.png,53083000.png,54833000.png 105 | f3218100d129be59,62729000.png,66800000.png,69269000.png 106 | f33cad5fe9bd65f9,14916667.png,12916667.png,10916667.png 107 | f4421fc161eb69f3,89189100.png,94194100.png,99999900.png 108 | f4ce75753de06c58,66366667.png,75800000.png,86066667.png 109 | f4e0e97f02bcb1e3,245946000.png,247881000.png,249850000.png 110 | f51d8789db1f722c,179500000.png,173800000.png,167900000.png 111 | f58db3b7a2ce9cc0,1125724600.png,1128393933.png,1132064267.png 112 | f5ae92dd3d7e87e0,604303000.png,618385000.png,634434000.png 113 | f5ef3487cdcee3fb,33433000.png,36003000.png,38639000.png 114 | f6996b1f345e96b4,153954000.png,159960000.png,166065000.png 115 | f77008b83e767032,1734933000.png,1750449000.png,1764996000.png 116 | f7f7b65716d33d83,53219833.png,53787067.png,54320933.png 117 | f830083184521d34,60293627.png,65865866.png,71938605.png 118 | f84181beccbd5e13,451984867.png,454654200.png,457523733.png 119 | f8c932575ce9f2a5,47115385.png,50923077.png,54730769.png 120 | f8e283943d00a04d,215649000.png,221855000.png,228995000.png 121 | f91e3039f88375a3,88388300.png,90190100.png,91925167.png 122 | f93cd4d2a4313d38,55989267.png,61761700.png,68201467.png 123 | f976c57572a6ce0e,4217413200.png,4232695133.png,4249144900.png 124 | f98981baa6b8ee65,2185617000.png,2195660000.png,2207472000.png 125 | fa0d910e19270413,228061000.png,232800000.png,238104000.png 126 | fa329483307d59ac,109576000.png,117317000.png,125525000.png 127 | fa35a303ae5d4be6,2936767000.png,2951649000.png,2959323000.png 128 | faf3bb76a633f435,1004703000.png,1002368000.png,1000099000.png 129 | fb155e5f56868f1b,604479479.png,608775442.png,613446780.png 130 | fb54fffc370d1b0a,527994133.png,539372167.png,549682467.png 131 | fb9c37c89297d619,1316148167.png,1331029700.png,1347045700.png 132 | fbc7bf83e0ae84d3,13646967.png,15081733.png,16349667.png 133 | fc00edbd1d936174,38433333.png,43266667.png,47866667.png 134 | fc9bf0a83f1891e8,305680000.png,309851000.png,314397000.png 135 | fd012e8d736a6c8e,38505133.png,41875167.png,45512133.png 136 | fd2f9b4c485544f3,50050000.png,50984000.png,51919000.png 137 | fd3acb362205d6eb,52085000.png,54420000.png,56923000.png 138 | fd9d98306e9bf7d0,19186000.png,22022000.png,22556000.png 139 | fe0428833b293005,15448767.png,24524500.png,33333300.png 140 | ff85e9318de5f4bf,69002267.png,75241833.png,81948533.png 141 | ffc56e43145f3ae5,105305200.png,109843067.png,114848067.png 142 | -------------------------------------------------------------------------------- /data/realestate_custom_frames.txt: -------------------------------------------------------------------------------- 1 | f04880903462e8f0,106806000.png,109609000.png,112646000.png 2 | f1171dbc746a3cd6,152919433.png,154621133.png,156256100.png 3 | f1226d5483f1c1cd,96029267.png,97931167.png,99999900.png 4 | f14627d2fc3158d4,56756700.png,60894167.png,65532133.png 5 | f1623ec1d552408c,87220467.png,89222467.png,91291200.png 6 | f22301abb4ab062c,299032000.png,300333000.png,301734000.png 7 | f2309b1e3f43d138,271738133.png,274841233.png,279145533.png 8 | f253339fa19467f5,307574000.png,308809000.png,312346000.png 9 | f2736f39d6213d22,140940000.png,141975000.png,143076000.png 10 | f289b5f7136bba85,46546500.png,50650600.png,55355300.png 11 | f3236d9eb1996c58,140073267.png,142142000.png,144677867.png 12 | f32643006f1260b9,85085000.png,89155733.png,92492400.png 13 | f3534478d06ca677,194727867.png,197096900.png,199532667.png 14 | f372706f15279f2d,212779000.png,215415000.png,218419000.png 15 | f408fd1ecf9639dd,41583000.png,43291000.png,44999000.png 16 | f412f8fa0ea298ff,253553000.png,255755000.png,258892000.png 17 | f413f880cdafe53c,226226000.png,228161267.png,230096533.png 18 | f43420f8216f180b,134300833.png,132732600.png,131097633.png 19 | f4769a32a225c172,178444000.png,182582000.png,187253000.png 20 | f478d5f55e41f4f7,273072800.png,276743133.png,279979700.png 21 | f48829b917629fe0,42642000.png,47113000.png,50650000.png 22 | f5257718f0aae70d,114547767.png,115648867.png,118351567.png 23 | f533f7e3ca2b82b2,255355100.png,259425833.png,264264000.png 24 | f536430fb223d623,290957333.png,292792500.png,294594300.png 25 | f546b4bb3f06e5ad,279613000.png,283216000.png,286319000.png 26 | f591e1921f2c850c,154020533.png,156489667.png,158958800.png 27 | f6057c281758b5f6,104104000.png,108641000.png,113113000.png 28 | f60664acd06e22aa,129729600.png,132098633.png,135268467.png 29 | f626135a90d68e85,177577400.png,179412567.png,181381200.png 30 | f649244a6907838c,52218833.png,54587867.png,57624233.png 31 | f654d9cb6bc8869e,260426833.png,261060800.png,262328733.png 32 | f673068196024955,196196000.png,197865000.png,199733000.png 33 | f6884e5c27acbafe,184384200.png,186219367.png,187921067.png 34 | f714ccd51adc29a8,31398000.png,34234000.png,37037000.png 35 | f74139ac48f19b3c,172005167.png,176543033.png,179712867.png 36 | f793e374e29bbd16,145745600.png,148481667.png,151484667.png 37 | f795469d3856697e,188021167.png,191558033.png,194727867.png 38 | f8545ff1c0ebc410,34900000.png,37833333.png,40466667.png 39 | f882aaab13d8d7a6,186719000.png,189756000.png,192826000.png 40 | f91749843a42fa3e,247247000.png,249649400.png,252185267.png 41 | f954f61234d49919,121254000.png,123290000.png,125258000.png 42 | f956425e6e6f9863,114448000.png,115982000.png,117650000.png 43 | f96049f4355544bd,194260733.png,198097900.png,202001800.png 44 | f968e1bb4d5f9e4b,273239633.png,275208267.png,277610667.png 45 | f99691764cd67e0c,102969533.png,106506400.png,110243467.png 46 | f99730b487e57e9b,217250367.png,221287733.png,225291733.png 47 | fa005e9c2a1c4828,167567400.png,171504667.png,175675500.png 48 | fa14b62a46ffd7b7,30964267.png,35502133.png,39873167.png 49 | fa1fe0e1d0ba7450,133199733.png,134701233.png,136402933.png 50 | fa25ffc4897fc142,70303567.png,73373300.png,76309567.png 51 | fa53b8c13f6caabb,45344000.png,47680000.png,52118000.png 52 | fa71d84898d6ac7f,30697333.png,35502133.png,39706333.png 53 | fa7ac0885965196f,78545133.png,83016267.png,85752333.png 54 | fa82d71146e33555,113413300.png,116749967.png,121087633.png 55 | fa95ce310f223e51,157123790.png,158258258.png,159426093.png 56 | faa891b87b4b33ad,270169837.png,273139473.png,276309309.png 57 | faaba1f48348f5c6,262463000.png,263897000.png,265332000.png 58 | fab28f712c0edea4,188772000.png,190732000.png,192609000.png 59 | faf2fb655d5974b5,120820700.png,125391933.png,129796333.png 60 | fb08b3c2668c99c3,204871000.png,208041000.png,211378000.png 61 | fb0d3f523276eefe,67033633.png,71471400.png,75742333.png 62 | fb1bd05ed2a473ad,59759700.png,64064000.png,68568500.png 63 | fb277b237d7bdcb5,249015433.png,250883967.png,253019433.png 64 | fb32e6765f776b18,289622667.png,292292000.png,294877922.png 65 | fb35415a0b8eb135,260026000.png,263196000.png,265965000.png 66 | fb36cf1f6923924b,112145000.png,115248000.png,118718000.png 67 | fb3d48fa6ce76da1,132098633.png,133266467.png,134434300.png 68 | fb424ab019f3e1a1,77978000.png,82315000.png,86653000.png 69 | fb52f951d8a8ad11,295628667.png,297664033.png,299899600.png 70 | fb606e0ed3d19e90,120187000.png,122122000.png,124158000.png 71 | fb723ecfff811097,138538400.png,143009533.png,147213733.png 72 | fb80aacde91ee824,212879333.png,217483933.png,221788233.png 73 | fbb437c54b5d9d78,113380000.png,117617000.png,122121000.png 74 | fbd37246c3e03576,167534033.png,171838333.png,176376200.png 75 | fbdcb1880c6bab8e,81915167.png,84250833.png,86719967.png 76 | fbdf22c23b31ceb4,30730700.png,35101733.png,39606233.png 77 | fbe4055d5dfd480e,40240200.png,43043000.png,46146100.png 78 | fbe9232477f69e33,100166733.png,101801700.png,103336567.png 79 | fbeb367248e705bc,104070633.png,107607500.png,110844067.png 80 | fbef04fb266195f9,81414667.png,83183100.png,85085000.png 81 | fbfed12d638968c2,40807433.png,45078367.png,48715333.png 82 | fc174b681e13cdf4,120854067.png,123356567.png,126092633.png 83 | fc43cc25dadacb18,82682600.png,85051633.png,87587500.png 84 | fc4410feec1308a2,176409567.png,180914067.png,185385200.png 85 | fc50278b5f950186,100133367.png,101734967.png,102335567.png 86 | fc50693501976251,159592767.png,161828333.png,165064900.png 87 | fc508faa91c9807d,145278000.png,146847000.png,148281000.png 88 | fc6f664a700121e9,33833800.png,37037000.png,39939900.png 89 | fc75727847fd61f4,260193267.png,264397467.png,268668400.png 90 | fc87af9c4f0000b1,57023633.png,56423033.png,55055000.png 91 | fc966f25afa3659f,160660500.png,162829333.png,167033533.png 92 | fc991bfad5040ca4,232607378.png,234234000.png,236069167.png 93 | fc99e61da3cfb1b1,54788067.png,57123733.png,59726333.png 94 | fca310e3805d5ab4,78945533.png,80613867.png,81581500.png 95 | fca5ca7d4517812b,253286367.png,254020433.png,258258000.png 96 | fcb0cb3d3da5c2e5,38738700.png,41107733.png,43710333.png 97 | fcd7ebdf0454ef30,149182367.png,151251100.png,152519033.png 98 | fcde616cb28da426,200800000.png,202735000.png,204571000.png 99 | fce3a289eed2d6c6,158058000.png,160828000.png,163563000.png 100 | fcfa188fdd8e4cdc,100199000.png,104804000.png,108141000.png 101 | fd020d75ab5e2fa3,91558133.png,94794700.png,97931167.png 102 | fd2c0fc2bf2befa7,87720967.png,89055633.png,90290200.png 103 | fd3c54bb57284e5c,120019000.png,121254000.png,122455000.png 104 | fd48a65a5e252855,71671000.png,75909000.png,80180000.png 105 | fd6ee791edb85b45,202802600.png,205138267.png,207273733.png 106 | fd7890ec5b0a924c,310226000.png,307015000.png,304346000.png 107 | fd9ab1e2b1f8e5aa,78233333.png,80666667.png,83066667.png 108 | fdaab1f49851782c,69736333.png,72305567.png,74874800.png 109 | fdb60d991671bd45,47213833.png,51718333.png,55955900.png 110 | fdceb26461ef9adb,50050000.png,54020633.png,57924533.png 111 | fdd41bb5cab98c86,68701967.png,72105367.png,76042633.png 112 | fdf17f383d76c327,238605000.png,241641000.png,244745000.png 113 | fe04a079dc651d9b,112245000.png,117884000.png,121087000.png 114 | fe26ad446164cb32,112378933.png,114714600.png,118351567.png 115 | fe2fadf89a84e92a,134000533.png,135468667.png,137070267.png 116 | fe625de05cd0a34b,83082000.png,87353000.png,91958000.png 117 | fe8431a1c5bd3c81,260793867.png,265198267.png,269569300.png 118 | fe97f5516fea659b,144544400.png,148014533.png,151484667.png 119 | fea3a0ff96583244,271104167.png,273373100.png,275408467.png 120 | fea544b472e9abd1,44110000.png,48682000.png,53086000.png 121 | feba28a8f1b69de6,102870000.png,107107000.png,111778000.png 122 | fef11d498a882edc,44911533.png,47113733.png,49182467.png 123 | fef6f87f0eb94f9d,243310000.png,246112000.png,248648000.png 124 | fefa62ae2b9915cb,34635000.png,38471000.png,42141000.png 125 | ff036af715175dce,313413100.png,314981333.png,316616300.png 126 | ff03c7906d70bc61,213713000.png,216916000.png,219986000.png 127 | ff0749db248be7fb,115315200.png,117650867.png,120120000.png 128 | ff1c8223873aa02b,93727000.png,95161000.png,96730000.png 129 | ff20a3e943ea6d63,179246000.png,183583000.png,188021000.png 130 | ff58ffaac40eb035,277043433.png,280013067.png,282649033.png 131 | ff63123a7ef312a5,130630500.png,132765967.png,134934800.png 132 | ff6d8ab35e042db5,137237100.png,140006533.png,142775967.png 133 | ff7f5042dddd5e12,113146000.png,115048000.png,118952000.png 134 | ff821ee982e8f194,181981800.png,184517667.png,186986800.png 135 | ff887646981745a6,270970700.png,275275000.png,279479200.png 136 | ff9e265755208438,82849433.png,84717967.png,87287200.png 137 | ffa95c3b40609c76,146179000.png,149449000.png,152986000.png 138 | ffb3b1bb765b96eb,195828967.png,199933067.png,204404200.png 139 | ffbd89833d851c18,177210367.png,179979800.png,182882700.png 140 | ffd80f73ab22500b,181514667.png,184417567.png,187053533.png 141 | ffe67ac537febe41,57290567.png,60727333.png,63596867.png 142 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: geofree 2 | channels: 3 | - pytorch 4 | - defaults 5 | - conda-forge 6 | dependencies: 7 | - python=3.8.5 8 | - pip=20.3 9 | - cudatoolkit=10.1 10 | - cudatoolkit-dev=10.1.243 # for nvcc 11 | - gcc_linux-64=7.3.0 # for nvcc compatibility 12 | - gxx_linux-64=7.3.0 # for nvcc compatibility 13 | - pytorch=1.7.1 14 | - torchvision=0.8.2 15 | - numpy=1.19.2 16 | -------------------------------------------------------------------------------- /geofree/__init__.py: -------------------------------------------------------------------------------- 1 | from geofree.util import pretrained_models 2 | -------------------------------------------------------------------------------- /geofree/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/data/__init__.py -------------------------------------------------------------------------------- /geofree/data/acid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | from geofree.data.realestate import PRNGMixin, load_sparse_model_example, pad_points 6 | 7 | 8 | class ACIDSparseBase(Dataset, PRNGMixin): 9 | def __init__(self): 10 | self.sparse_dir = "data/acid_sparse" 11 | 12 | def __len__(self): 13 | return len(self.sequences) 14 | 15 | def __getitem__(self, index): 16 | seq = self.sequences[index] 17 | root = os.path.join(self.sequence_dir, seq) 18 | frames = sorted([fname for fname in os.listdir(os.path.join(root, "images")) if fname.endswith(".png")]) 19 | segments = self.prng.choice(3, 2, replace=False) 20 | if segments[0] < segments[1]: # forward 21 | if segments[1]-segments[0] == 1: # small 22 | label = 0 23 | else: 24 | label = 1 # large 25 | else: # backward 26 | if segments[1]-segments[0] == 1: # small 27 | label = 2 28 | else: 29 | label = 3 30 | n = len(frames) 31 | dst_indices = list(range(segments[0]*n//3, (segments[0]+1)*n//3)) 32 | src_indices = list(range(segments[1]*n//3, (segments[1]+1)*n//3)) 33 | dst_index = self.prng.choice(dst_indices) 34 | src_index = self.prng.choice(src_indices) 35 | img_dst = frames[dst_index] 36 | img_src = frames[src_index] 37 | 38 | example = load_sparse_model_example( 39 | root=root, img_dst=img_dst, img_src=img_src, size=self.size) 40 | 41 | for k in example: 42 | example[k] = example[k].astype(np.float32) 43 | 44 | example["src_points"] = pad_points(example["src_points"], 45 | self.max_points) 46 | example["seq"] = seq 47 | example["label"] = label 48 | example["dst_fname"] = img_dst 49 | example["src_fname"] = img_src 50 | 51 | return example 52 | 53 | 54 | class ACIDSparseTrain(ACIDSparseBase): 55 | def __init__(self, size=None, max_points=16384): 56 | super().__init__() 57 | self.size = size 58 | self.max_points = max_points 59 | 60 | self.split = "train" 61 | self.sequence_dir = os.path.join(self.sparse_dir, self.split) 62 | with open("data/acid_train_sequences.txt", "r") as f: 63 | self.sequences = f.read().splitlines() 64 | 65 | 66 | class ACIDSparseValidation(ACIDSparseBase): 67 | def __init__(self, size=None, max_points=16384): 68 | super().__init__() 69 | self.size = size 70 | self.max_points = max_points 71 | 72 | self.split = "validation" 73 | self.sequence_dir = os.path.join(self.sparse_dir, self.split) 74 | with open("data/acid_validation_sequences.txt", "r") as f: 75 | self.sequences = f.read().splitlines() 76 | 77 | 78 | class ACIDSparseTest(ACIDSparseBase): 79 | def __init__(self, size=None, max_points=16384): 80 | super().__init__() 81 | self.size = size 82 | self.max_points = max_points 83 | 84 | self.split = "test" 85 | self.sequence_dir = os.path.join(self.sparse_dir, self.split) 86 | with open("data/acid_test_sequences.txt", "r") as f: 87 | self.sequences = f.read().splitlines() 88 | 89 | 90 | class ACIDCustomTest(Dataset): 91 | def __init__(self, size=None, max_points=16384): 92 | self.size = size 93 | self.max_points = max_points 94 | 95 | self.frames_file = "data/acid_custom_frames.txt" 96 | self.sparse_dir = "data/acid_sparse" 97 | self.split = "test" 98 | 99 | with open(self.frames_file, "r") as f: 100 | frames = f.read().splitlines() 101 | 102 | seq_data = dict() 103 | for line in frames: 104 | seq,a,b,c = line.split(",") 105 | assert not seq in seq_data 106 | seq_data[seq] = [a,b,c] 107 | 108 | # sequential list of seq, label, dst, src 109 | # where label is used to disambiguate different warping scenarios 110 | # 0: small forward movement 111 | # 1: large forward movement 112 | # 2: small backward movement (reverse of 0) 113 | # 3: large backward movement (reverse of 1) 114 | frame_data = list() 115 | for seq in sorted(seq_data.keys()): 116 | abc = seq_data[seq] 117 | frame_data.append([seq, 0, abc[1], abc[0]]) # b|a 118 | frame_data.append([seq, 1, abc[2], abc[0]]) # c|a 119 | frame_data.append([seq, 2, abc[0], abc[1]]) # a|b 120 | frame_data.append([seq, 3, abc[0], abc[2]]) # a|c 121 | 122 | self.frame_data = frame_data 123 | 124 | def __len__(self): 125 | return len(self.frame_data) 126 | 127 | def __getitem__(self, index): 128 | seq, label, img_dst, img_src = self.frame_data[index] 129 | root = os.path.join(self.sparse_dir, self.split, seq) 130 | 131 | example = load_sparse_model_example( 132 | root=root, img_dst=img_dst, img_src=img_src, size=self.size) 133 | 134 | for k in example: 135 | example[k] = example[k].astype(np.float32) 136 | 137 | example["src_points"] = pad_points(example["src_points"], 138 | self.max_points) 139 | example["seq"] = seq 140 | example["label"] = label 141 | example["dst_fname"] = img_dst 142 | example["src_fname"] = img_src 143 | 144 | return example 145 | -------------------------------------------------------------------------------- /geofree/data/realestate.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import os 3 | import numpy as np 4 | import torch.utils.data as data 5 | from PIL import Image 6 | from torchvision.transforms import Compose, Normalize, Resize, ToTensor 7 | from geofree.data.read_write_model import read_model 8 | 9 | 10 | class PRNGMixin(object): 11 | """Adds a prng property which is a numpy RandomState which gets 12 | reinitialized whenever the pid changes to avoid synchronized sampling 13 | behavior when used in conjunction with multiprocessing.""" 14 | 15 | @property 16 | def prng(self): 17 | currentpid = os.getpid() 18 | if getattr(self, "_initpid", None) != currentpid: 19 | self._initpid = currentpid 20 | self._prng = np.random.RandomState() 21 | return self._prng 22 | 23 | 24 | def load_sparse_model_example(root, img_dst, img_src, size): 25 | """ 26 | Parameters 27 | root folder containing directory sparse with points3D.bin, 28 | images.bin and cameras.bin, and directory images 29 | img_dst filename of image in images to be used as destination 30 | img_src filename of image in images to be used as source 31 | size size to resize image and parameters to. If None nothing is 32 | done, otherwise it should be in (h,w) format 33 | Returns 34 | example dictionary containing 35 | dst_img destination image as (h,w,c) array in range (-1,1) 36 | src_img source image as (h,w,c) array in range (-1,1) 37 | src_points sparse set of 3d points for source image as (N,3) array 38 | with (:,:2) being pixel coordinates and (:,2) depth values 39 | K 3x3 camera intrinsics 40 | K_inv inverse of camera intrinsics 41 | R_rel relative rotation mapping from source to destination 42 | coordinate system 43 | t_rel relative translation mapping from source to destination 44 | coordinate system 45 | """ 46 | # load sparse model 47 | model = os.path.join(root, "sparse") 48 | try: 49 | cameras, images, points3D = read_model(path=model, ext=".bin") 50 | except Exception as e: 51 | raise Exception(f"Failed to load sparse model {model}.") from e 52 | 53 | 54 | # load camera parameters and image size 55 | cam = cameras[1] 56 | h = cam.height 57 | w = cam.width 58 | params = cam.params 59 | K = np.array([[params[0], 0.0, params[2]], 60 | [0.0, params[1], params[3]], 61 | [0.0, 0.0, 1.0]]) 62 | 63 | # find keys of desired dst and src images 64 | key_dst = [k for k in images.keys() if images[k].name==img_dst] 65 | assert len(key_dst)==1, (img_dst, key_dst) 66 | key_src = [k for k in images.keys() if images[k].name==img_src] 67 | assert len(key_src)==1, (img_src, key_src) 68 | keys = [key_dst[0], key_src[0]] 69 | 70 | # load extrinsics 71 | Rs = np.stack([images[k].qvec2rotmat() for k in keys]) 72 | ts = np.stack([images[k].tvec for k in keys]) 73 | 74 | # load sparse 3d points to be able to estimate scale 75 | sparse_points = [None, None] 76 | #for i in range(len(keys)): 77 | for i in [1]: # only need it for source 78 | key = keys[i] 79 | xys = images[key].xys 80 | p3D = images[key].point3D_ids 81 | pmask = p3D > 0 82 | # if verbose: print("Found {} 3d points in sparse model.".format(pmask.sum())) 83 | xys = xys[pmask] 84 | p3D = p3D[pmask] 85 | worlds = np.stack([points3D[id_].xyz for id_ in p3D]) # N, 3 86 | # project to current view 87 | worlds = worlds[..., None] # N,3,1 88 | pixels = K[None,...]@(Rs[i][None,...]@worlds+ts[i][None,...,None]) 89 | pixels = pixels.squeeze(-1) # N,3 90 | 91 | # instead of using provided xys, one could also project pixels, ie 92 | # xys ~ pixels[:,:2]/pixels[:,[2]] 93 | points = np.concatenate([xys, pixels[:,[2]]], axis=1) 94 | sparse_points[i] = points 95 | 96 | # code to convert to sparse depth map 97 | # xys = points[:,:2] 98 | # xys = np.round(xys).astype(np.int) 99 | # xys[:,0] = xys[:,0].clip(min=0,max=w-1) 100 | # xys[:,1] = xys[:,1].clip(min=0,max=h-1) 101 | # indices = xys[:,1]*w+xys[:,0] 102 | # flatdm = np.zeros(h*w) 103 | # flatz = pixels[:,2] 104 | # np.put_along_axis(flatdm, indices, flatz, axis=0) 105 | # sparse_dm = flatdm.reshape(h,w) 106 | 107 | # load images 108 | im_root = os.path.join(root, "images") 109 | im_paths = [os.path.join(im_root, images[k].name) for k in keys] 110 | ims = list() 111 | for path in im_paths: 112 | im = Image.open(path) 113 | ims.append(im) 114 | 115 | if size is not None and (size[0] != h or size[1] != w): 116 | # resize 117 | ## K 118 | K[0,:] = K[0,:]*size[1]/w 119 | K[1,:] = K[1,:]*size[0]/h 120 | ## points 121 | points[:,0] = points[:,0]*size[1]/w 122 | points[:,1] = points[:,1]*size[0]/h 123 | ## img 124 | for i in range(len(ims)): 125 | ims[i] = ims[i].resize((size[1],size[0]), 126 | resample=Image.LANCZOS) 127 | 128 | 129 | for i in range(len(ims)): 130 | ims[i] = np.array(ims[i])/127.5-1.0 131 | 132 | 133 | # relative camera 134 | R_dst = Rs[0] 135 | t_dst = ts[0] 136 | R_src_inv = Rs[1].transpose(-1,-2) 137 | t_src = ts[1] 138 | R_rel = R_dst@R_src_inv 139 | t_rel = t_dst-R_rel@t_src 140 | K_inv = np.linalg.inv(K) 141 | 142 | # collect results 143 | example = { 144 | "dst_img": ims[0], 145 | "src_img": ims[1], 146 | "src_points": sparse_points[1], 147 | "K": K, 148 | "K_inv": K_inv, 149 | "R_rel": R_rel, 150 | "t_rel": t_rel, 151 | } 152 | 153 | return example 154 | 155 | def pad_points(points, N): 156 | padded = -1*np.ones((N,3), dtype=points.dtype) 157 | padded[:points.shape[0],:] = points 158 | return padded 159 | 160 | class RealEstate10KCustomTest(data.Dataset): 161 | def __init__(self, size=None, max_points=16384): 162 | self.size = size 163 | self.max_points = max_points 164 | 165 | self.frames_file = "data/realestate_custom_frames.txt" 166 | self.sparse_dir = "data/realestate_sparse" 167 | self.split = "test" 168 | 169 | with open(self.frames_file, "r") as f: 170 | frames = f.read().splitlines() 171 | 172 | seq_data = dict() 173 | for line in frames: 174 | seq,a,b,c = line.split(",") 175 | assert not seq in seq_data 176 | seq_data[seq] = [a,b,c] 177 | 178 | # sequential list of seq, label, dst, src 179 | # where label is used to disambiguate different warping scenarios 180 | # 0: small forward movement 181 | # 1: large forward movement 182 | # 2: small backward movement (reverse of 0) 183 | # 3: large backward movement (reverse of 1) 184 | frame_data = list() 185 | for seq in sorted(seq_data.keys()): 186 | abc = seq_data[seq] 187 | frame_data.append([seq, 0, abc[1], abc[0]]) # b|a 188 | frame_data.append([seq, 1, abc[2], abc[0]]) # c|a 189 | frame_data.append([seq, 2, abc[0], abc[1]]) # a|b 190 | frame_data.append([seq, 3, abc[0], abc[2]]) # a|c 191 | 192 | self.frame_data = frame_data 193 | 194 | def __len__(self): 195 | return len(self.frame_data) 196 | 197 | def __getitem__(self, index): 198 | seq, label, img_dst, img_src = self.frame_data[index] 199 | root = os.path.join(self.sparse_dir, self.split, seq) 200 | 201 | example = load_sparse_model_example( 202 | root=root, img_dst=img_dst, img_src=img_src, size=self.size) 203 | 204 | for k in example: 205 | example[k] = example[k].astype(np.float32) 206 | 207 | example["src_points"] = pad_points(example["src_points"], 208 | self.max_points) 209 | example["seq"] = seq 210 | example["label"] = label 211 | example["dst_fname"] = img_dst 212 | example["src_fname"] = img_src 213 | 214 | return example 215 | 216 | 217 | class RealEstate10KSparseTrain(data.Dataset, PRNGMixin): 218 | def __init__(self, size=None, max_points=16384): 219 | self.size = size 220 | self.max_points = max_points 221 | 222 | self.sparse_dir = "data/realestate_sparse" 223 | self.split = "train" 224 | self.sequence_dir = os.path.join(self.sparse_dir, self.split) 225 | with open("data/realestate_train_sequences.txt", "r") as f: 226 | self.sequences = f.read().splitlines() 227 | 228 | def __len__(self): 229 | return len(self.sequences) 230 | 231 | def __getitem__(self, index): 232 | seq = self.sequences[index] 233 | root = os.path.join(self.sequence_dir, seq) 234 | frames = sorted([fname for fname in os.listdir(os.path.join(root, "images")) if fname.endswith(".png")]) 235 | segments = self.prng.choice(3, 2, replace=False) 236 | if segments[0] < segments[1]: # forward 237 | if segments[1]-segments[0] == 1: # small 238 | label = 0 239 | else: 240 | label = 1 # large 241 | else: # backward 242 | if segments[1]-segments[0] == 1: # small 243 | label = 2 244 | else: 245 | label = 3 246 | n = len(frames) 247 | dst_indices = list(range(segments[0]*n//3, (segments[0]+1)*n//3)) 248 | src_indices = list(range(segments[1]*n//3, (segments[1]+1)*n//3)) 249 | dst_index = self.prng.choice(dst_indices) 250 | src_index = self.prng.choice(src_indices) 251 | img_dst = frames[dst_index] 252 | img_src = frames[src_index] 253 | 254 | example = load_sparse_model_example( 255 | root=root, img_dst=img_dst, img_src=img_src, size=self.size) 256 | 257 | for k in example: 258 | example[k] = example[k].astype(np.float32) 259 | 260 | example["src_points"] = pad_points(example["src_points"], 261 | self.max_points) 262 | example["seq"] = seq 263 | example["label"] = label 264 | example["dst_fname"] = img_dst 265 | example["src_fname"] = img_src 266 | 267 | return example 268 | 269 | 270 | class RealEstate10KSparseCustom(data.Dataset): 271 | def __init__(self, frame_data, split, size=None, max_points=16384): 272 | self.size = size 273 | self.max_points = max_points 274 | 275 | self.sparse_dir = "data/realestate_sparse" 276 | self.split = split 277 | self.frame_data = frame_data 278 | 279 | def __len__(self): 280 | return len(self.frame_data) 281 | 282 | def __getitem__(self, index): 283 | seq, label, img_dst, img_src = self.frame_data[index] 284 | root = os.path.join(self.sparse_dir, self.split, seq) 285 | 286 | example = load_sparse_model_example( 287 | root=root, img_dst=img_dst, img_src=img_src, size=self.size) 288 | 289 | for k in example: 290 | example[k] = example[k].astype(np.float32) 291 | 292 | example["src_points"] = pad_points(example["src_points"], 293 | self.max_points) 294 | example["seq"] = seq 295 | example["label"] = label 296 | example["dst_fname"] = img_dst 297 | example["src_fname"] = img_src 298 | 299 | return example 300 | -------------------------------------------------------------------------------- /geofree/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/examples/__init__.py -------------------------------------------------------------------------------- /geofree/examples/artist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/examples/artist.jpg -------------------------------------------------------------------------------- /geofree/examples/beach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/examples/beach.jpg -------------------------------------------------------------------------------- /geofree/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /geofree/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/models/__init__.py -------------------------------------------------------------------------------- /geofree/models/transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/models/transformers/__init__.py -------------------------------------------------------------------------------- /geofree/models/transformers/warpgpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import torch.nn.functional as F 4 | import pytorch_lightning as pl 5 | from torch.optim.lr_scheduler import LambdaLR 6 | 7 | from geofree.main import instantiate_from_config 8 | 9 | 10 | class WarpTransformer(pl.LightningModule): 11 | """This one relies on WarpGPT, where the warper handles the warping to 12 | support warping of positional embeddings, too. 13 | first stage to encode dst 14 | cond stage to encode src 15 | """ 16 | def __init__(self, 17 | transformer_config, 18 | first_stage_config, 19 | cond_stage_config, 20 | ckpt_path=None, 21 | ignore_keys=[], 22 | first_stage_key="dst_img", 23 | cond_stage_key="src_img", 24 | use_scheduler=False, 25 | scheduler_config=None, 26 | monitor="val/loss", 27 | pkeep=1.0, 28 | plot_cond_stage=False, 29 | log_det_sample=False, 30 | top_k=None 31 | ): 32 | 33 | super().__init__() 34 | if monitor is not None: 35 | self.monitor = monitor 36 | self.log_det_sample = log_det_sample 37 | self.init_first_stage_from_ckpt(first_stage_config) 38 | self.init_cond_stage_from_ckpt(cond_stage_config) 39 | self.transformer = instantiate_from_config(config=transformer_config) 40 | 41 | if ckpt_path is not None: 42 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 43 | self.first_stage_key = first_stage_key 44 | self.cond_stage_key = cond_stage_key 45 | self.pkeep = pkeep 46 | 47 | self.use_scheduler = use_scheduler 48 | if use_scheduler: 49 | assert scheduler_config is not None 50 | self.scheduler_config = scheduler_config 51 | self.plot_cond_stage = plot_cond_stage 52 | self.top_k = top_k if top_k is not None else 100 53 | self.warpkwargs_keys = { 54 | "x": "src_img", 55 | "points": "src_points", 56 | "R": "R_rel", 57 | "t": "t_rel", 58 | "K_dst": "K", 59 | "K_src_inv": "K_inv", 60 | } 61 | 62 | def init_from_ckpt(self, path, ignore_keys=list()): 63 | sd = torch.load(path, map_location="cpu")["state_dict"] 64 | for k in sd.keys(): 65 | for ik in ignore_keys: 66 | if k.startswith(ik): 67 | self.print("Deleting key {} from state_dict.".format(k)) 68 | del sd[k] 69 | missing, unexpected = self.load_state_dict(sd, strict=False) 70 | print(f"Restored from {path} with {len(missing)} missing keys and {len(unexpected)} unexpected keys.") 71 | 72 | def init_first_stage_from_ckpt(self, config): 73 | model = instantiate_from_config(config) 74 | self.first_stage_model = model.eval() 75 | 76 | def init_cond_stage_from_ckpt(self, config): 77 | if config == "__is_first_stage__": 78 | print("Using first stage also as cond stage.") 79 | self.cond_stage_model = self.first_stage_model 80 | else: 81 | model = instantiate_from_config(config) 82 | self.cond_stage_model = model.eval() 83 | 84 | def forward(self, x, c, warpkwargs): 85 | # one step to produce the logits 86 | _, z_indices = self.encode_to_z(x) 87 | _, c_indices = self.encode_to_c(c) 88 | 89 | if self.training and self.pkeep < 1.0: 90 | mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, 91 | device=z_indices.device)) 92 | mask = mask.round().to(dtype=torch.int64) 93 | r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) 94 | a_indices = mask*z_indices+(1-mask)*r_indices 95 | else: 96 | a_indices = z_indices 97 | 98 | cz_indices = torch.cat((c_indices, a_indices), dim=1) 99 | 100 | # target includes all sequence elements (no need to handle first one 101 | # differently because we are conditioning) 102 | target = z_indices 103 | # make the prediction 104 | logits, _ = self.transformer(cz_indices[:, :-1], warpkwargs) 105 | # cut off conditioning outputs - output i corresponds to p(z_i | z_{ 3: 183 | # colorize with random projection 184 | assert xrec.shape[1] > 3 185 | x = self.to_rgb(x) 186 | xrec = self.to_rgb(xrec) 187 | log["inputs"] = x 188 | log["reconstructions"] = xrec 189 | return log 190 | 191 | def to_rgb(self, x): 192 | assert self.image_key == "segmentation" 193 | if not hasattr(self, "colorize"): 194 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 195 | x = F.conv2d(x, weight=self.colorize) 196 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 197 | return x 198 | -------------------------------------------------------------------------------- /geofree/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/modules/__init__.py -------------------------------------------------------------------------------- /geofree/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /geofree/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from geofree.modules.losses.vqperceptual import DummyLoss 2 | -------------------------------------------------------------------------------- /geofree/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class DummyLoss(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | -------------------------------------------------------------------------------- /geofree/modules/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/modules/transformer/__init__.py -------------------------------------------------------------------------------- /geofree/modules/transformer/mingpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | credit to: https://github.com/karpathy/minGPT/ 3 | GPT model: 4 | - the initial stem consists of a combination of token encoding and a positional encoding 5 | - the meat of it is a uniform sequence of Transformer blocks 6 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block 7 | - all blocks feed into a central residual pathway similar to resnets 8 | - the final decoder is a linear projection into a vanilla Softmax classifier 9 | """ 10 | 11 | import math 12 | import logging 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | from geofree.main import instantiate_from_config 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class GPTConfig: 24 | """ base GPT config, params common to all GPT versions """ 25 | embd_pdrop = 0.1 26 | resid_pdrop = 0.1 27 | attn_pdrop = 0.1 28 | 29 | def __init__(self, vocab_size, block_size, **kwargs): 30 | self.vocab_size = vocab_size 31 | self.block_size = block_size 32 | for k,v in kwargs.items(): 33 | setattr(self, k, v) 34 | 35 | 36 | class GPT1Config(GPTConfig): 37 | """ GPT-1 like network roughly 125M params """ 38 | n_layer = 12 39 | n_head = 12 40 | n_embd = 768 41 | 42 | 43 | class CausalSelfAttention(nn.Module): 44 | def __init__(self, config): 45 | super().__init__() 46 | assert config.n_embd % config.n_head == 0, f"n_embd is {config.n_embd} but n_head is {config.n_head}." 47 | # key, query, value projections for all heads 48 | self.key = nn.Linear(config.n_embd, config.n_embd) 49 | self.query = nn.Linear(config.n_embd, config.n_embd) 50 | self.value = nn.Linear(config.n_embd, config.n_embd) 51 | # regularization 52 | self.attn_drop = nn.Dropout(config.attn_pdrop) 53 | self.resid_drop = nn.Dropout(config.resid_pdrop) 54 | # output projection 55 | self.proj = nn.Linear(config.n_embd, config.n_embd) 56 | # causal mask to ensure that attention is only applied to the left in the input sequence 57 | mask = torch.tril(torch.ones(config.block_size, 58 | config.block_size)) 59 | if hasattr(config, "n_unmasked"): 60 | mask[:config.n_unmasked, :config.n_unmasked] = 1 61 | self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) 62 | self.n_head = config.n_head 63 | 64 | def forward(self, x, layer_past=None): 65 | B, T, C = x.size() 66 | 67 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 68 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 69 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 70 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 71 | 72 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 73 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 74 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 75 | att = F.softmax(att, dim=-1) 76 | att = self.attn_drop(att) 77 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 78 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 79 | 80 | # output projection 81 | y = self.resid_drop(self.proj(y)) 82 | return y 83 | 84 | 85 | class Block(nn.Module): 86 | """ an unassuming Transformer block """ 87 | def __init__(self, config): 88 | super().__init__() 89 | self.ln1 = nn.LayerNorm(config.n_embd) 90 | self.ln2 = nn.LayerNorm(config.n_embd) 91 | self.attn = CausalSelfAttention(config) 92 | self.mlp = nn.Sequential( 93 | nn.Linear(config.n_embd, 4 * config.n_embd), 94 | nn.GELU(), # nice 95 | nn.Linear(4 * config.n_embd, config.n_embd), 96 | nn.Dropout(config.resid_pdrop), 97 | ) 98 | 99 | def forward(self, x): 100 | x = x + self.attn(self.ln1(x)) 101 | x = x + self.mlp(self.ln2(x)) 102 | return x 103 | 104 | 105 | class GPT(nn.Module): 106 | """ the full GPT language model, with a context size of block_size """ 107 | def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, 108 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0, 109 | input_vocab_size=None): 110 | super().__init__() 111 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size, 112 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, 113 | n_layer=n_layer, n_head=n_head, n_embd=n_embd, 114 | n_unmasked=n_unmasked) 115 | # input embedding stem 116 | in_vocab_size = vocab_size if not input_vocab_size else input_vocab_size 117 | self.tok_emb = nn.Embedding(in_vocab_size, config.n_embd) 118 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 119 | self.drop = nn.Dropout(config.embd_pdrop) 120 | # transformer 121 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 122 | # decoder head 123 | self.ln_f = nn.LayerNorm(config.n_embd) 124 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 125 | self.block_size = config.block_size 126 | self.apply(self._init_weights) 127 | self.config = config 128 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) 129 | 130 | def get_block_size(self): 131 | return self.block_size 132 | 133 | def _init_weights(self, module): 134 | if isinstance(module, (nn.Linear, nn.Embedding)): 135 | module.weight.data.normal_(mean=0.0, std=0.02) 136 | if isinstance(module, nn.Linear) and module.bias is not None: 137 | module.bias.data.zero_() 138 | elif isinstance(module, nn.LayerNorm): 139 | module.bias.data.zero_() 140 | module.weight.data.fill_(1.0) 141 | 142 | def forward(self, idx, embeddings=None, targets=None, return_layers=False): 143 | # forward the GPT model 144 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 145 | 146 | if embeddings is not None: # prepend explicit embeddings 147 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 148 | 149 | t = token_embeddings.shape[1] 150 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 151 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 152 | x = self.drop(token_embeddings + position_embeddings) 153 | 154 | if return_layers: 155 | layers = [x] 156 | for block in self.blocks: 157 | x = block(x) 158 | layers.append(x) 159 | return layers 160 | 161 | x = self.blocks(x) 162 | x = self.ln_f(x) 163 | logits = self.head(x) 164 | 165 | # if we are given some desired targets also calculate the loss 166 | loss = None 167 | if targets is not None: 168 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 169 | return logits, loss 170 | 171 | 172 | class DummyGPT(nn.Module): 173 | # for debugging 174 | def __init__(self, add_value=1): 175 | super().__init__() 176 | self.add_value = add_value 177 | 178 | def forward(self, idx): 179 | return idx + self.add_value, None 180 | 181 | 182 | class CodeGPT(nn.Module): 183 | """Takes in semi-embeddings""" 184 | def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256, 185 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0, clf_head=True, 186 | init_weights=True, init_last_layer_to_zero=False): 187 | super().__init__() 188 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size, 189 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, 190 | n_layer=n_layer, n_head=n_head, n_embd=n_embd, 191 | n_unmasked=n_unmasked) 192 | # input embedding stem 193 | self.tok_emb = nn.Linear(in_channels, config.n_embd) 194 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 195 | self.drop = nn.Dropout(config.embd_pdrop) 196 | # transformer 197 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 198 | 199 | self.clf_head = clf_head 200 | if self.clf_head: 201 | # decoder head 202 | self.ln_f = nn.LayerNorm(config.n_embd) 203 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 204 | else: 205 | self.head = nn.Linear(config.n_embd, in_channels) 206 | #print(f"Not using classification head of GPT model. n_unmasked is {n_unmasked}.") 207 | self.block_size = config.block_size 208 | if init_weights: 209 | self.apply(self._init_weights) 210 | if init_last_layer_to_zero: 211 | print("### WARNING: Initializing CodeGPT-head to zero ###") 212 | self.head.apply(self._init_to_zero) 213 | self.config = config 214 | #print(f"{self.__class__.__name__}: number of parameters: {sum(p.numel() for p in self.parameters())*1.0e-6:.2f} M.") 215 | 216 | def get_block_size(self): 217 | return self.block_size 218 | 219 | def _init_weights(self, module): 220 | if isinstance(module, (nn.Linear, nn.Embedding)): 221 | module.weight.data.normal_(mean=0.0, std=0.02) 222 | if isinstance(module, nn.Linear) and module.bias is not None: 223 | module.bias.data.zero_() 224 | elif isinstance(module, nn.LayerNorm): 225 | module.bias.data.zero_() 226 | module.weight.data.fill_(1.0) 227 | 228 | def _init_to_zero(self, module): 229 | if isinstance(module, (nn.Linear, nn.Embedding)): 230 | module.weight.data.zero_() 231 | if isinstance(module, nn.Linear) and module.bias is not None: 232 | module.bias.data.zero_() 233 | elif isinstance(module, nn.LayerNorm): 234 | module.bias.data.zero_() 235 | module.weight.data.zero_() 236 | 237 | 238 | def forward(self, idx, embeddings=None, targets=None): 239 | # forward the GPT model 240 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 241 | 242 | if embeddings is not None: # prepend explicit embeddings 243 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 244 | 245 | t = token_embeddings.shape[1] 246 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 247 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 248 | x = self.drop(token_embeddings + position_embeddings) 249 | x = self.blocks(x) 250 | 251 | if not self.clf_head: 252 | # maps back to in_channels 253 | return self.head(x) 254 | 255 | x = self.ln_f(x) 256 | logits = self.head(x) 257 | # if we are given some desired targets also calculate the loss 258 | loss = None 259 | if targets is not None: 260 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 261 | return logits, loss 262 | 263 | 264 | class WarpGPT(GPT): 265 | """ the full GPT language model, with a context size of block_size """ 266 | def __init__(self, **kwargs): 267 | assert "n_unmasked" in kwargs and kwargs["n_unmasked"] != 0 268 | warper_config = kwargs.pop("warper_config") 269 | if warper_config.params is None: 270 | warper_config.params = dict() 271 | warper_config.params["n_unmasked"] = kwargs["n_unmasked"] 272 | warper_config.params["block_size"] = kwargs["block_size"] 273 | warper_config.params["n_embd"] = kwargs["n_embd"] 274 | super().__init__(**kwargs) 275 | self.warper = instantiate_from_config(warper_config) 276 | 277 | def forward(self, idx, warpkwargs, embeddings=None, targets=None): 278 | # forward the GPT model 279 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 280 | 281 | if embeddings is not None: # prepend explicit embeddings 282 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 283 | 284 | t = token_embeddings.shape[1] 285 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 286 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 287 | 288 | token_embeddings, position_embeddings = self.warper(token_embeddings, 289 | position_embeddings, 290 | warpkwargs) 291 | 292 | x = self.drop(token_embeddings + position_embeddings) 293 | x = self.blocks(x) 294 | x = self.ln_f(x) 295 | logits = self.head(x) 296 | 297 | # if we are given some desired targets also calculate the loss 298 | loss = None 299 | if targets is not None: 300 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 301 | return logits, loss 302 | 303 | #### sampling utils 304 | 305 | def top_k_logits(logits, k): 306 | v, ix = torch.topk(logits, k) 307 | out = logits.clone() 308 | out[out < v[:, [-1]]] = -float('Inf') 309 | return out 310 | 311 | @torch.no_grad() 312 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): 313 | """ 314 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 315 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 316 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 317 | of block_size, unlike an RNN that has an infinite context window. 318 | """ 319 | block_size = model.get_block_size() 320 | model.eval() 321 | for k in range(steps): 322 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 323 | logits, _ = model(x_cond) 324 | # pluck the logits at the final step and scale by temperature 325 | logits = logits[:, -1, :] / temperature 326 | # optionally crop probabilities to only the top k options 327 | if top_k is not None: 328 | logits = top_k_logits(logits, top_k) 329 | # apply softmax to convert to probabilities 330 | probs = F.softmax(logits, dim=-1) 331 | # sample from the distribution or take the most likely 332 | if sample: 333 | ix = torch.multinomial(probs, num_samples=1) 334 | else: 335 | _, ix = torch.topk(probs, k=1, dim=-1) 336 | # append to the sequence and continue 337 | x = torch.cat((x, ix), dim=1) 338 | 339 | return x 340 | -------------------------------------------------------------------------------- /geofree/modules/transformer/warper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | from geofree.modules.warp.midas import Midas 6 | 7 | 8 | def disabled_train(self, mode=True): 9 | """Overwrite model.train with this function to make sure train/eval mode 10 | does not change anymore.""" 11 | return self 12 | 13 | 14 | class AbstractWarper(nn.Module): 15 | def __init__(self, *args, **kwargs): 16 | super().__init__() 17 | self._midas = Midas() 18 | self._midas.eval() 19 | self._midas.train = disabled_train 20 | for param in self._midas.parameters(): 21 | param.requires_grad = False 22 | 23 | self.n_unmasked = kwargs["n_unmasked"] # length of conditioning 24 | self.n_embd = kwargs["n_embd"] 25 | self.block_size = kwargs["block_size"] 26 | self.size = kwargs["size"] # h, w tuple 27 | self.start_idx = kwargs.get("start_idx", 0) # hint to not modify parts 28 | 29 | self._use_cache = False 30 | self.new_emb = None # cache 31 | self.new_pos = None # cache 32 | 33 | def set_cache(self, value): 34 | self._use_cache = value 35 | 36 | def get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): 37 | if self._use_cache: 38 | assert not self.training, "Do you really want to use caching during training?" 39 | assert self.new_emb is not None 40 | assert self.new_pos is not None 41 | return self.new_emb, self.new_pos 42 | self.new_emb, self.new_pos = self._get_embeddings(token_embeddings, 43 | position_embeddings, 44 | warpkwargs) 45 | return self.new_emb, self.new_pos 46 | 47 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): 48 | raise NotImplementedError() 49 | 50 | def forward(self, token_embeddings, position_embeddings, warpkwargs): 51 | new_emb, new_pos = self.get_embeddings(token_embeddings, 52 | position_embeddings, 53 | warpkwargs) 54 | 55 | new_emb = torch.cat([new_emb, token_embeddings[:,self.n_unmasked:,:]], 56 | dim=1) 57 | b = new_pos.shape[0] 58 | new_pos = torch.cat([new_pos, position_embeddings[:,self.n_unmasked:,:][b*[0],...]], 59 | dim=1) 60 | 61 | return new_emb, new_pos 62 | 63 | def _to_sequence(self, x): 64 | x = rearrange(x, 'b c h w -> b (h w) c') 65 | return x 66 | 67 | def _to_imglike(self, x): 68 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.size[0]) 69 | return x 70 | 71 | 72 | class AbstractWarperWithCustomEmbedding(AbstractWarper): 73 | def __init__(self, *args, **kwargs): 74 | super().__init__() 75 | self.pos_emb = nn.Parameter(torch.zeros(1, self.block_size, self.n_embd)) 76 | 77 | 78 | class NoSourceWarper(AbstractWarper): 79 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): 80 | cond_emb = token_embeddings[:,:self.n_unmasked,:] 81 | cond_pos = position_embeddings[:,:self.n_unmasked,:] 82 | 83 | b, seq_length, chn = cond_emb.shape 84 | cond_emb = self._to_imglike(cond_emb) 85 | 86 | cond_pos = self._to_imglike(cond_pos) 87 | cond_pos = cond_pos[b*[0],...] 88 | 89 | new_emb, _ = self._midas.warp_features(f=cond_emb, no_depth_grad=True, 90 | boltzmann_factor=0.0, 91 | **warpkwargs) 92 | new_pos, _ = self._midas.warp_features(f=cond_pos, no_depth_grad=True, 93 | boltzmann_factor=0.0, 94 | **warpkwargs) 95 | new_emb = self._filter_nans(new_emb) 96 | new_pos = self._filter_nans(new_pos) 97 | 98 | new_emb = self._to_sequence(new_emb) 99 | new_pos = self._to_sequence(new_pos) 100 | return new_emb, new_pos 101 | 102 | def _filter_nans(self, x): 103 | x[torch.isnan(x)] = 0. 104 | return x 105 | 106 | 107 | class ConvWarper(AbstractWarper): 108 | def __init__(self, *args, **kwargs): 109 | super().__init__(*args, **kwargs) 110 | self.emb_conv = nn.Conv2d(2*self.n_embd, self.n_embd, 111 | kernel_size=1, 112 | padding=0, 113 | bias=False) 114 | self.pos_conv = nn.Conv2d(2*self.n_embd, self.n_embd, 115 | kernel_size=1, 116 | padding=0, 117 | bias=False) 118 | 119 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): 120 | cond_emb = token_embeddings[:,self.start_idx:self.n_unmasked,:] 121 | cond_pos = position_embeddings[:,self.start_idx:self.n_unmasked,:] 122 | 123 | b, seq_length, chn = cond_emb.shape 124 | cond_emb = cond_emb.reshape(b, self.size[0], self.size[1], chn) 125 | cond_emb = cond_emb.permute(0,3,1,2) 126 | 127 | cond_pos = cond_pos.reshape(1, self.size[0], self.size[1], chn) 128 | cond_pos = cond_pos.permute(0,3,1,2) 129 | cond_pos = cond_pos[b*[0],...] 130 | 131 | with torch.no_grad(): 132 | warp_emb, _ = self._midas.warp_features(f=cond_emb, no_depth_grad=True, **warpkwargs) 133 | warp_pos, _ = self._midas.warp_features(f=cond_pos, no_depth_grad=True, **warpkwargs) 134 | 135 | new_emb = self.emb_conv(torch.cat([cond_emb, warp_emb], dim=1)) 136 | new_pos = self.pos_conv(torch.cat([cond_pos, warp_pos], dim=1)) 137 | 138 | new_emb = new_emb.permute(0,2,3,1) 139 | new_emb = new_emb.reshape(b,seq_length,chn) 140 | 141 | new_pos = new_pos.permute(0,2,3,1) 142 | new_pos = new_pos.reshape(b,seq_length,chn) 143 | 144 | # prepend unmodified ones again 145 | new_emb = torch.cat((token_embeddings[:,:self.start_idx,:], new_emb), 146 | dim=1) 147 | new_pos = torch.cat((position_embeddings[:,:self.start_idx,:][b*[0],...], new_pos), 148 | dim=1) 149 | 150 | return new_emb, new_pos 151 | -------------------------------------------------------------------------------- /geofree/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AbstractEncoder(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def encode(self, *args, **kwargs): 10 | raise NotImplementedError 11 | 12 | 13 | class SOSProvider(AbstractEncoder): 14 | # for unconditional training 15 | def __init__(self, sos_token, quantize_interface=True): 16 | super().__init__() 17 | self.sos_token = sos_token 18 | self.quantize_interface = quantize_interface 19 | 20 | def encode(self, x): 21 | # get batch size from data and replicate sos_token 22 | c = torch.ones(x.shape[0], 1)*self.sos_token 23 | c = c.long().to(x.device) 24 | if self.quantize_interface: 25 | return None, None, [None, None, c] 26 | return c 27 | 28 | 29 | def count_params(model, verbose=False): 30 | total_params = sum(p.numel() for p in model.parameters()) 31 | if verbose: 32 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 33 | return total_params 34 | 35 | 36 | class ActNorm(nn.Module): 37 | def __init__(self, num_features, logdet=False, affine=True, 38 | allow_reverse_init=False): 39 | assert affine 40 | super().__init__() 41 | self.logdet = logdet 42 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 43 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 44 | self.allow_reverse_init = allow_reverse_init 45 | 46 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 47 | 48 | def initialize(self, input): 49 | with torch.no_grad(): 50 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 51 | mean = ( 52 | flatten.mean(1) 53 | .unsqueeze(1) 54 | .unsqueeze(2) 55 | .unsqueeze(3) 56 | .permute(1, 0, 2, 3) 57 | ) 58 | std = ( 59 | flatten.std(1) 60 | .unsqueeze(1) 61 | .unsqueeze(2) 62 | .unsqueeze(3) 63 | .permute(1, 0, 2, 3) 64 | ) 65 | 66 | self.loc.data.copy_(-mean) 67 | self.scale.data.copy_(1 / (std + 1e-6)) 68 | 69 | def forward(self, input, reverse=False): 70 | if reverse: 71 | return self.reverse(input) 72 | if len(input.shape) == 2: 73 | input = input[:,:,None,None] 74 | squeeze = True 75 | else: 76 | squeeze = False 77 | 78 | _, _, height, width = input.shape 79 | 80 | if self.training and self.initialized.item() == 0: 81 | self.initialize(input) 82 | self.initialized.fill_(1) 83 | 84 | h = self.scale * (input + self.loc) 85 | 86 | if squeeze: 87 | h = h.squeeze(-1).squeeze(-1) 88 | 89 | if self.logdet: 90 | log_abs = torch.log(torch.abs(self.scale)) 91 | logdet = height*width*torch.sum(log_abs) 92 | logdet = logdet * torch.ones(input.shape[0]).to(input) 93 | return h, logdet 94 | 95 | return h 96 | 97 | def reverse(self, output): 98 | if self.training and self.initialized.item() == 0: 99 | if not self.allow_reverse_init: 100 | raise RuntimeError( 101 | "Initializing ActNorm in reverse direction is " 102 | "disabled by default. Use allow_reverse_init=True to enable." 103 | ) 104 | else: 105 | self.initialize(output) 106 | self.initialized.fill_(1) 107 | 108 | if len(output.shape) == 2: 109 | output = output[:,:,None,None] 110 | squeeze = True 111 | else: 112 | squeeze = False 113 | 114 | h = output / self.scale - self.loc 115 | 116 | if squeeze: 117 | h = h.squeeze(-1).squeeze(-1) 118 | return h 119 | 120 | 121 | class Embedder(nn.Module): 122 | """to replace the convolutional architecture entirely""" 123 | def __init__(self, n_positions, n_channels, n_embed, bias=False): 124 | super().__init__() 125 | self.n_positions = n_positions 126 | self.n_channels = n_channels 127 | self.n_embed = n_embed 128 | self.fc = nn.Linear(self.n_channels, self.n_embed, bias=bias) 129 | 130 | def forward(self, x): 131 | x = x.reshape(x.shape[0], self.n_positions, self.n_channels) 132 | x = self.fc(x) 133 | return x 134 | 135 | 136 | class MultiEmbedder(nn.Module): 137 | def __init__(self, keys, n_positions, n_channels, n_embed, bias=False): 138 | super().__init__() 139 | self.keys = keys 140 | self.n_positions = n_positions 141 | self.n_channels = n_channels 142 | self.n_embed = n_embed 143 | self.fc = nn.Linear(self.n_channels, self.n_embed, bias=bias) 144 | 145 | def forward(self, **kwargs): 146 | values = [kwargs[k] for k in self.keys] 147 | inputs = list() 148 | for k in self.keys: 149 | entry = kwargs[k].reshape(kwargs[k].shape[0], -1, self.n_channels) 150 | inputs.append(entry) 151 | x = torch.cat(inputs, dim=1) 152 | assert x.shape[1] == self.n_positions, x.shape 153 | x = self.fc(x) 154 | return x 155 | 156 | 157 | class SpatialEmbedder(nn.Module): 158 | def __init__(self, keys, n_channels, n_embed, bias=False, shape=[13, 23]): 159 | # here, n_channels = dim(params) 160 | super().__init__() 161 | self.shape = shape 162 | self.keys = keys 163 | self.n_channels = n_channels 164 | self.n_embed = n_embed 165 | self.linear = nn.Conv2d(self.n_channels, self.n_embed, 1, bias=bias) 166 | 167 | def forward(self, **kwargs): 168 | inputs = list() 169 | for k in self.keys: 170 | entry = kwargs[k].reshape(kwargs[k].shape[0], -1, 1, 1) 171 | inputs.append(entry) 172 | x = torch.cat(inputs, dim=1) # b, n_channels, 1, 1 173 | assert x.shape[1] == self.n_channels, f"expecting {self.n_channels} channels but got {x.shape[1]}" 174 | x = x.repeat(1, 1, self.shape[0], self.shape[1]) # duplicate spatially 175 | x = self.linear(x) 176 | return x 177 | 178 | 179 | def to_rgb(model, x): 180 | if not hasattr(model, "colorize"): 181 | model.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 182 | x = nn.functional.conv2d(x, weight=model.colorize) 183 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 184 | return x 185 | -------------------------------------------------------------------------------- /geofree/modules/vqvae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/modules/vqvae/__init__.py -------------------------------------------------------------------------------- /geofree/modules/vqvae/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import einsum 5 | 6 | 7 | class VectorQuantizer(nn.Module): 8 | """ 9 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py 10 | ____________________________________________ 11 | Discretization bottleneck part of the VQ-VAE. 12 | Inputs: 13 | - n_e : number of embeddings 14 | - e_dim : dimension of embedding 15 | - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 16 | _____________________________________________ 17 | """ 18 | 19 | def __init__(self, n_e, e_dim, beta): 20 | super(VectorQuantizer, self).__init__() 21 | self.n_e = n_e 22 | self.e_dim = e_dim 23 | self.beta = beta 24 | 25 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 26 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 27 | 28 | def forward(self, z): 29 | """ 30 | Inputs the output of the encoder network z and maps it to a discrete 31 | one-hot vector that is the index of the closest embedding vector e_j 32 | z (continuous) -> z_q (discrete) 33 | z.shape = (batch, channel, height, width) 34 | quantization pipeline: 35 | 1. get encoder input (B,C,H,W) 36 | 2. flatten input to (B*H*W,C) 37 | """ 38 | # reshape z -> (batch, height, width, channel) and flatten 39 | z = z.permute(0, 2, 3, 1).contiguous() 40 | z_flattened = z.view(-1, self.e_dim) 41 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 42 | 43 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 44 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 45 | torch.matmul(z_flattened, self.embedding.weight.t()) 46 | 47 | # find closest encodings 48 | min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) 49 | 50 | min_encodings = torch.zeros( 51 | min_encoding_indices.shape[0], self.n_e).to(z) 52 | min_encodings.scatter_(1, min_encoding_indices, 1) 53 | 54 | # get quantized latent vectors 55 | z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) 56 | 57 | # compute loss for embedding 58 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 59 | torch.mean((z_q - z.detach()) ** 2) 60 | 61 | # preserve gradients 62 | z_q = z + (z_q - z).detach() 63 | 64 | # perplexity 65 | e_mean = torch.mean(min_encodings, dim=0) 66 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) 67 | 68 | # reshape back to match original input shape 69 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 70 | 71 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 72 | 73 | def get_codebook_entry(self, indices, shape): 74 | # shape specifying (batch, height, width, channel) 75 | min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) 76 | min_encodings.scatter_(1, indices[:,None], 1) 77 | 78 | # get quantized latent vectors 79 | z_q = torch.matmul(min_encodings.float(), self.embedding.weight) 80 | 81 | if shape is not None: 82 | z_q = z_q.view(shape) 83 | 84 | # reshape back to match original input shape 85 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 86 | 87 | return z_q 88 | -------------------------------------------------------------------------------- /geofree/modules/warp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/modules/warp/__init__.py -------------------------------------------------------------------------------- /geofree/modules/warp/midas.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from splatting import splatting_function 8 | 9 | # pretty much like eval_re_midas but b,c,h,w format and module which should be 10 | # finetunable 11 | 12 | def render_forward(src_ims, src_dms, 13 | R, t, 14 | K_src_inv, 15 | K_dst, 16 | alpha=None): 17 | # R: b,3,3 18 | # t: b,3 19 | # K_dst: b,3,3 20 | # K_src_inv: b,3,3 21 | t = t[...,None] 22 | 23 | ####### 24 | 25 | assert len(src_ims.shape) == 4 # b,c,h,w 26 | assert len(src_dms.shape) == 3 # b,h,w 27 | assert src_ims.shape[2:4] == src_dms.shape[1:3], (src_ims.shape, 28 | src_dms.shape) 29 | 30 | x = np.arange(src_ims.shape[3]) 31 | y = np.arange(src_ims.shape[2]) 32 | coord = np.stack(np.meshgrid(x,y), -1) 33 | coord = np.concatenate((coord, np.ones_like(coord)[:,:,[0]]), -1) # z=1 34 | coord = coord.astype(np.float32) 35 | coord = torch.as_tensor(coord, dtype=K_dst.dtype, device=K_dst.device) 36 | coord = coord[None] # b,h,w,3 37 | 38 | D = src_dms[:,:,:,None,None] # b,h,w,1,1 39 | 40 | points = K_dst[:,None,None,...]@(R[:,None,None,...]@(D*K_src_inv[:,None,None,...]@coord[:,:,:,:,None])+t[:,None,None,:,:]) 41 | points = points.squeeze(-1) 42 | 43 | new_z = points[:,:,:,[2]].clone().permute(0,3,1,2) # b,1,h,w 44 | points = points/torch.clamp(points[:,:,:,[2]], 1e-8, None) 45 | 46 | flow = points - coord 47 | flow = flow.permute(0,3,1,2)[:,:2,...] 48 | 49 | if alpha is not None: 50 | # used to be 50 but this is unstable even if we subtract the maximum 51 | importance = alpha/new_z 52 | #importance = importance-importance.amin((1,2,3),keepdim=True) 53 | importance = importance.exp() 54 | else: 55 | # use heuristic to rescale import between 0 and 10 to be stable in 56 | # float32 57 | importance = 1.0/new_z 58 | importance_min = importance.amin((1,2,3),keepdim=True) 59 | importance_max = importance.amax((1,2,3),keepdim=True) 60 | importance=(importance-importance_min)/(importance_max-importance_min+1e-6)*10-10 61 | importance = importance.exp() 62 | 63 | input_data = torch.cat([importance*src_ims, importance], 1) 64 | output_data = splatting_function("summation", input_data, flow) 65 | 66 | num = output_data[:,:-1,:,:] 67 | nom = output_data[:,-1:,:,:] 68 | 69 | #rendered = num/(nom+1e-7) 70 | rendered = num/nom.clamp(min=1e-8) 71 | return rendered 72 | 73 | 74 | 75 | class Midas(nn.Module): 76 | def __init__(self): 77 | super().__init__() 78 | midas = torch.hub.load("intel-isl/MiDaS", "MiDaS") 79 | 80 | self.midas = midas 81 | 82 | # parameters to reproduce the provided transform 83 | mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32) 84 | mean = mean.reshape(1,3,1,1) 85 | self.register_buffer("mean", mean) 86 | std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32) 87 | std = std.reshape(1,3,1,1) 88 | self.register_buffer("std", std) 89 | self.__height = 384 90 | self.__width = 384 91 | self.__keep_aspect_ratio = True 92 | self.__resize_method = "upper_bound" 93 | self.__multiple_of = 32 94 | 95 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 96 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 97 | 98 | if max_val is not None and y > max_val: 99 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 100 | 101 | if y < min_val: 102 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 103 | 104 | return y 105 | 106 | def get_size(self, width, height): 107 | # determine new height and width 108 | scale_height = self.__height / height 109 | scale_width = self.__width / width 110 | 111 | if self.__keep_aspect_ratio: 112 | if self.__resize_method == "lower_bound": 113 | # scale such that output size is lower bound 114 | if scale_width > scale_height: 115 | # fit width 116 | scale_height = scale_width 117 | else: 118 | # fit height 119 | scale_width = scale_height 120 | elif self.__resize_method == "upper_bound": 121 | # scale such that output size is upper bound 122 | if scale_width < scale_height: 123 | # fit width 124 | scale_height = scale_width 125 | else: 126 | # fit height 127 | scale_width = scale_height 128 | elif self.__resize_method == "minimal": 129 | # scale as least as possbile 130 | if abs(1 - scale_width) < abs(1 - scale_height): 131 | # fit width 132 | scale_height = scale_width 133 | else: 134 | # fit height 135 | scale_width = scale_height 136 | else: 137 | raise ValueError( 138 | f"resize_method {self.__resize_method} not implemented" 139 | ) 140 | 141 | if self.__resize_method == "lower_bound": 142 | new_height = self.constrain_to_multiple_of( 143 | scale_height * height, min_val=self.__height 144 | ) 145 | new_width = self.constrain_to_multiple_of( 146 | scale_width * width, min_val=self.__width 147 | ) 148 | elif self.__resize_method == "upper_bound": 149 | new_height = self.constrain_to_multiple_of( 150 | scale_height * height, max_val=self.__height 151 | ) 152 | new_width = self.constrain_to_multiple_of( 153 | scale_width * width, max_val=self.__width 154 | ) 155 | elif self.__resize_method == "minimal": 156 | new_height = self.constrain_to_multiple_of(scale_height * height) 157 | new_width = self.constrain_to_multiple_of(scale_width * width) 158 | else: 159 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 160 | 161 | return (new_width, new_height) 162 | 163 | def resize(self, x): 164 | assert len(x.shape)==4 165 | assert x.shape[1]==3, x.shape 166 | 167 | width, height = self.get_size( 168 | x.shape[3], x.shape[2] 169 | ) 170 | x = torch.nn.functional.interpolate( 171 | x, 172 | size=(height,width), 173 | mode="bicubic", 174 | align_corners=False, 175 | ) 176 | return x 177 | 178 | def __call__(self, x, clamp=True, out_size="original"): 179 | assert len(x.shape)==4, x.shape 180 | assert x.shape[1]==3, x.shape 181 | assert -1.0 <= x.min() <= x.max() <= 1.0 182 | 183 | # replace provided transform by differentiable one supporting batches 184 | if out_size == "original": 185 | out_size = x.shape[2:4] 186 | # to [0,1] 187 | x = (x+1.0)/2.0 188 | # resize 189 | x = self.resize(x) 190 | # normalize (x-mean)/std 191 | x = (x - self.mean)/self.std 192 | # prepare = transpose to (b,c,h,w) 193 | 194 | x = self.midas(x) 195 | 196 | if out_size is not None: 197 | x = torch.nn.functional.interpolate( 198 | x.unsqueeze(1), 199 | size=out_size, 200 | mode="bicubic", 201 | align_corners=False, 202 | ).squeeze(1) 203 | 204 | if clamp: 205 | # negative values due to resizing 206 | x = x.clamp(min=1e-8) 207 | 208 | return x 209 | 210 | def scaled_depth(self, x, points, return_inverse_depth=False): 211 | b,c,h,w = x.shape 212 | assert c==3, c 213 | b_,_,c_ = points.shape 214 | assert b==b_ 215 | assert c_==3 216 | 217 | dm = self(x) 218 | 219 | xys = points[:,:,:2] 220 | xys = xys.round().to(dtype=torch.long) 221 | xys[:,:,0] = xys[:,:,0].clamp(min=0,max=w-1) 222 | xys[:,:,1] = xys[:,:,1].clamp(min=0,max=h-1) 223 | indices = xys[:,:,1]*w+xys[:,:,0] # b,N 224 | flatdm = torch.zeros(b,h*w, dtype=dm.dtype, device=dm.device) 225 | flatz = points[:,:,2] # b,N 226 | flatdm.scatter_(dim=1, index=indices, src=flatz) 227 | sparse_dm = flatdm.reshape(b,h,w) 228 | mask = (sparse_dm<1e-3).to(dtype=sparse_dm.dtype, 229 | device=sparse_dm.device) 230 | 231 | #error = (1-mask)*(dm[i]-(scale*dm+offset))**2 232 | N = (1-mask).sum(dim=(1,2)) # b 233 | m_sparse_dm = (1-mask)*sparse_dm # was mdm 234 | m_sparse_dm[(1-mask)>0] = 1.0/m_sparse_dm[(1-mask)>0] # we align disparity 235 | m_dm = (1-mask)*dm # was mmi 236 | s = ((m_dm*m_sparse_dm).sum(dim=(1,2))-1/N*m_sparse_dm.sum(dim=(1,2))*m_dm.sum(dim=(1,2))) / ((m_dm**2).sum(dim=(1,2))-1/N*(m_dm.sum(dim=(1,2))**2)) 237 | c = 1/N*(m_sparse_dm.sum(dim=(1,2))-s*m_dm.sum(dim=(1,2))) 238 | 239 | scaled_dm = s[:,None,None]*dm + c[:,None,None] 240 | scaled_dm = scaled_dm.clamp(min=1e-8) 241 | if not return_inverse_depth: 242 | scaled_dm[scaled_dm!=0] = 1.0/scaled_dm[scaled_dm!=0] # disparity to depth 243 | 244 | return scaled_dm 245 | 246 | def fixed_scale_depth(self, x, return_inverse_depth=False, scale=[0.18577382, 0.93059154]): 247 | b,c,h,w = x.shape 248 | assert c==3, c 249 | 250 | dm = self(x) 251 | dmmin = dm.amin(dim=(1,2), keepdim=True) 252 | dmmax = dm.amax(dim=(1,2), keepdim=True) 253 | scaled_dm = (dm-dmmin)/(dmmax-dmmin)*(scale[1]-scale[0])+scale[0] 254 | 255 | if not return_inverse_depth: 256 | scaled_dm[scaled_dm!=0] = 1.0/scaled_dm[scaled_dm!=0] # disparity to depth 257 | 258 | return scaled_dm 259 | 260 | def warp(self, x, points, R, t, K_src_inv, K_dst): 261 | src_dms = self.scaled_depth(x, points) 262 | wrp = render_forward(src_ims=x, src_dms=src_dms, 263 | R=R, t=t, 264 | K_src_inv=K_src_inv, K_dst=K_dst) 265 | return wrp, src_dms 266 | 267 | def warp_features(self, f, x, points, R, t, K_src_inv, K_dst, 268 | no_depth_grad=False, boltzmann_factor=None): 269 | b,c,h,w = f.shape 270 | 271 | context = torch.no_grad() if no_depth_grad else nullcontext() 272 | with context: 273 | src_dms = self.scaled_depth(x, points) 274 | 275 | # rescale depth map to feature map size 276 | src_dms = torch.nn.functional.interpolate( 277 | src_dms.unsqueeze(1), 278 | size=(h,w), 279 | mode="bicubic", 280 | align_corners=False, 281 | ).squeeze(1) 282 | 283 | 284 | # rescale intrinsics to feature map size 285 | K_dst = K_dst.clone() 286 | K_dst[:,0,:] *= f.shape[3]/x.shape[3] 287 | K_dst[:,1,:] *= f.shape[2]/x.shape[2] 288 | K_src_inv = K_src_inv.clone() 289 | K_src_inv[:,0,0] /= f.shape[3]/x.shape[3] 290 | K_src_inv[:,1,1] /= f.shape[3]/x.shape[3] 291 | 292 | wrp = render_forward(src_ims=f, src_dms=src_dms, 293 | R=R, t=t, 294 | K_src_inv=K_src_inv, K_dst=K_dst, 295 | alpha=boltzmann_factor) 296 | return wrp, src_dms 297 | -------------------------------------------------------------------------------- /geofree/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import requests 3 | from tqdm import tqdm 4 | import torch 5 | import os 6 | from omegaconf import OmegaConf 7 | 8 | from geofree.models.transformers.geogpt import GeoTransformer 9 | 10 | 11 | URL_MAP = { 12 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1", 13 | "re_first_stage": "https://heibox.uni-heidelberg.de/f/6db3e2beebe34c1e8d06/?dl=1", 14 | "re_depth_stage": "https://heibox.uni-heidelberg.de/f/c5e51b91377942d7be18/?dl=1", 15 | "re_impl_depth_config": "https://heibox.uni-heidelberg.de/f/234e2e21a6414690b663/?dl=1", 16 | "re_impl_depth": "https://heibox.uni-heidelberg.de/f/740100d4cdbd46d4bb39/?dl=1", 17 | "re_impl_nodepth_config": "https://heibox.uni-heidelberg.de/f/21989b42cea544bbbb9e/?dl=1", 18 | "re_impl_nodepth": "https://heibox.uni-heidelberg.de/f/b909c3d4ac7143209387/?dl=1", 19 | "ac_first_stage": "https://heibox.uni-heidelberg.de/f/f21118f64fde4134917b/?dl=1", 20 | "ac_depth_stage": "https://heibox.uni-heidelberg.de/f/fd4bd78496884990b75e/?dl=1", 21 | "ac_impl_depth_config": "https://heibox.uni-heidelberg.de/f/6e3081df54e245cd9aaa/?dl=1", 22 | "ac_impl_depth": "https://heibox.uni-heidelberg.de/f/6e3081df54e245cd9aaa/?dl=1", 23 | "ac_impl_nodepth_config": "https://heibox.uni-heidelberg.de/f/0a2749e895784c0099a4/?dl=1", 24 | "ac_impl_nodepth": "https://heibox.uni-heidelberg.de/f/bead9082ed0c425fb6b4/?dl=1", 25 | } 26 | 27 | CKPT_MAP = { 28 | "vgg_lpips": "geofree/lpips/vgg.pth", 29 | "re_first_stage": "geofree/re_first_stage/last.ckpt", 30 | "re_depth_stage": "geofree/re_depth_stage/last.ckpt", 31 | "re_impl_depth_config": "geofree/re_impl_depth/config.yaml", 32 | "re_impl_depth": "geofree/re_impl_depth/last.ckpt", 33 | "re_impl_nodepth_config": "geofree/re_impl_nodepth/config.yaml", 34 | "re_impl_nodepth": "geofree/re_impl_nodepth/last.ckpt", 35 | "ac_first_stage": "geofree/ac_first_stage/last.ckpt", 36 | "ac_depth_stage": "geofree/ac_depth_stage/last.ckpt", 37 | "ac_impl_depth_config": "geofree/ac_impl_depth/config.yaml", 38 | "ac_impl_depth": "geofree/ac_impl_depth/last.ckpt", 39 | "ac_impl_nodepth_config": "geofree/ac_impl_nodepth/config.yaml", 40 | "ac_impl_nodepth": "geofree/ac_impl_nodepth/last.ckpt", 41 | } 42 | CACHE = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) 43 | CKPT_MAP = dict((k, os.path.join(CACHE, CKPT_MAP[k])) for k in CKPT_MAP) 44 | 45 | MD5_MAP = { 46 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a", 47 | "re_first_stage": "b8b999aba6b618757329c1f114e9f5a5", 48 | "re_depth_stage": "ab35861b9050476d02fa8f4c8f761d61", 49 | "re_impl_depth_config": "e75df96667a3a6022ca2d5b27515324b", 50 | "re_impl_depth": "144dcdeb1379760d2f1ae763cabcac85", 51 | "re_impl_nodepth_config": "351c976463c4740fc575a7c64a836624", 52 | "re_impl_nodepth": "b6646db26a756b80840aa2b6aca924d8", 53 | "ac_first_stage": "657f683698477be24254b2be85fbced0", 54 | "ac_depth_stage": "e75ef829cf72be4f5f252450bdeb10db", 55 | "ac_impl_depth_config": "79b38a57fe4a195165a79c795492092c", 56 | "ac_impl_depth": "e92673e24acb28977d74a2fc106bcfb1", 57 | "ac_impl_nodepth_config": "22e1a55122cef561597cc5c040d45fba", 58 | "ac_impl_nodepth": "22e1a55122cef561597cc5c040d45fba", 59 | } 60 | 61 | 62 | def download(url, local_path, chunk_size=1024): 63 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 64 | with requests.get(url, stream=True) as r: 65 | total_size = int(r.headers.get("content-length", 0)) 66 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 67 | with open(local_path, "wb") as f: 68 | for data in r.iter_content(chunk_size=chunk_size): 69 | if data: 70 | f.write(data) 71 | pbar.update(chunk_size) 72 | 73 | 74 | def md5_hash(path): 75 | with open(path, "rb") as f: 76 | content = f.read() 77 | return hashlib.md5(content).hexdigest() 78 | 79 | 80 | def get_local_path(name, root=None, check=False): 81 | path = CKPT_MAP[name] 82 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 83 | assert name in URL_MAP, name 84 | print("Downloading {} from {} to {}".format(name, URL_MAP[name], path)) 85 | download(URL_MAP[name], path) 86 | md5 = md5_hash(path) 87 | assert md5 == MD5_MAP[name], md5 88 | return path 89 | 90 | 91 | def pretrained_models(model="re_impl_nodepth"): 92 | prefix = model[:2] 93 | assert prefix in ["re", "ac"], "not implemented" 94 | 95 | config_path = get_local_path(model+"_config") 96 | config = OmegaConf.load(config_path) 97 | config.model.params.first_stage_config.params["ckpt_path"] = get_local_path(f"{prefix}_first_stage") 98 | config.model.params.depth_stage_config.params["ckpt_path"] = get_local_path(f"{prefix}_depth_stage") 99 | 100 | ckpt_path = get_local_path(model) 101 | 102 | model = GeoTransformer(**config.model.params) 103 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 104 | missing, unexpected = model.load_state_dict(sd, strict=False) 105 | model = model.eval() 106 | print(f"Restored model from {ckpt_path}") 107 | return model 108 | -------------------------------------------------------------------------------- /scripts/braindance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys, os 3 | import argparse 4 | import pygame 5 | import torch 6 | import numpy as np 7 | from PIL import Image 8 | from splatting import splatting_function 9 | from torch.utils.data.dataloader import default_collate 10 | from geofree import pretrained_models 11 | import imageio 12 | 13 | torch.set_grad_enabled(False) 14 | 15 | from geofree.modules.warp.midas import Midas 16 | 17 | from tkinter.filedialog import askopenfilename 18 | from tkinter import Tk 19 | 20 | 21 | def to_surface(x, text=None): 22 | if hasattr(x, "detach"): 23 | x = x.detach().cpu().numpy() 24 | x = x.transpose(1,0,2) 25 | x = (x+1.0)*127.5 26 | x = x.clip(0, 255).astype(np.uint8) 27 | if text is not None: 28 | from PIL import ImageDraw, ImageFont 29 | fontsize=22 30 | try: 31 | font = ImageFont.truetype("/usr/share/fonts/liberation/LiberationSans-BoldItalic.ttf", fontsize) 32 | except OSError: 33 | font = ImageFont.load_default() 34 | margin = 8 35 | 36 | x = x.transpose(1,0,2) 37 | pos = (margin, x.shape[0]-fontsize-margin//2) 38 | x = x.astype(np.float) 39 | x[x.shape[0]-fontsize-margin:,:,:] *= 0.5 40 | x = x.astype(np.uint8) 41 | 42 | img = Image.fromarray(x) 43 | ImageDraw.Draw(img).text(pos, f'{text}', (255, 255, 255), font=font) # coordinates, text, color, font 44 | x = np.array(img) 45 | x = x.transpose(1,0,2) 46 | return pygame.surfarray.make_surface(x), x.transpose(1,0,2) 47 | 48 | 49 | def render_forward(src_ims, src_dms, 50 | Rcam, tcam, 51 | K_src, 52 | K_dst): 53 | Rcam = Rcam.to(device=src_ims.device)[None] 54 | tcam = tcam.to(device=src_ims.device)[None] 55 | 56 | R = Rcam 57 | t = tcam[...,None] 58 | K_src_inv = K_src.inverse() 59 | 60 | assert len(src_ims.shape) == 4 61 | assert len(src_dms.shape) == 3 62 | assert src_ims.shape[1:3] == src_dms.shape[1:3], (src_ims.shape, 63 | src_dms.shape) 64 | 65 | x = np.arange(src_ims[0].shape[1]) 66 | y = np.arange(src_ims[0].shape[0]) 67 | coord = np.stack(np.meshgrid(x,y), -1) 68 | coord = np.concatenate((coord, np.ones_like(coord)[:,:,[0]]), -1) # z=1 69 | coord = coord.astype(np.float32) 70 | coord = torch.as_tensor(coord, dtype=K_src.dtype, device=K_src.device) 71 | coord = coord[None] # bs, h, w, 3 72 | 73 | D = src_dms[:,:,:,None,None] 74 | 75 | points = K_dst[None,None,None,...]@(R[:,None,None,...]@(D*K_src_inv[None,None,None,...]@coord[:,:,:,:,None])+t[:,None,None,:,:]) 76 | points = points.squeeze(-1) 77 | 78 | new_z = points[:,:,:,[2]].clone().permute(0,3,1,2) # b,1,h,w 79 | points = points/torch.clamp(points[:,:,:,[2]], 1e-8, None) 80 | 81 | src_ims = src_ims.permute(0,3,1,2) 82 | flow = points - coord 83 | flow = flow.permute(0,3,1,2)[:,:2,...] 84 | 85 | alpha = 0.5 86 | importance = alpha/new_z 87 | importance_min = importance.amin((1,2,3),keepdim=True) 88 | importance_max = importance.amax((1,2,3),keepdim=True) 89 | importance=(importance-importance_min)/(importance_max-importance_min+1e-6)*10-10 90 | importance = importance.exp() 91 | 92 | input_data = torch.cat([importance*src_ims, importance], 1) 93 | output_data = splatting_function("summation", input_data, flow) 94 | 95 | num = torch.sum(output_data[:,:-1,:,:], dim=0, keepdim=True) 96 | nom = torch.sum(output_data[:,-1:,:,:], dim=0, keepdim=True) 97 | 98 | rendered = num/(nom+1e-7) 99 | rendered = rendered.permute(0,2,3,1)[0,...] 100 | return rendered 101 | 102 | def normalize(x): 103 | return x/np.linalg.norm(x) 104 | 105 | def cosd(x): 106 | return np.cos(np.deg2rad(x)) 107 | 108 | def sind(x): 109 | return np.sin(np.deg2rad(x)) 110 | 111 | def look_to(camera_pos, camera_dir, camera_up): 112 | camera_right = normalize(np.cross(camera_up, camera_dir)) 113 | R = np.zeros((4, 4)) 114 | R[0,0:3] = normalize(camera_right) 115 | R[1,0:3] = normalize(np.cross(camera_dir, camera_right)) 116 | R[2,0:3] = normalize(camera_dir) 117 | R[3,3] = 1 118 | trans_matrix = np.array([[1.0, 0.0, 0.0, -camera_pos[0]], 119 | [0.0, 1.0, 0.0, -camera_pos[1]], 120 | [0.0, 0.0, 1.0, -camera_pos[2]], 121 | [0.0, 0.0, 0.0, 1.0]]) 122 | tmp = R@trans_matrix 123 | return tmp[:3,:3], tmp[:3,3] 124 | 125 | def rotate_around_axis(angle, axis): 126 | axis = normalize(axis) 127 | rotation = np.array([[cosd(angle)+axis[0]**2*(1-cosd(angle)), 128 | axis[0]*axis[1]*(1-cosd(angle))-axis[2]*sind(angle), 129 | axis[0]*axis[2]*(1-cosd(angle))+axis[1]*sind(angle)], 130 | [axis[1]*axis[0]*(1-cosd(angle))+axis[2]*sind(angle), 131 | cosd(angle)+axis[1]**2*(1-cosd(angle)), 132 | axis[1]*axis[2]*(1-cosd(angle))-axis[0]*sind(angle)], 133 | [axis[2]*axis[0]*(1-cosd(angle))-axis[1]*sind(angle), 134 | axis[2]*axis[1]*(1-cosd(angle))+axis[0]*sind(angle), 135 | cosd(angle)+axis[2]**2*(1-cosd(angle))]]) 136 | return rotation 137 | 138 | 139 | class Renderer(object): 140 | def __init__(self, model, device): 141 | self.model = pretrained_models(model=model) 142 | self.model = self.model.to(device=device) 143 | self._active = False 144 | # rough estimates for min and maximum inverse depth values on the 145 | # training datasets 146 | if model.startswith("re"): 147 | self.scale = [0.18577382, 0.93059154] 148 | else: 149 | self.scale = [1e-8, 0.75] 150 | 151 | def init(self, 152 | start_im, 153 | example, 154 | show_R, 155 | show_t): 156 | self._active = True 157 | self.step = 0 158 | 159 | batch = self.batch = default_collate([example]) 160 | batch["R_rel"] = show_R[None,...] 161 | batch["t_rel"] = show_t[None,...] 162 | 163 | _, cdict, edict = self.model.get_xce(batch) 164 | for k in cdict: 165 | cdict[k] = cdict[k].to(device=self.model.device) 166 | for k in edict: 167 | edict[k] = edict[k].to(device=self.model.device) 168 | 169 | quant_d, quant_c, dc_indices, embeddings = self.model.get_normalized_c( 170 | cdict, edict, fixed_scale=True, scale=self.scale) 171 | 172 | start_im = start_im[None,...].to(self.model.device).permute(0,3,1,2) 173 | quant_c, c_indices = self.model.encode_to_c(c=start_im) 174 | cond_rec = self.model.cond_stage_model.decode(quant_c) 175 | 176 | self.current_im = cond_rec.permute(0,2,3,1)[0] 177 | self.current_sample = c_indices 178 | 179 | self.quant_c = quant_c # to know shape 180 | # for sampling 181 | self.dc_indices = dc_indices 182 | self.embeddings = embeddings 183 | 184 | def __call__(self): 185 | if self.step < self.current_sample.shape[1]: 186 | z_start_indices = self.current_sample[:, :self.step] 187 | temperature=None 188 | top_k=250 189 | callback=None 190 | index_sample = self.model.sample(z_start_indices, self.dc_indices, 191 | steps=1, 192 | temperature=temperature if temperature is not None else 1.0, 193 | sample=True, 194 | top_k=top_k if top_k is not None else 100, 195 | callback=callback if callback is not None else lambda k: None, 196 | embeddings=self.embeddings) 197 | self.current_sample = torch.cat((index_sample, 198 | self.current_sample[:,self.step+1:]), 199 | dim=1) 200 | 201 | sample_dec = self.model.decode_to_img(self.current_sample, 202 | self.quant_c.shape) 203 | self.current_im = sample_dec.permute(0,2,3,1)[0] 204 | self.step += 1 205 | 206 | if self.step >= self.current_sample.shape[1]: 207 | self._active = False 208 | 209 | return self.current_im 210 | 211 | def active(self): 212 | return self._active 213 | 214 | def reconstruct(self, x): 215 | x = x.to(self.model.device).permute(0,3,1,2) 216 | quant_c, c_indices = self.model.encode_to_c(c=x) 217 | x_rec = self.model.cond_stage_model.decode(quant_c) 218 | return x_rec.permute(0,2,3,1) 219 | 220 | 221 | def load_as_example(path, model="re"): 222 | size = [208, 368] 223 | im = Image.open(path) 224 | w,h = im.size 225 | if np.abs(w/h - size[1]/size[0]) > 0.1: 226 | print(f"Center cropping {path} to AR {size[1]/size[0]}") 227 | if w/h < size[1]/size[0]: 228 | # crop h 229 | left = 0 230 | right = w 231 | top = h/2 - size[0]/size[1]*w/2 232 | bottom = h/2 + size[0]/size[1]*w/2 233 | else: 234 | # crop w 235 | top = 0 236 | bottom = h 237 | left = w/2 - size[1]/size[0]*h 238 | right = w/2 + size[1]/size[0]*h 239 | im = im.crop(box=(left, top, right, bottom)) 240 | 241 | im = im.resize((size[1],size[0]), 242 | resample=Image.LANCZOS) 243 | im = np.array(im)/127.5-1.0 244 | im = im.astype(np.float32) 245 | 246 | example = dict() 247 | example["src_img"] = im 248 | if model.startswith("re"): 249 | example["K"] = np.array([[184.0, 0.0, 184.0], 250 | [0.0, 184.0, 104.0], 251 | [0.0, 0.0, 1.0]], dtype=np.float32) 252 | elif model.startswith("ac"): 253 | example["K"] = np.array([[200.0, 0.0, 184.0], 254 | [0.0, 200.0, 104.0], 255 | [0.0, 0.0, 1.0]], dtype=np.float32) 256 | else: 257 | raise NotImplementedError() 258 | example["K_inv"] = np.linalg.inv(example["K"]) 259 | 260 | ## dummy data not used during inference 261 | example["dst_img"] = np.zeros_like(example["src_img"]) 262 | example["src_points"] = np.zeros((1,3), dtype=np.float32) 263 | 264 | return example 265 | 266 | 267 | if __name__ == "__main__": 268 | helptxt = "What's up, BD-maniacs?\n\n"+"\n".join([ 269 | "{: <12} {: <24}".format("key(s)", "action"), 270 | "="*37, 271 | "{: <12} {: <24}".format("wasd", "move around"), 272 | "{: <12} {: <24}".format("arrows", "look around"), 273 | "{: <12} {: <24}".format("m", "enable looking with mouse"), 274 | "{: <12} {: <24}".format("space", "render with transformer"), 275 | "{: <12} {: <24}".format("q", "quit"), 276 | ]) 277 | parser = argparse.ArgumentParser(description=helptxt, 278 | formatter_class=argparse.RawTextHelpFormatter) 279 | parser.add_argument('path', type=str, nargs='?', default=None, 280 | help='path to image or directory from which to select ' 281 | 'image. Default example is used if not specified.') 282 | parser.add_argument('--model', choices=["re_impl_nodepth", "re_impl_depth", 283 | "ac_impl_nodepth", "ac_impl_depth"], 284 | default="re_impl_nodepth", 285 | help='pretrained model to use.') 286 | parser.add_argument('--video', type=str, nargs='?', default=None, 287 | help='path to write video recording to. (no recording if unspecified).') 288 | opt = parser.parse_args() 289 | print(helptxt) 290 | 291 | if torch.cuda.is_available(): 292 | device = torch.device("cuda") 293 | else: 294 | print("Warning: Running on CPU---sampling might take a while...") 295 | device = torch.device("cpu") 296 | midas = Midas().eval().to(device) 297 | # init transformer 298 | renderer = Renderer(model=opt.model, device=device) 299 | 300 | if opt.path is None: 301 | try: 302 | import importlib.resources as pkg_resources 303 | except ImportError: 304 | import importlib_resources as pkg_resources 305 | 306 | example_name = "artist.jpg" if opt.model.startswith("re") else "beach.jpg" 307 | with pkg_resources.path("geofree.examples", example_name) as path: 308 | example = load_as_example(path, model=opt.model) 309 | else: 310 | path = opt.path 311 | if not os.path.isfile(path): 312 | Tk().withdraw() 313 | path = askopenfilename(initialdir=sys.argv[1]) 314 | example = load_as_example(path, model=opt.model) 315 | 316 | ims = example["src_img"][None,...] 317 | K = example["K"] 318 | 319 | # compute depth for preview 320 | dms = [None] 321 | for i in range(ims.shape[0]): 322 | midas_in = torch.tensor(ims[i])[None,...].permute(0,3,1,2).to(device) 323 | scaled_idepth = midas.fixed_scale_depth(midas_in, 324 | return_inverse_depth=True, 325 | scale=renderer.scale) 326 | dms[i] = 1.0/scaled_idepth[0].cpu().numpy() 327 | 328 | # now switch to pytorch 329 | src_ims = torch.tensor(ims, dtype=torch.float32) 330 | src_dms = torch.tensor(dms, dtype=torch.float32) 331 | K = torch.tensor(K, dtype=torch.float32) 332 | 333 | src_ims = src_ims.to(device=device) 334 | src_dms = src_dms.to(device=device) 335 | K = K.to(device=device) 336 | 337 | K_cam = K.clone().detach() 338 | 339 | RENDERING = False 340 | DISPLAY_REC = True 341 | if DISPLAY_REC: 342 | rec_ims = renderer.reconstruct(src_ims) 343 | 344 | # init pygame 345 | b,h,w,c = src_ims.shape 346 | pygame.init() 347 | display = (w, h) 348 | surface = pygame.display.set_mode(display) 349 | clock = pygame.time.Clock() 350 | 351 | # init camera 352 | camera_pos = np.array([0.0, 0.0, 0.0]) 353 | camera_dir = np.array([0.0, 0.0, 1.0]) 354 | camera_up = np.array([0.0, 1.0, 0.0]) 355 | CAM_SPEED = 0.025 356 | CAM_SPEED_YAW = 0.5 357 | CAM_SPEED_PITCH = 0.25 358 | MOUSE_SENSITIVITY = 0.02 359 | USE_MOUSE = False 360 | if opt.model.startswith("ac"): 361 | CAM_SPEED *= 0.1 362 | CAM_SPEED_YAW *= 0.5 363 | CAM_SPEED_PITCH *= 0.5 364 | 365 | if opt.video is not None: 366 | writer = imageio.get_writer(opt.video, fps=40) 367 | 368 | step = 0 369 | step_PHASE = 0 370 | while True: 371 | ######## Boring stuff 372 | clock.tick(40) 373 | for event in pygame.event.get(): 374 | if event.type == pygame.QUIT: 375 | pygame.quit() 376 | quit() 377 | 378 | keys = pygame.key.get_pressed() 379 | if keys[pygame.K_q]: 380 | if opt.video is not None: 381 | writer.close() 382 | pygame.quit() 383 | quit() 384 | 385 | ######### Camera 386 | camera_yaw = 0 387 | camera_pitch = 0 388 | if keys[pygame.K_a]: 389 | camera_pos += CAM_SPEED*normalize(np.cross(camera_dir, camera_up)) 390 | if keys[pygame.K_d]: 391 | camera_pos -= CAM_SPEED*normalize(np.cross(camera_dir, camera_up)) 392 | if keys[pygame.K_w]: 393 | camera_pos += CAM_SPEED*normalize(camera_dir) 394 | if keys[pygame.K_s]: 395 | camera_pos -= CAM_SPEED*normalize(camera_dir) 396 | if keys[pygame.K_PAGEUP]: 397 | camera_pos -= CAM_SPEED*normalize(camera_up) 398 | if keys[pygame.K_PAGEDOWN]: 399 | camera_pos += CAM_SPEED*normalize(camera_up) 400 | 401 | if keys[pygame.K_LEFT]: 402 | camera_yaw += CAM_SPEED_YAW 403 | if keys[pygame.K_RIGHT]: 404 | camera_yaw -= CAM_SPEED_YAW 405 | if keys[pygame.K_UP]: 406 | camera_pitch -= CAM_SPEED_PITCH 407 | if keys[pygame.K_DOWN]: 408 | camera_pitch += CAM_SPEED_PITCH 409 | 410 | if USE_MOUSE: 411 | dx, dy = pygame.mouse.get_rel() 412 | if not RENDERING: 413 | camera_yaw -= MOUSE_SENSITIVITY*dx 414 | camera_pitch += MOUSE_SENSITIVITY*dy 415 | 416 | if keys[pygame.K_PLUS]: 417 | CAM_SPEED += 0.1 418 | print(CAM_SPEED) 419 | if keys[pygame.K_MINUS]: 420 | CAM_SPEED -= 0.1 421 | print(CAM_SPEED) 422 | 423 | if keys[pygame.K_m]: 424 | if not USE_MOUSE: 425 | pygame.mouse.set_visible(False) 426 | pygame.event.set_grab(True) 427 | USE_MOUSE = True 428 | else: 429 | pygame.mouse.set_visible(True) 430 | pygame.event.set_grab(False) 431 | USE_MOUSE = False 432 | 433 | # adjust for yaw and pitch 434 | rotation = np.array([[cosd(-camera_yaw), 0.0, sind(-camera_yaw)], 435 | [0.0, 1.0, 0.0], 436 | [-sind(-camera_yaw), 0.0, cosd(-camera_yaw)]]) 437 | camera_dir = rotation@camera_dir 438 | 439 | rotation = rotate_around_axis(camera_pitch, np.cross(camera_dir, 440 | camera_up)) 441 | camera_dir = rotation@camera_dir 442 | 443 | show_R, show_t = look_to(camera_pos, camera_dir, camera_up) # look from pos in direction dir 444 | show_R = torch.as_tensor(show_R, dtype=torch.float32) 445 | show_t = torch.as_tensor(show_t, dtype=torch.float32) 446 | 447 | ############# /Camera 448 | ###### control rendering 449 | if keys[pygame.K_SPACE]: 450 | RENDERING = True 451 | renderer.init(wrp_im, example, show_R, show_t) 452 | 453 | PRESSED = False 454 | if any(keys[k] for k in [pygame.K_a, pygame.K_d, pygame.K_w, 455 | pygame.K_s]): 456 | RENDERING = False 457 | 458 | # display 459 | if not RENDERING: 460 | with torch.no_grad(): 461 | wrp_im = render_forward(src_ims, src_dms, 462 | show_R, show_t, 463 | K_src=K, 464 | K_dst=K_cam) 465 | else: 466 | with torch.no_grad(): 467 | wrp_im = renderer() 468 | 469 | text = "Sampling" if renderer._active else None 470 | image, frame = to_surface(wrp_im, text) 471 | surface.blit(image, (0,0)) 472 | pygame.display.flip() 473 | if opt.video is not None: 474 | writer.append_data(frame) 475 | 476 | step +=1 477 | -------------------------------------------------------------------------------- /scripts/database.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | # This script is based on an original implementation by True Price. 33 | 34 | import sys 35 | import sqlite3 36 | import numpy as np 37 | 38 | 39 | IS_PYTHON3 = sys.version_info[0] >= 3 40 | 41 | MAX_IMAGE_ID = 2**31 - 1 42 | 43 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( 44 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 45 | model INTEGER NOT NULL, 46 | width INTEGER NOT NULL, 47 | height INTEGER NOT NULL, 48 | params BLOB, 49 | prior_focal_length INTEGER NOT NULL)""" 50 | 51 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( 52 | image_id INTEGER PRIMARY KEY NOT NULL, 53 | rows INTEGER NOT NULL, 54 | cols INTEGER NOT NULL, 55 | data BLOB, 56 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" 57 | 58 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( 59 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 60 | name TEXT NOT NULL UNIQUE, 61 | camera_id INTEGER NOT NULL, 62 | prior_qw REAL, 63 | prior_qx REAL, 64 | prior_qy REAL, 65 | prior_qz REAL, 66 | prior_tx REAL, 67 | prior_ty REAL, 68 | prior_tz REAL, 69 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}), 70 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id)) 71 | """.format(MAX_IMAGE_ID) 72 | 73 | CREATE_TWO_VIEW_GEOMETRIES_TABLE = """ 74 | CREATE TABLE IF NOT EXISTS two_view_geometries ( 75 | pair_id INTEGER PRIMARY KEY NOT NULL, 76 | rows INTEGER NOT NULL, 77 | cols INTEGER NOT NULL, 78 | data BLOB, 79 | config INTEGER NOT NULL, 80 | F BLOB, 81 | E BLOB, 82 | H BLOB) 83 | """ 84 | 85 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( 86 | image_id INTEGER PRIMARY KEY NOT NULL, 87 | rows INTEGER NOT NULL, 88 | cols INTEGER NOT NULL, 89 | data BLOB, 90 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE) 91 | """ 92 | 93 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( 94 | pair_id INTEGER PRIMARY KEY NOT NULL, 95 | rows INTEGER NOT NULL, 96 | cols INTEGER NOT NULL, 97 | data BLOB)""" 98 | 99 | CREATE_NAME_INDEX = \ 100 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" 101 | 102 | CREATE_ALL = "; ".join([ 103 | CREATE_CAMERAS_TABLE, 104 | CREATE_IMAGES_TABLE, 105 | CREATE_KEYPOINTS_TABLE, 106 | CREATE_DESCRIPTORS_TABLE, 107 | CREATE_MATCHES_TABLE, 108 | CREATE_TWO_VIEW_GEOMETRIES_TABLE, 109 | CREATE_NAME_INDEX 110 | ]) 111 | 112 | 113 | def image_ids_to_pair_id(image_id1, image_id2): 114 | if image_id1 > image_id2: 115 | image_id1, image_id2 = image_id2, image_id1 116 | return image_id1 * MAX_IMAGE_ID + image_id2 117 | 118 | 119 | def pair_id_to_image_ids(pair_id): 120 | image_id2 = pair_id % MAX_IMAGE_ID 121 | image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID 122 | return image_id1, image_id2 123 | 124 | 125 | def array_to_blob(array): 126 | if IS_PYTHON3: 127 | return array.tostring() 128 | else: 129 | return np.getbuffer(array) 130 | 131 | 132 | def blob_to_array(blob, dtype, shape=(-1,)): 133 | if IS_PYTHON3: 134 | return np.fromstring(blob, dtype=dtype).reshape(*shape) 135 | else: 136 | return np.frombuffer(blob, dtype=dtype).reshape(*shape) 137 | 138 | 139 | class COLMAPDatabase(sqlite3.Connection): 140 | 141 | @staticmethod 142 | def connect(database_path): 143 | return sqlite3.connect(database_path, factory=COLMAPDatabase) 144 | 145 | 146 | def __init__(self, *args, **kwargs): 147 | super(COLMAPDatabase, self).__init__(*args, **kwargs) 148 | 149 | self.create_tables = lambda: self.executescript(CREATE_ALL) 150 | self.create_cameras_table = \ 151 | lambda: self.executescript(CREATE_CAMERAS_TABLE) 152 | self.create_descriptors_table = \ 153 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) 154 | self.create_images_table = \ 155 | lambda: self.executescript(CREATE_IMAGES_TABLE) 156 | self.create_two_view_geometries_table = \ 157 | lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE) 158 | self.create_keypoints_table = \ 159 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE) 160 | self.create_matches_table = \ 161 | lambda: self.executescript(CREATE_MATCHES_TABLE) 162 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) 163 | 164 | def add_camera(self, model, width, height, params, 165 | prior_focal_length=False, camera_id=None): 166 | params = np.asarray(params, np.float64) 167 | cursor = self.execute( 168 | "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", 169 | (camera_id, model, width, height, array_to_blob(params), 170 | prior_focal_length)) 171 | return cursor.lastrowid 172 | 173 | def add_image(self, name, camera_id, 174 | prior_q=np.full(4, np.NaN), prior_t=np.full(3, np.NaN), image_id=None): 175 | cursor = self.execute( 176 | "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 177 | (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], 178 | prior_q[3], prior_t[0], prior_t[1], prior_t[2])) 179 | return cursor.lastrowid 180 | 181 | def add_keypoints(self, image_id, keypoints): 182 | assert(len(keypoints.shape) == 2) 183 | assert(keypoints.shape[1] in [2, 4, 6]) 184 | 185 | keypoints = np.asarray(keypoints, np.float32) 186 | self.execute( 187 | "INSERT INTO keypoints VALUES (?, ?, ?, ?)", 188 | (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) 189 | 190 | def add_descriptors(self, image_id, descriptors): 191 | descriptors = np.ascontiguousarray(descriptors, np.uint8) 192 | self.execute( 193 | "INSERT INTO descriptors VALUES (?, ?, ?, ?)", 194 | (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) 195 | 196 | def add_matches(self, image_id1, image_id2, matches): 197 | assert(len(matches.shape) == 2) 198 | assert(matches.shape[1] == 2) 199 | 200 | if image_id1 > image_id2: 201 | matches = matches[:,::-1] 202 | 203 | pair_id = image_ids_to_pair_id(image_id1, image_id2) 204 | matches = np.asarray(matches, np.uint32) 205 | self.execute( 206 | "INSERT INTO matches VALUES (?, ?, ?, ?)", 207 | (pair_id,) + matches.shape + (array_to_blob(matches),)) 208 | 209 | def add_two_view_geometry(self, image_id1, image_id2, matches, 210 | F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2): 211 | assert(len(matches.shape) == 2) 212 | assert(matches.shape[1] == 2) 213 | 214 | if image_id1 > image_id2: 215 | matches = matches[:,::-1] 216 | 217 | pair_id = image_ids_to_pair_id(image_id1, image_id2) 218 | matches = np.asarray(matches, np.uint32) 219 | F = np.asarray(F, dtype=np.float64) 220 | E = np.asarray(E, dtype=np.float64) 221 | H = np.asarray(H, dtype=np.float64) 222 | self.execute( 223 | "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)", 224 | (pair_id,) + matches.shape + (array_to_blob(matches), config, 225 | array_to_blob(F), array_to_blob(E), array_to_blob(H))) 226 | 227 | 228 | def example_usage(): 229 | import os 230 | import argparse 231 | 232 | parser = argparse.ArgumentParser() 233 | parser.add_argument("--database_path", default="database.db") 234 | args = parser.parse_args() 235 | 236 | if os.path.exists(args.database_path): 237 | print("ERROR: database path already exists -- will not modify it.") 238 | return 239 | 240 | # Open the database. 241 | 242 | db = COLMAPDatabase.connect(args.database_path) 243 | 244 | # For convenience, try creating all the tables upfront. 245 | 246 | db.create_tables() 247 | 248 | # Create dummy cameras. 249 | 250 | model1, width1, height1, params1 = \ 251 | 0, 1024, 768, np.array((1024., 512., 384.)) 252 | model2, width2, height2, params2 = \ 253 | 2, 1024, 768, np.array((1024., 512., 384., 0.1)) 254 | 255 | camera_id1 = db.add_camera(model1, width1, height1, params1) 256 | camera_id2 = db.add_camera(model2, width2, height2, params2) 257 | 258 | # Create dummy images. 259 | 260 | image_id1 = db.add_image("image1.png", camera_id1) 261 | image_id2 = db.add_image("image2.png", camera_id1) 262 | image_id3 = db.add_image("image3.png", camera_id2) 263 | image_id4 = db.add_image("image4.png", camera_id2) 264 | 265 | # Create dummy keypoints. 266 | # 267 | # Note that COLMAP supports: 268 | # - 2D keypoints: (x, y) 269 | # - 4D keypoints: (x, y, theta, scale) 270 | # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22) 271 | 272 | num_keypoints = 1000 273 | keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1) 274 | keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1) 275 | keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2) 276 | keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2) 277 | 278 | db.add_keypoints(image_id1, keypoints1) 279 | db.add_keypoints(image_id2, keypoints2) 280 | db.add_keypoints(image_id3, keypoints3) 281 | db.add_keypoints(image_id4, keypoints4) 282 | 283 | # Create dummy matches. 284 | 285 | M = 50 286 | matches12 = np.random.randint(num_keypoints, size=(M, 2)) 287 | matches23 = np.random.randint(num_keypoints, size=(M, 2)) 288 | matches34 = np.random.randint(num_keypoints, size=(M, 2)) 289 | 290 | db.add_matches(image_id1, image_id2, matches12) 291 | db.add_matches(image_id2, image_id3, matches23) 292 | db.add_matches(image_id3, image_id4, matches34) 293 | 294 | # Commit the data to the file. 295 | 296 | db.commit() 297 | 298 | # Read and check cameras. 299 | 300 | rows = db.execute("SELECT * FROM cameras") 301 | 302 | camera_id, model, width, height, params, prior = next(rows) 303 | params = blob_to_array(params, np.float64) 304 | assert camera_id == camera_id1 305 | assert model == model1 and width == width1 and height == height1 306 | assert np.allclose(params, params1) 307 | 308 | camera_id, model, width, height, params, prior = next(rows) 309 | params = blob_to_array(params, np.float64) 310 | assert camera_id == camera_id2 311 | assert model == model2 and width == width2 and height == height2 312 | assert np.allclose(params, params2) 313 | 314 | # Read and check keypoints. 315 | 316 | keypoints = dict( 317 | (image_id, blob_to_array(data, np.float32, (-1, 2))) 318 | for image_id, data in db.execute( 319 | "SELECT image_id, data FROM keypoints")) 320 | 321 | assert np.allclose(keypoints[image_id1], keypoints1) 322 | assert np.allclose(keypoints[image_id2], keypoints2) 323 | assert np.allclose(keypoints[image_id3], keypoints3) 324 | assert np.allclose(keypoints[image_id4], keypoints4) 325 | 326 | # Read and check matches. 327 | 328 | pair_ids = [image_ids_to_pair_id(*pair) for pair in 329 | ((image_id1, image_id2), 330 | (image_id2, image_id3), 331 | (image_id3, image_id4))] 332 | 333 | matches = dict( 334 | (pair_id_to_image_ids(pair_id), 335 | blob_to_array(data, np.uint32, (-1, 2))) 336 | for pair_id, data in db.execute("SELECT pair_id, data FROM matches") 337 | ) 338 | 339 | assert np.all(matches[(image_id1, image_id2)] == matches12) 340 | assert np.all(matches[(image_id2, image_id3)] == matches23) 341 | assert np.all(matches[(image_id3, image_id4)] == matches34) 342 | 343 | # Clean up. 344 | 345 | db.close() 346 | 347 | if os.path.exists(args.database_path): 348 | os.remove(args.database_path) 349 | 350 | 351 | if __name__ == "__main__": 352 | example_usage() 353 | 354 | -------------------------------------------------------------------------------- /scripts/download_vqmodels.py: -------------------------------------------------------------------------------- 1 | # convenience function to download pretrained vqmodels and run with provided training configs 2 | import os 3 | from geofree.util import get_local_path 4 | 5 | 6 | SYMLINK_MAP = { 7 | "re_first_stage": "pretrained_models/realestate_first_stage/last.ckpt", 8 | "re_depth_stage": "pretrained_models/realestate_depth_stage/last.ckpt", 9 | "ac_first_stage": "pretrained_models/acid_first_stage/last.ckpt", 10 | "ac_depth_stage": "pretrained_models/acid_depth_stage/last.ckpt", 11 | } 12 | 13 | 14 | def create_symlink(name, path): 15 | print(f"Creating symlink from {path} to {SYMLINK_MAP[name]}") 16 | os.makedirs("/".join(SYMLINK_MAP[name].split(os.sep)[:-1])) 17 | os.symlink(src=path, dst=SYMLINK_MAP[name], target_is_directory=False) 18 | 19 | 20 | if __name__ == "__main__": 21 | for model in SYMLINK_MAP: 22 | path = get_local_path(model) 23 | create_symlink(model, path) 24 | print("done.") 25 | 26 | -------------------------------------------------------------------------------- /scripts/sparse_from_realestate_format.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob, shutil, subprocess 2 | import numpy as np 3 | 4 | 5 | def pose_to_sparse(txt_src, img_src, spa_dst, DEBUG=False, exists_ok=True, 6 | worker_idx=None, world_size=None): 7 | opt = argparse.Namespace(txt_src=txt_src, img_src=img_src, spa_dst=spa_dst) 8 | 9 | assert os.path.exists(opt.txt_src), opt.txt_src 10 | assert os.path.exists(opt.img_src), opt.img_src 11 | 12 | if not os.path.exists(opt.spa_dst): 13 | os.makedirs(opt.spa_dst) 14 | else: 15 | if DEBUG: 16 | shutil.rmtree(opt.spa_dst) 17 | elif not exists_ok: 18 | print("Output directory exists, doing nothing") 19 | return 0 20 | 21 | 22 | txts = sorted(glob.glob(os.path.join(opt.txt_src, "*.txt"))) 23 | if DEBUG: print(txts) 24 | 25 | if worker_idx is not None and world_size is not None: 26 | txts = txts[worker_idx::world_size] 27 | 28 | if DEBUG: txts = txts[:1] 29 | failed = list() 30 | for txt in txts: 31 | vidid = os.path.splitext(os.path.split(txt)[1])[0] 32 | print(f"Processing {vidid}") 33 | 34 | if (os.path.exists(os.path.join(opt.spa_dst, vidid, "sparse", "cameras.bin")) and 35 | os.path.exists(os.path.join(opt.spa_dst, vidid, "sparse", "images.bin")) and 36 | os.path.exists(os.path.join(opt.spa_dst, vidid, "sparse", "points3D.bin"))): 37 | print("Found sparse model, skipping {}".format(vidid)) 38 | continue 39 | 40 | if os.path.exists(os.path.join(opt.spa_dst, vidid)): 41 | shutil.rmtree(os.path.join(opt.spa_dst, vidid)) 42 | print("Found partial output of previous run, removed {}".format(vidid)) 43 | 44 | 45 | # read camera poses for this sequence 46 | with open(txt, "r") as f: 47 | firstline = f.readline() 48 | 49 | if firstline.startswith("http"): 50 | if DEBUG: print("Ignoring first line.") 51 | skiprows = 1 52 | else: 53 | skiprows = 0 54 | 55 | vid_data = np.loadtxt(txt, skiprows=skiprows) 56 | if len(vid_data.shape) != 2: 57 | failed.append(vidid) 58 | print(f"Wrong txt format for {vidid}!") 59 | continue 60 | 61 | timestamps = vid_data[:,0].astype(np.int) 62 | if DEBUG: print(timestamps) 63 | filenames = [str(ts)+".png" for ts in timestamps] 64 | 65 | if not len(filenames) > 1: 66 | failed.append(vidid) 67 | print(f"Less than two frames, skipping {vidid}!") 68 | continue 69 | 70 | if not os.path.exists(os.path.join(opt.img_src, vidid)): 71 | failed.append(vidid) 72 | print(f"Could not find frames, skipping {vidid}!") 73 | continue 74 | 75 | if not len(glob.glob(os.path.join(opt.img_src, vidid, "*.png"))) == len(filenames): 76 | failed.append(vidid) 77 | print(f"Could not find all frames, skipping {vidid}!") 78 | continue 79 | 80 | if DEBUG: print(vid_data[0,1:]) 81 | K_params = vid_data[:,1:7] 82 | Ks = np.zeros((K_params.shape[0], 3, 3)) 83 | Ks[:,0,0] = K_params[:,0] 84 | Ks[:,1,1] = K_params[:,1] 85 | Ks[:,0,2] = K_params[:,2] 86 | Ks[:,1,2] = K_params[:,3] 87 | Ks[:,2,2] = 1 88 | assert (Ks[0,...]==Ks[1,...]).all() 89 | K = Ks[0] 90 | if DEBUG: print(K) 91 | 92 | Rts = vid_data[:,7:].reshape(-1, 3, 4) 93 | if DEBUG: print(Rts[0]) 94 | 95 | # given these intrinsics and extrinsics, find a sparse set of scale 96 | # consistent 3d points following 97 | # https://colmap.github.io/faq.html#reconstruct-sparse-dense-model-from-known-camera-poses 98 | 99 | # extract and match features on frames 100 | dst_dir = os.path.join(opt.spa_dst, vidid) 101 | os.makedirs(dst_dir) 102 | database_path = os.path.join(dst_dir, "database.db") 103 | 104 | # symlink images 105 | image_path = os.path.join(dst_dir, "images") 106 | os.symlink(os.path.abspath(os.path.join(opt.img_src, vidid)), image_path) 107 | 108 | cmd = ["colmap", "feature_extractor", 109 | "--database_path", database_path, 110 | "--image_path", image_path, 111 | "--ImageReader.camera_model", "PINHOLE", 112 | "--ImageReader.single_camera", "1", 113 | "--SiftExtraction.use_gpu", "1"] 114 | if DEBUG: print(" ".join(cmd)) 115 | subprocess.run(cmd, check=True) 116 | 117 | # read the database 118 | from database import COLMAPDatabase, blob_to_array, array_to_blob 119 | db = COLMAPDatabase.connect(database_path) 120 | 121 | # read and update camera 122 | ## https://colmap.github.io/cameras.html 123 | cam = db.execute("SELECT * FROM cameras").fetchone() 124 | camera_id = cam[0] 125 | camera_model = cam[1] 126 | assert camera_model == 1 # PINHOLE 127 | width = cam[2] 128 | height = cam[3] 129 | params = blob_to_array(cam[4], dtype=np.float64) 130 | assert len(params) == 4 # fx, fy, cx, cy for PINHOLE 131 | 132 | # adjust params 133 | params[0] = width*K[0,0] 134 | params[1] = height*K[1,1] 135 | params[2] = width*K[0,2] 136 | params[3] = height*K[1,2] 137 | 138 | # update 139 | db.execute("UPDATE cameras SET params = ? WHERE camera_id = ?", 140 | (array_to_blob(params), camera_id)) 141 | db.commit() 142 | 143 | # match features 144 | cmd = ["colmap", "sequential_matcher", 145 | "--database_path", database_path, 146 | "--SiftMatching.use_gpu", "1"] 147 | if DEBUG: print(" ".join(cmd)) 148 | subprocess.run(cmd, check=True) 149 | 150 | # triangulate 151 | ## prepare pose model 152 | ### https://colmap.github.io/format.html#text-format 153 | pose_dir = os.path.join(dst_dir, "pose") 154 | os.makedirs(pose_dir) 155 | cameras_txt = os.path.join(pose_dir, "cameras.txt") 156 | with open(cameras_txt, "w") as f: 157 | f.write("{} PINHOLE {} {} {}".format(camera_id, width, height, 158 | " ".join(["{:.2f}".format(p) for p in params]))) 159 | 160 | images_txt = os.path.join(pose_dir, "images.txt") 161 | # match image ids with filenames and export their extrinsics to images.txt 162 | images = db.execute("SELECT image_id, name, camera_id FROM images").fetchall() 163 | lines = list() 164 | for image in images: 165 | assert image[2] == camera_id 166 | image_id = image[0] 167 | image_name = image[1] 168 | image_idx = filenames.index(image_name) 169 | Rt = Rts[image_idx] 170 | R = Rt[:3,:3] 171 | t = Rt[:3,3] 172 | # convert R to quaternion 173 | from scipy.spatial.transform import Rotation 174 | Q = Rotation.from_matrix(R).as_quat() 175 | # from x,y,z,w to w,x,y,z 176 | line = " ".join(["{:.6f}".format(x) for x in [Q[3],Q[0],Q[1],Q[2],t[0],t[1],t[2]]]) 177 | line = "{} ".format(image_id)+line+" {} {}".format(camera_id, image_name) 178 | lines.append(line) 179 | lines.append("") # empty line for 3d points to be triangulated 180 | with open(images_txt, "w") as f: 181 | f.write("\n".join(lines)+"\n") 182 | 183 | # create empty points3D.txt 184 | points3D_txt = os.path.join(pose_dir, "points3D.txt") 185 | open(points3D_txt, "w").close() 186 | 187 | # run point_triangulator 188 | out_dir = os.path.join(dst_dir, "sparse") 189 | os.makedirs(out_dir) 190 | cmd = ["colmap", "point_triangulator", 191 | "--database_path", database_path, 192 | "--image_path", image_path, 193 | "--input_path", pose_dir, 194 | "--output_path", out_dir] 195 | result = subprocess.run(cmd) 196 | if result.returncode != 0: 197 | print(f"Triangulation failed for {vidid}!") 198 | failed.append(vidid) 199 | 200 | print("Failed sequences:") 201 | print("\n".join(failed)) 202 | print(f"Could not create sparse models for {len(failed)} sequences.") 203 | return len(txts) 204 | 205 | 206 | if __name__ == "__main__": 207 | parser = argparse.ArgumentParser(description='Process some integers.') 208 | parser.add_argument('--txt_src', type=str, 209 | help='path to directory containing .txt files of realestate format') 210 | parser.add_argument('--img_src', type=str, 211 | help='path to directory containing /.png frames') 212 | parser.add_argument('--spa_dst', type=str, 213 | help='path to directory to write sparse models into') 214 | parser.add_argument('--DEBUG', action="store_true", 215 | help='for quick development') 216 | parser.add_argument('--worker_idx', type=int, 217 | help='if world_size is specified, should be 0<=worker_idx 1 224 | assert opt.worker_idx is not None 225 | assert 0<=opt.worker_idx=2.0.0', 20 | 'pytorch-lightning>=1.0.8', 21 | 'pygame', 22 | 'splatting @ git+https://github.com/pesser/splatting@1427d7c4204282d117403b35698d489e0324287f#egg=splatting', 23 | 'einops', 24 | 'importlib-resources', 25 | 'imageio', 26 | 'imageio-ffmpeg', 27 | 'test-tube' 28 | ], 29 | ) 30 | --------------------------------------------------------------------------------