├── License.txt
├── README.md
├── assets
├── acid_long.mp4
├── acid_preview.gif
├── acid_short.mp4
├── firstpage.jpg
├── geofree_variants.png
├── realestate_long.mp4
├── realestate_preview.gif
├── realestate_short.mp4
└── rooms_scenic_01_wkr.jpg
├── configs
├── acid
│ ├── acid_13x23_expl_emb.yaml
│ ├── acid_13x23_expl_feat.yaml
│ ├── acid_13x23_expl_img.yaml
│ ├── acid_13x23_hybrid.yaml
│ ├── acid_13x23_impl_catdepth.yaml
│ ├── acid_13x23_impl_depth.yaml
│ └── acid_13x23_impl_nodepth.yaml
└── realestate
│ ├── realestate_13x23_expl_emb.yaml
│ ├── realestate_13x23_expl_feat.yaml
│ ├── realestate_13x23_expl_img.yaml
│ ├── realestate_13x23_hybrid.yaml
│ ├── realestate_13x23_impl_catdepth.yaml
│ ├── realestate_13x23_impl_depth.yaml
│ └── realestate_13x23_impl_nodepth.yaml
├── data
├── acid_custom_frames.txt
├── acid_train_sequences.txt
├── realestate_custom_frames.txt
└── realestate_train_sequences.txt
├── environment.yaml
├── geofree
├── __init__.py
├── data
│ ├── __init__.py
│ ├── acid.py
│ ├── read_write_model.py
│ └── realestate.py
├── examples
│ ├── __init__.py
│ ├── artist.jpg
│ └── beach.jpg
├── lr_scheduler.py
├── main.py
├── models
│ ├── __init__.py
│ ├── transformers
│ │ ├── __init__.py
│ │ ├── geogpt.py
│ │ ├── net2net.py
│ │ └── warpgpt.py
│ └── vqgan.py
├── modules
│ ├── __init__.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ └── model.py
│ ├── losses
│ │ ├── __init__.py
│ │ └── vqperceptual.py
│ ├── transformer
│ │ ├── __init__.py
│ │ ├── mingpt.py
│ │ └── warper.py
│ ├── util.py
│ ├── vqvae
│ │ ├── __init__.py
│ │ └── quantize.py
│ └── warp
│ │ ├── __init__.py
│ │ └── midas.py
└── util.py
├── scripts
├── braindance.ipynb
├── braindance.py
├── database.py
├── download_vqmodels.py
├── sparse_from_realestate_format.py
├── sparsify_acid.sh
├── sparsify_realestate.sh
└── strip_ckpt.py
└── setup.py
/License.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020 Robin Rombach and Patrick Esser and Björn Ommer
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
19 | OR OTHER DEALINGS IN THE SOFTWARE./
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Geometry-Free View Synthesis: Transformers and no 3D Priors
2 | 
3 |
4 | [**Geometry-Free View Synthesis: Transformers and no 3D Priors**](https://compvis.github.io/geometry-free-view-synthesis/)
5 | [Robin Rombach](https://github.com/rromb)\*,
6 | [Patrick Esser](https://github.com/pesser)\*,
7 | [Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
8 | \* equal contribution
9 |
10 | [arXiv](https://arxiv.org/abs/2104.07652) | [BibTeX](#bibtex) | [Colab](https://colab.research.google.com/github/CompVis/geometry-free-view-synthesis/blob/master/scripts/braindance.ipynb)
11 |
12 | ### Interactive Scene Exploration Results
13 |
14 | [RealEstate10K](https://google.github.io/realestate10k/):
15 | 
16 | Videos: [short (2min)](assets/realestate_short.mp4) / [long (12min)](assets/realestate_long.mp4)
17 |
18 | [ACID](https://infinite-nature.github.io/):
19 | 
20 | Videos: [short (2min)](assets/acid_short.mp4) / [long (9min)](assets/acid_long.mp4)
21 |
22 | ### Demo
23 |
24 | For a quickstart, you can try the [Colab
25 | demo](https://colab.research.google.com/github/CompVis/geometry-free-view-synthesis/blob/master/scripts/braindance.ipynb),
26 | but for a smoother experience we recommend installing the local demo as
27 | described below.
28 |
29 | #### Installation
30 |
31 | The demo requires building a PyTorch extension. If you have a sane development
32 | environment with PyTorch, g++ and nvcc, you can simply
33 |
34 | ```
35 | pip install git+https://github.com/CompVis/geometry-free-view-synthesis#egg=geometry-free-view-synthesis
36 | ```
37 |
38 | If you run into problems and have a GPU with compute capability below 8, you
39 | can also use the provided conda environment:
40 |
41 | ```
42 | git clone https://github.com/CompVis/geometry-free-view-synthesis
43 | conda env create -f geometry-free-view-synthesis/environment.yaml
44 | conda activate geofree
45 | pip install geometry-free-view-synthesis/
46 | ```
47 |
48 | #### Running
49 |
50 | After [installation](#installation), running
51 |
52 | ```
53 | braindance.py
54 | ```
55 |
56 | will start the demo on [a sample scene](http://walledoffhotel.com/rooms.html).
57 | Explore the scene interactively using the `WASD` keys to move and `arrow keys` to
58 | look around. Once positioned, hit the `space bar` to render the novel view with
59 | GeoGPT.
60 |
61 | You can move again with WASD keys. Mouse control can be activated with the m
62 | key. Run `braindance.py ` to run the
63 | demo on your own images. By default, it uses the `re-impl-nodepth` (trained on
64 | RealEstate without explicit transformation and no depth input) which can be
65 | changed with the `--model` flag. The corresponding checkpoints will be
66 | downloaded the first time they are required. Specify an output path using
67 | `--video path/to/vid.mp4` to record a video.
68 |
69 | ```
70 | > braindance.py -h
71 | usage: braindance.py [-h] [--model {re_impl_nodepth,re_impl_depth,ac_impl_nodepth,ac_impl_depth}] [--video [VIDEO]] [path]
72 |
73 | What's up, BD-maniacs?
74 |
75 | key(s) action
76 | =====================================
77 | wasd move around
78 | arrows look around
79 | m enable looking with mouse
80 | space render with transformer
81 | q quit
82 |
83 | positional arguments:
84 | path path to image or directory from which to select image. Default example is used if not specified.
85 |
86 | optional arguments:
87 | -h, --help show this help message and exit
88 | --model {re_impl_nodepth,re_impl_depth,ac_impl_nodepth,ac_impl_depth}
89 | pretrained model to use.
90 | --video [VIDEO] path to write video recording to. (no recording if unspecified).
91 | ```
92 |
93 | ## Training
94 |
95 | ### Data Preparation
96 |
97 | We support training on [RealEstate10K](https://google.github.io/realestate10k/)
98 | and [ACID](https://infinite-nature.github.io/). Both come in the same [format as
99 | described here](https://google.github.io/realestate10k/download.html) and the
100 | preparation is the same for both of them. You will need to have
101 | [`colmap`](https://github.com/colmap/colmap) installed and available on your
102 | `$PATH`.
103 |
104 | We assume that you have extracted the `.txt` files of the dataset you want to
105 | prepare into `$TXT_ROOT`, e.g. for RealEstate:
106 |
107 | ```
108 | > tree $TXT_ROOT
109 | ├── test
110 | │ ├── 000c3ab189999a83.txt
111 | │ ├── ...
112 | │ └── fff9864727c42c80.txt
113 | └── train
114 | ├── 0000cc6d8b108390.txt
115 | ├── ...
116 | └── ffffe622a4de5489.txt
117 | ```
118 |
119 | and that you have downloaded the frames (we downloaded them in resolution `640
120 | x 360`) into `$IMG_ROOT`, e.g. for RealEstate:
121 |
122 | ```
123 | > tree $IMG_ROOT
124 | ├── test
125 | │ ├── 000c3ab189999a83
126 | │ │ ├── 45979267.png
127 | │ │ ├── ...
128 | │ │ └── 55255200.png
129 | │ ├── ...
130 | │ ├── 0017ce4c6a39d122
131 | │ │ ├── 40874000.png
132 | │ │ ├── ...
133 | │ │ └── 48482000.png
134 | ├── train
135 | │ ├── ...
136 | ```
137 |
138 | To prepare the `$SPLIT` split of the dataset (`$SPLIT` being one of `train`,
139 | `test` for RealEstate and `train`, `test`, `validation` for ACID) in
140 | `$SPA_ROOT`, run the following within the `scripts` directory:
141 |
142 | ```
143 | python sparse_from_realestate_format.py --txt_src ${TXT_ROOT}/${SPLIT} --img_src ${IMG_ROOT}/${SPLIT} --spa_dst ${SPA_ROOT}/${SPLIT}
144 | ```
145 |
146 | You can also simply set `TXT_ROOT`, `IMG_ROOT` and `SPA_ROOT` as environment
147 | variables and run `./sparsify_realestate.sh` or `./sparsify_acid.sh`. Take a
148 | look into the sources to run with multiple workers in parallel.
149 |
150 | Finally, symlink `$SPA_ROOT` to `data/realestate_sparse`/`data/acid_sparse`.
151 |
152 | ### First Stage Models
153 | As described in [our paper](https://arxiv.org/abs/2104.07652), we train the transformer models in
154 | a compressed, discrete latent space of pretrained VQGANs. These pretrained models can be conveniently
155 | downloaded by running
156 | ```
157 | python scripts/download_vqmodels.py
158 | ```
159 | which will also create symlinks ensuring that the paths specified in the training configs (see `configs/*`) exist.
160 | In case some of the models have already been downloaded, the script will only create the symlinks.
161 |
162 | For training custom first stage models, we refer to the [taming transformers
163 | repository](https://github.com/CompVis/taming-transformers).
164 |
165 | ### Running the Training
166 | After both the preparation of the data and the first stage models are done,
167 | the experiments on ACID and RealEstate10K as described in our paper can be reproduced by running
168 | ```
169 | python geofree/main.py --base configs//_13x23_.yaml -t --gpus 0,
170 | ```
171 | where `` is one of `realestate`/`acid` and `` is one of
172 | `expl_img`/`expl_feat`/`expl_emb`/`impl_catdepth`/`impl_depth`/`impl_nodepth`/`hybrid`.
173 | These abbreviations correspond to the experiments listed in the following Table (see also Fig.2 in the main paper)
174 |
175 | 
176 |
177 | Note that each experiment was conducted on a GPU with 40 GB VRAM.
178 |
179 | ## BibTeX
180 |
181 | ```
182 | @misc{rombach2021geometryfree,
183 | title={Geometry-Free View Synthesis: Transformers and no 3D Priors},
184 | author={Robin Rombach and Patrick Esser and Björn Ommer},
185 | year={2021},
186 | eprint={2104.07652},
187 | archivePrefix={arXiv},
188 | primaryClass={cs.CV}
189 | }
190 | ```
191 |
--------------------------------------------------------------------------------
/assets/acid_long.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/acid_long.mp4
--------------------------------------------------------------------------------
/assets/acid_preview.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/acid_preview.gif
--------------------------------------------------------------------------------
/assets/acid_short.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/acid_short.mp4
--------------------------------------------------------------------------------
/assets/firstpage.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/firstpage.jpg
--------------------------------------------------------------------------------
/assets/geofree_variants.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/geofree_variants.png
--------------------------------------------------------------------------------
/assets/realestate_long.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/realestate_long.mp4
--------------------------------------------------------------------------------
/assets/realestate_preview.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/realestate_preview.gif
--------------------------------------------------------------------------------
/assets/realestate_short.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/realestate_short.mp4
--------------------------------------------------------------------------------
/assets/rooms_scenic_01_wkr.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/assets/rooms_scenic_01_wkr.jpg
--------------------------------------------------------------------------------
/configs/acid/acid_13x23_expl_emb.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.warpgpt.WarpTransformer
4 | params:
5 | plot_cond_stage: True
6 | monitor: "val/loss"
7 |
8 | use_scheduler: True
9 | scheduler_config:
10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
11 | params:
12 | verbosity_interval: 0 # 0 or negative to disable
13 | warm_up_steps: 5000
14 | max_decay_steps: 500001
15 | lr_start: 2.5e-6
16 | lr_max: 1.5e-4
17 | lr_min: 1.0e-8
18 |
19 | transformer_config:
20 | target: geofree.modules.transformer.mingpt.WarpGPT
21 | params:
22 | vocab_size: 16384
23 | block_size: 597 # conditioning + 299 - 1
24 | n_unmasked: 299 # 299 cond embeddings
25 | n_layer: 32
26 | n_head: 16
27 | n_embd: 1024
28 | warper_config:
29 | target: geofree.modules.transformer.warper.ConvWarper
30 | params:
31 | size: [13, 23]
32 |
33 | first_stage_config:
34 | target: geofree.models.vqgan.VQModel
35 | params:
36 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
37 | embed_dim: 256
38 | n_embed: 16384
39 | ddconfig:
40 | double_z: False
41 | z_channels: 256
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
47 | num_res_blocks: 2
48 | attn_resolutions: [ 16 ]
49 | dropout: 0.0
50 | lossconfig:
51 | target: geofree.modules.losses.vqperceptual.DummyLoss
52 |
53 | cond_stage_config: "__is_first_stage__"
54 |
55 | data:
56 | target: geofree.main.DataModuleFromConfig
57 | params:
58 | # bs 8 and accumulate_grad_batches 2 for 34gb vram
59 | batch_size: 8
60 | num_workers: 16
61 | train:
62 | target: geofree.data.acid.ACIDSparseTrain
63 | params:
64 | size:
65 | - 208
66 | - 368
67 |
68 | validation:
69 | target: geofree.data.acid.ACIDCustomTest
70 | params:
71 | size:
72 | - 208
73 | - 368
74 |
75 | lightning:
76 | trainer:
77 | accumulate_grad_batches: 2
78 | benchmark: True
79 |
--------------------------------------------------------------------------------
/configs/acid/acid_13x23_expl_feat.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.net2net.WarpingFeatureTransformer
4 | params:
5 | plot_cond_stage: True
6 | monitor: "val/loss"
7 |
8 | use_scheduler: True
9 | scheduler_config:
10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
11 | params:
12 | verbosity_interval: 0 # 0 or negative to disable
13 | warm_up_steps: 5000
14 | max_decay_steps: 500001
15 | lr_start: 2.5e-6
16 | lr_max: 1.5e-4
17 | lr_min: 1.0e-8
18 |
19 | transformer_config:
20 | target: geofree.modules.transformer.mingpt.GPT
21 | params:
22 | vocab_size: 16384
23 | block_size: 597 # conditioning + 299 - 1
24 | n_unmasked: 299 # 299 cond embeddings
25 | n_layer: 32
26 | n_head: 16
27 | n_embd: 1024
28 |
29 | first_stage_key:
30 | x: "dst_img"
31 |
32 | cond_stage_key:
33 | c: "src_img"
34 | points: "src_points"
35 | R: "R_rel"
36 | t: "t_rel"
37 | K: "K"
38 | K_inv: "K_inv"
39 |
40 | first_stage_config:
41 | target: geofree.models.vqgan.VQModel
42 | params:
43 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
44 | embed_dim: 256
45 | n_embed: 16384
46 | ddconfig:
47 | double_z: False
48 | z_channels: 256
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
54 | num_res_blocks: 2
55 | attn_resolutions: [ 16 ]
56 | dropout: 0.0
57 | lossconfig:
58 | target: geofree.modules.losses.vqperceptual.DummyLoss
59 |
60 | cond_stage_config: "__is_first_stage__"
61 |
62 | data:
63 | target: geofree.main.DataModuleFromConfig
64 | params:
65 | # bs 8 and accumulate_grad_batches 2 for 34gb vram
66 | batch_size: 8
67 | num_workers: 16
68 | train:
69 | target: geofree.data.acid.ACIDSparseTrain
70 | params:
71 | size:
72 | - 208
73 | - 368
74 |
75 | validation:
76 | target: geofree.data.acid.ACIDCustomTest
77 | params:
78 | size:
79 | - 208
80 | - 368
81 |
82 | lightning:
83 | trainer:
84 | accumulate_grad_batches: 2
85 | benchmark: True
86 |
--------------------------------------------------------------------------------
/configs/acid/acid_13x23_expl_img.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.net2net.WarpingTransformer
4 | params:
5 | plot_cond_stage: True
6 | monitor: "val/loss"
7 |
8 | use_scheduler: True
9 | scheduler_config:
10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
11 | params:
12 | verbosity_interval: 0 # 0 or negative to disable
13 | warm_up_steps: 5000
14 | max_decay_steps: 500001
15 | lr_start: 2.5e-6
16 | lr_max: 1.5e-4
17 | lr_min: 1.0e-8
18 |
19 | transformer_config:
20 | target: geofree.modules.transformer.mingpt.GPT
21 | params:
22 | vocab_size: 16384
23 | block_size: 597 # conditioning + 299 - 1
24 | n_unmasked: 299 # 299 cond embeddings
25 | n_layer: 32
26 | n_head: 16
27 | n_embd: 1024
28 |
29 | first_stage_config:
30 | target: geofree.models.vqgan.VQModel
31 | params:
32 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
33 | embed_dim: 256
34 | n_embed: 16384
35 | ddconfig:
36 | double_z: False
37 | z_channels: 256
38 | resolution: 256
39 | in_channels: 3
40 | out_ch: 3
41 | ch: 128
42 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
43 | num_res_blocks: 2
44 | attn_resolutions: [ 16 ]
45 | dropout: 0.0
46 | lossconfig:
47 | target: geofree.modules.losses.vqperceptual.DummyLoss
48 |
49 | data:
50 | target: geofree.main.DataModuleFromConfig
51 | params:
52 | # bs 8 and accumulate_grad_batches 2 for 34gb vram
53 | batch_size: 8
54 | num_workers: 16
55 | train:
56 | target: geofree.data.acid.ACIDSparseTrain
57 | params:
58 | size:
59 | - 208
60 | - 368
61 |
62 | validation:
63 | target: geofree.data.acid.ACIDCustomTest
64 | params:
65 | size:
66 | - 208
67 | - 368
68 |
69 | lightning:
70 | trainer:
71 | accumulate_grad_batches: 2
72 | benchmark: True
73 |
--------------------------------------------------------------------------------
/configs/acid/acid_13x23_hybrid.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.geogpt.WarpGeoTransformer
4 | params:
5 | merge_channels: 512 # channels of cond vq + depth vq
6 |
7 | plot_cond_stage: True
8 | monitor: "val/loss"
9 |
10 | use_scheduler: True
11 | scheduler_config:
12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
13 | params:
14 | verbosity_interval: 0 # 0 or negative to disable
15 | warm_up_steps: 5000
16 | max_decay_steps: 500001
17 | lr_start: 2.5e-6
18 | lr_max: 1.5e-4
19 | lr_min: 1.0e-8
20 |
21 | transformer_config:
22 | target: geofree.modules.transformer.mingpt.WarpGPT
23 | params:
24 | vocab_size: 16384
25 | block_size: 627 # conditioning + 299 - 1
26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings
27 | n_layer: 32
28 | n_head: 16
29 | n_embd: 1024
30 | warper_config:
31 | target: geofree.modules.transformer.warper.ConvWarper
32 | params:
33 | start_idx: 30 # do not warp camera embeddings
34 | size: [13, 23]
35 |
36 | first_stage_key:
37 | x: "dst_img"
38 |
39 | cond_stage_key:
40 | c: "src_img"
41 |
42 | emb_stage_key:
43 | points: "src_points"
44 | R: "R_rel"
45 | t: "t_rel"
46 | K: "K"
47 | K_inv: "K_inv"
48 |
49 | first_stage_config:
50 | target: geofree.models.vqgan.VQModel
51 | params:
52 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
53 | embed_dim: 256
54 | n_embed: 16384
55 | ddconfig:
56 | double_z: False
57 | z_channels: 256
58 | resolution: 256
59 | in_channels: 3
60 | out_ch: 3
61 | ch: 128
62 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
63 | num_res_blocks: 2
64 | attn_resolutions: [ 16 ]
65 | dropout: 0.0
66 | lossconfig:
67 | target: geofree.modules.losses.vqperceptual.DummyLoss
68 |
69 | cond_stage_config: "__is_first_stage__"
70 |
71 | depth_stage_config:
72 | target: geofree.models.vqgan.VQModel
73 | params:
74 | ckpt_path: "pretrained_models/acid_depth_stage/last.ckpt"
75 | embed_dim: 256
76 | n_embed: 1024
77 | ddconfig:
78 | double_z: false
79 | z_channels: 256
80 | resolution: 256
81 | in_channels: 1
82 | out_ch: 1
83 | ch: 128
84 | ch_mult:
85 | - 1
86 | - 1
87 | - 2
88 | - 2
89 | - 4
90 | num_res_blocks: 2
91 | attn_resolutions:
92 | - 16
93 | dropout: 0.0
94 | lossconfig:
95 | target: geofree.modules.losses.vqperceptual.DummyLoss
96 |
97 | emb_stage_config:
98 | target: geofree.modules.util.MultiEmbedder
99 | params:
100 | keys:
101 | - "R"
102 | - "t"
103 | - "K"
104 | - "K_inv"
105 | n_positions: 30
106 | n_channels: 1
107 | n_embed: 1024
108 | bias: False
109 |
110 | data:
111 | target: geofree.main.DataModuleFromConfig
112 | params:
113 | # bs 8 and accumulate_grad_batches 2 for 36gb vram
114 | batch_size: 8
115 | num_workers: 16
116 | train:
117 | target: geofree.data.acid.ACIDSparseTrain
118 | params:
119 | size:
120 | - 208
121 | - 368
122 |
123 | validation:
124 | target: geofree.data.acid.ACIDCustomTest
125 | params:
126 | size:
127 | - 208
128 | - 368
129 |
130 | lightning:
131 | trainer:
132 | accumulate_grad_batches: 2
133 | benchmark: True
134 |
--------------------------------------------------------------------------------
/configs/acid/acid_13x23_impl_catdepth.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.geogpt.GeoTransformer
4 | params:
5 | merge_channels: null
6 |
7 | plot_cond_stage: True
8 | monitor: "val/loss"
9 |
10 | use_scheduler: True
11 | scheduler_config:
12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
13 | params:
14 | verbosity_interval: 0 # 0 or negative to disable
15 | warm_up_steps: 5000
16 | max_decay_steps: 500001
17 | lr_start: 2.5e-6
18 | lr_max: 1.5e-4
19 | lr_min: 1.0e-8
20 |
21 | transformer_config:
22 | target: geofree.modules.transformer.mingpt.GPT
23 | params:
24 | vocab_size: 16384
25 | block_size: 926 # conditioning + 299 - 1
26 | n_unmasked: 628 # 30 camera embeddings + 299 depth and 299 cond embeddings
27 | n_layer: 32
28 | n_head: 16
29 | n_embd: 1024
30 |
31 | first_stage_key:
32 | x: "dst_img"
33 |
34 | cond_stage_key:
35 | c: "src_img"
36 |
37 | emb_stage_key:
38 | points: "src_points"
39 | R: "R_rel"
40 | t: "t_rel"
41 | K: "K"
42 | K_inv: "K_inv"
43 |
44 | first_stage_config:
45 | target: geofree.models.vqgan.VQModel
46 | params:
47 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
48 | embed_dim: 256
49 | n_embed: 16384
50 | ddconfig:
51 | double_z: False
52 | z_channels: 256
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
58 | num_res_blocks: 2
59 | attn_resolutions: [ 16 ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: geofree.modules.losses.vqperceptual.DummyLoss
63 |
64 | cond_stage_config: "__is_first_stage__"
65 |
66 | depth_stage_config:
67 | target: geofree.models.vqgan.VQModel
68 | params:
69 | ckpt_path: "pretrained_models/acid_depth_stage/last.ckpt"
70 | embed_dim: 256
71 | n_embed: 1024
72 | ddconfig:
73 | double_z: false
74 | z_channels: 256
75 | resolution: 256
76 | in_channels: 1
77 | out_ch: 1
78 | ch: 128
79 | ch_mult:
80 | - 1
81 | - 1
82 | - 2
83 | - 2
84 | - 4
85 | num_res_blocks: 2
86 | attn_resolutions:
87 | - 16
88 | dropout: 0.0
89 | lossconfig:
90 | target: geofree.modules.losses.vqperceptual.DummyLoss
91 |
92 | emb_stage_config:
93 | target: geofree.modules.util.MultiEmbedder
94 | params:
95 | keys:
96 | - "R"
97 | - "t"
98 | - "K"
99 | - "K_inv"
100 | n_positions: 30
101 | n_channels: 1
102 | n_embed: 1024
103 | bias: False
104 |
105 | data:
106 | target: geofree.main.DataModuleFromConfig
107 | params:
108 | # bs 4 and accumulate_grad_batches 4 for 34gb vram
109 | batch_size: 4
110 | num_workers: 8
111 | train:
112 | target: geofree.data.acid.ACIDSparseTrain
113 | params:
114 | size:
115 | - 208
116 | - 368
117 |
118 | validation:
119 | target: geofree.data.acid.ACIDCustomTest
120 | params:
121 | size:
122 | - 208
123 | - 368
124 |
125 | lightning:
126 | trainer:
127 | accumulate_grad_batches: 4
128 | benchmark: True
129 |
--------------------------------------------------------------------------------
/configs/acid/acid_13x23_impl_depth.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.geogpt.GeoTransformer
4 | params:
5 | merge_channels: 512 # channels of cond vq + depth vq
6 |
7 | plot_cond_stage: True
8 | monitor: "val/loss"
9 |
10 | use_scheduler: True
11 | scheduler_config:
12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
13 | params:
14 | verbosity_interval: 0 # 0 or negative to disable
15 | warm_up_steps: 5000
16 | max_decay_steps: 500001
17 | lr_start: 2.5e-6
18 | lr_max: 1.5e-4
19 | lr_min: 1.0e-8
20 |
21 | transformer_config:
22 | target: geofree.modules.transformer.mingpt.GPT
23 | params:
24 | vocab_size: 16384
25 | block_size: 627 # conditioning + 299 - 1
26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings
27 | n_layer: 32
28 | n_head: 16
29 | n_embd: 1024
30 |
31 | first_stage_key:
32 | x: "dst_img"
33 |
34 | cond_stage_key:
35 | c: "src_img"
36 |
37 | emb_stage_key:
38 | points: "src_points"
39 | R: "R_rel"
40 | t: "t_rel"
41 | K: "K"
42 | K_inv: "K_inv"
43 |
44 | first_stage_config:
45 | target: geofree.models.vqgan.VQModel
46 | params:
47 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
48 | embed_dim: 256
49 | n_embed: 16384
50 | ddconfig:
51 | double_z: False
52 | z_channels: 256
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
58 | num_res_blocks: 2
59 | attn_resolutions: [ 16 ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: geofree.modules.losses.vqperceptual.DummyLoss
63 |
64 | cond_stage_config: "__is_first_stage__"
65 |
66 | depth_stage_config:
67 | target: geofree.models.vqgan.VQModel
68 | params:
69 | ckpt_path: "pretrained_models/acid_depth_stage/last.ckpt"
70 | embed_dim: 256
71 | n_embed: 1024
72 | ddconfig:
73 | double_z: false
74 | z_channels: 256
75 | resolution: 256
76 | in_channels: 1
77 | out_ch: 1
78 | ch: 128
79 | ch_mult:
80 | - 1
81 | - 1
82 | - 2
83 | - 2
84 | - 4
85 | num_res_blocks: 2
86 | attn_resolutions:
87 | - 16
88 | dropout: 0.0
89 | lossconfig:
90 | target: geofree.modules.losses.vqperceptual.DummyLoss
91 |
92 | emb_stage_config:
93 | target: geofree.modules.util.MultiEmbedder
94 | params:
95 | keys:
96 | - "R"
97 | - "t"
98 | - "K"
99 | - "K_inv"
100 | n_positions: 30
101 | n_channels: 1
102 | n_embed: 1024
103 | bias: False
104 |
105 | data:
106 | target: geofree.main.DataModuleFromConfig
107 | params:
108 | # bs 8 and accumulate_grad_batches 2 for 36gb vram
109 | batch_size: 8
110 | num_workers: 16
111 | train:
112 | target: geofree.data.acid.ACIDSparseTrain
113 | params:
114 | size:
115 | - 208
116 | - 368
117 |
118 | validation:
119 | target: geofree.data.acid.ACIDCustomTest
120 | params:
121 | size:
122 | - 208
123 | - 368
124 |
125 | lightning:
126 | trainer:
127 | accumulate_grad_batches: 2
128 | benchmark: True
129 |
--------------------------------------------------------------------------------
/configs/acid/acid_13x23_impl_nodepth.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.geogpt.GeoTransformer
4 | params:
5 | use_depth: False # depth is not provided to transformer but only used to rescale t
6 |
7 | plot_cond_stage: True
8 | monitor: "val/loss"
9 |
10 | use_scheduler: True
11 | scheduler_config:
12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
13 | params:
14 | verbosity_interval: 0 # 0 or negative to disable
15 | warm_up_steps: 5000
16 | max_decay_steps: 500001
17 | lr_start: 2.5e-6
18 | lr_max: 1.5e-4
19 | lr_min: 1.0e-8
20 |
21 | transformer_config:
22 | target: geofree.modules.transformer.mingpt.GPT
23 | params:
24 | vocab_size: 16384
25 | block_size: 627 # conditioning + 299 - 1
26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings
27 | n_layer: 32
28 | n_head: 16
29 | n_embd: 1024
30 |
31 | first_stage_key:
32 | x: "dst_img"
33 |
34 | cond_stage_key:
35 | c: "src_img"
36 |
37 | emb_stage_key:
38 | points: "src_points"
39 | R: "R_rel"
40 | t: "t_rel"
41 | K: "K"
42 | K_inv: "K_inv"
43 |
44 | first_stage_config:
45 | target: geofree.models.vqgan.VQModel
46 | params:
47 | ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
48 | embed_dim: 256
49 | n_embed: 16384
50 | ddconfig:
51 | double_z: False
52 | z_channels: 256
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
58 | num_res_blocks: 2
59 | attn_resolutions: [ 16 ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: geofree.modules.losses.vqperceptual.DummyLoss
63 |
64 | cond_stage_config: "__is_first_stage__"
65 |
66 | depth_stage_config:
67 | target: geofree.models.vqgan.VQModel
68 | params:
69 | ckpt_path: "pretrained_models/acid_depth_stage/last.ckpt"
70 | embed_dim: 256
71 | n_embed: 1024
72 | ddconfig:
73 | double_z: false
74 | z_channels: 256
75 | resolution: 256
76 | in_channels: 1
77 | out_ch: 1
78 | ch: 128
79 | ch_mult:
80 | - 1
81 | - 1
82 | - 2
83 | - 2
84 | - 4
85 | num_res_blocks: 2
86 | attn_resolutions:
87 | - 16
88 | dropout: 0.0
89 | lossconfig:
90 | target: geofree.modules.losses.vqperceptual.DummyLoss
91 |
92 | emb_stage_config:
93 | target: geofree.modules.util.MultiEmbedder
94 | params:
95 | keys:
96 | - "R"
97 | - "t"
98 | - "K"
99 | - "K_inv"
100 | n_positions: 30
101 | n_channels: 1
102 | n_embed: 1024
103 | bias: False
104 |
105 | data:
106 | target: geofree.main.DataModuleFromConfig
107 | params:
108 | # bs 8 and accumulate_grad_batches 2 for 36gb vram
109 | batch_size: 8
110 | num_workers: 16
111 | train:
112 | target: geofree.data.acid.ACIDSparseTrain
113 | params:
114 | size:
115 | - 208
116 | - 368
117 |
118 | validation:
119 | target: geofree.data.acid.ACIDCustomTest
120 | params:
121 | size:
122 | - 208
123 | - 368
124 |
125 | lightning:
126 | trainer:
127 | accumulate_grad_batches: 2
128 | benchmark: True
129 |
--------------------------------------------------------------------------------
/configs/realestate/realestate_13x23_expl_emb.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.warpgpt.WarpTransformer
4 | params:
5 | plot_cond_stage: True
6 | monitor: "val/loss"
7 |
8 | use_scheduler: True
9 | scheduler_config:
10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
11 | params:
12 | verbosity_interval: 0 # 0 or negative to disable
13 | warm_up_steps: 5000
14 | max_decay_steps: 500001
15 | lr_start: 2.5e-6
16 | lr_max: 1.5e-4
17 | lr_min: 1.0e-8
18 |
19 | transformer_config:
20 | target: geofree.modules.transformer.mingpt.WarpGPT
21 | params:
22 | vocab_size: 16384
23 | block_size: 597 # conditioning + 299 - 1
24 | n_unmasked: 299 # 299 cond embeddings
25 | n_layer: 32
26 | n_head: 16
27 | n_embd: 1024
28 | warper_config:
29 | target: geofree.modules.transformer.warper.ConvWarper
30 | params:
31 | size: [13, 23]
32 |
33 | first_stage_config:
34 | target: geofree.models.vqgan.VQModel
35 | params:
36 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt"
37 | embed_dim: 256
38 | n_embed: 16384
39 | ddconfig:
40 | double_z: False
41 | z_channels: 256
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
47 | num_res_blocks: 2
48 | attn_resolutions: [ 16 ]
49 | dropout: 0.0
50 | lossconfig:
51 | target: geofree.modules.losses.vqperceptual.DummyLoss
52 |
53 | cond_stage_config: "__is_first_stage__"
54 |
55 | data:
56 | target: geofree.main.DataModuleFromConfig
57 | params:
58 | # bs 8 and accumulate_grad_batches 2 for 34gb vram
59 | batch_size: 8
60 | num_workers: 16
61 | train:
62 | target: geofree.data.realestate.RealEstate10KSparseTrain
63 | params:
64 | size:
65 | - 208
66 | - 368
67 |
68 | validation:
69 | target: geofree.data.realestate.RealEstate10KCustomTest
70 | params:
71 | size:
72 | - 208
73 | - 368
74 |
75 | lightning:
76 | trainer:
77 | accumulate_grad_batches: 2
78 | benchmark: True
79 |
--------------------------------------------------------------------------------
/configs/realestate/realestate_13x23_expl_feat.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.net2net.WarpingFeatureTransformer
4 | params:
5 | plot_cond_stage: True
6 | monitor: "val/loss"
7 |
8 | use_scheduler: True
9 | scheduler_config:
10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
11 | params:
12 | verbosity_interval: 0 # 0 or negative to disable
13 | warm_up_steps: 5000
14 | max_decay_steps: 500001
15 | lr_start: 2.5e-6
16 | lr_max: 1.5e-4
17 | lr_min: 1.0e-8
18 |
19 | transformer_config:
20 | target: geofree.modules.transformer.mingpt.GPT
21 | params:
22 | vocab_size: 16384
23 | block_size: 597 # conditioning + 299 - 1
24 | n_unmasked: 299 # 299 cond embeddings
25 | n_layer: 32
26 | n_head: 16
27 | n_embd: 1024
28 |
29 | first_stage_key:
30 | x: "dst_img"
31 |
32 | cond_stage_key:
33 | c: "src_img"
34 | points: "src_points"
35 | R: "R_rel"
36 | t: "t_rel"
37 | K: "K"
38 | K_inv: "K_inv"
39 |
40 | first_stage_config:
41 | target: geofree.models.vqgan.VQModel
42 | params:
43 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt"
44 | embed_dim: 256
45 | n_embed: 16384
46 | ddconfig:
47 | double_z: False
48 | z_channels: 256
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
54 | num_res_blocks: 2
55 | attn_resolutions: [ 16 ]
56 | dropout: 0.0
57 | lossconfig:
58 | target: geofree.modules.losses.vqperceptual.DummyLoss
59 |
60 | cond_stage_config: "__is_first_stage__"
61 |
62 | data:
63 | target: geofree.main.DataModuleFromConfig
64 | params:
65 | # bs 8 and accumulate_grad_batches 2 for 34gb vram
66 | batch_size: 8
67 | num_workers: 16
68 | train:
69 | target: geofree.data.realestate.RealEstate10KSparseTrain
70 | params:
71 | size:
72 | - 208
73 | - 368
74 |
75 | validation:
76 | target: geofree.data.realestate.RealEstate10KCustomTest
77 | params:
78 | size:
79 | - 208
80 | - 368
81 |
82 | lightning:
83 | trainer:
84 | accumulate_grad_batches: 2
85 | benchmark: True
86 |
--------------------------------------------------------------------------------
/configs/realestate/realestate_13x23_expl_img.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.net2net.WarpingTransformer
4 | params:
5 | plot_cond_stage: True
6 | monitor: "val/loss"
7 |
8 | use_scheduler: True
9 | scheduler_config:
10 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
11 | params:
12 | verbosity_interval: 0 # 0 or negative to disable
13 | warm_up_steps: 5000
14 | max_decay_steps: 500001
15 | lr_start: 2.5e-6
16 | lr_max: 1.5e-4
17 | lr_min: 1.0e-8
18 |
19 | transformer_config:
20 | target: geofree.modules.transformer.mingpt.GPT
21 | params:
22 | vocab_size: 16384
23 | block_size: 597 # conditioning + 299 - 1
24 | n_unmasked: 299 # 299 cond embeddings
25 | n_layer: 32
26 | n_head: 16
27 | n_embd: 1024
28 |
29 | first_stage_config:
30 | target: geofree.models.vqgan.VQModel
31 | params:
32 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt"
33 | embed_dim: 256
34 | n_embed: 16384
35 | ddconfig:
36 | double_z: False
37 | z_channels: 256
38 | resolution: 256
39 | in_channels: 3
40 | out_ch: 3
41 | ch: 128
42 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
43 | num_res_blocks: 2
44 | attn_resolutions: [ 16 ]
45 | dropout: 0.0
46 | lossconfig:
47 | target: geofree.modules.losses.vqperceptual.DummyLoss
48 |
49 | data:
50 | target: geofree.main.DataModuleFromConfig
51 | params:
52 | # bs 8 and accumulate_grad_batches 2 for 34gb vram
53 | batch_size: 8
54 | num_workers: 16
55 | train:
56 | target: geofree.data.realestate.RealEstate10KSparseTrain
57 | params:
58 | size:
59 | - 208
60 | - 368
61 |
62 | validation:
63 | target: geofree.data.realestate.RealEstate10KCustomTest
64 | params:
65 | size:
66 | - 208
67 | - 368
68 |
69 | lightning:
70 | trainer:
71 | accumulate_grad_batches: 2
72 | benchmark: True
73 |
--------------------------------------------------------------------------------
/configs/realestate/realestate_13x23_hybrid.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.geogpt.WarpGeoTransformer
4 | params:
5 | merge_channels: 512 # channels of cond vq + depth vq
6 |
7 | plot_cond_stage: True
8 | monitor: "val/loss"
9 |
10 | use_scheduler: True
11 | scheduler_config:
12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
13 | params:
14 | verbosity_interval: 0 # 0 or negative to disable
15 | warm_up_steps: 5000
16 | max_decay_steps: 500001
17 | lr_start: 2.5e-6
18 | lr_max: 1.5e-4
19 | lr_min: 1.0e-8
20 |
21 | transformer_config:
22 | target: geofree.modules.transformer.mingpt.WarpGPT
23 | params:
24 | vocab_size: 16384
25 | block_size: 627 # conditioning + 299 - 1
26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings
27 | n_layer: 32
28 | n_head: 16
29 | n_embd: 1024
30 | warper_config:
31 | target: geofree.modules.transformer.warper.ConvWarper
32 | params:
33 | start_idx: 30 # do not warp camera embeddings
34 | size: [13, 23]
35 |
36 | first_stage_key:
37 | x: "dst_img"
38 |
39 | cond_stage_key:
40 | c: "src_img"
41 |
42 | emb_stage_key:
43 | points: "src_points"
44 | R: "R_rel"
45 | t: "t_rel"
46 | K: "K"
47 | K_inv: "K_inv"
48 |
49 | first_stage_config:
50 | target: geofree.models.vqgan.VQModel
51 | params:
52 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt"
53 | embed_dim: 256
54 | n_embed: 16384
55 | ddconfig:
56 | double_z: False
57 | z_channels: 256
58 | resolution: 256
59 | in_channels: 3
60 | out_ch: 3
61 | ch: 128
62 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
63 | num_res_blocks: 2
64 | attn_resolutions: [ 16 ]
65 | dropout: 0.0
66 | lossconfig:
67 | target: geofree.modules.losses.vqperceptual.DummyLoss
68 |
69 | cond_stage_config: "__is_first_stage__"
70 |
71 | depth_stage_config:
72 | target: geofree.models.vqgan.VQModel
73 | params:
74 | ckpt_path: "pretrained_models/realestate_depth_stage/last.ckpt"
75 | embed_dim: 256
76 | n_embed: 1024
77 | ddconfig:
78 | double_z: false
79 | z_channels: 256
80 | resolution: 256
81 | in_channels: 1
82 | out_ch: 1
83 | ch: 128
84 | ch_mult:
85 | - 1
86 | - 1
87 | - 2
88 | - 2
89 | - 4
90 | num_res_blocks: 2
91 | attn_resolutions:
92 | - 16
93 | dropout: 0.0
94 | lossconfig:
95 | target: geofree.modules.losses.vqperceptual.DummyLoss
96 |
97 | emb_stage_config:
98 | target: geofree.modules.util.MultiEmbedder
99 | params:
100 | keys:
101 | - "R"
102 | - "t"
103 | - "K"
104 | - "K_inv"
105 | n_positions: 30
106 | n_channels: 1
107 | n_embed: 1024
108 | bias: False
109 |
110 | data:
111 | target: geofree.main.DataModuleFromConfig
112 | params:
113 | # bs 8 and accumulate_grad_batches 2 for 36gb vram
114 | batch_size: 8
115 | num_workers: 16
116 | train:
117 | target: geofree.data.realestate.RealEstate10KSparseTrain
118 | params:
119 | size:
120 | - 208
121 | - 368
122 |
123 | validation:
124 | target: geofree.data.realestate.RealEstate10KCustomTest
125 | params:
126 | size:
127 | - 208
128 | - 368
129 |
130 | lightning:
131 | trainer:
132 | accumulate_grad_batches: 2
133 | benchmark: True
134 |
--------------------------------------------------------------------------------
/configs/realestate/realestate_13x23_impl_catdepth.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.geogpt.GeoTransformer
4 | params:
5 | merge_channels: null
6 |
7 | plot_cond_stage: True
8 | monitor: "val/loss"
9 |
10 | use_scheduler: True
11 | scheduler_config:
12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
13 | params:
14 | verbosity_interval: 0 # 0 or negative to disable
15 | warm_up_steps: 5000
16 | max_decay_steps: 500001
17 | lr_start: 2.5e-6
18 | lr_max: 1.5e-4
19 | lr_min: 1.0e-8
20 |
21 | transformer_config:
22 | target: geofree.modules.transformer.mingpt.GPT
23 | params:
24 | vocab_size: 16384
25 | block_size: 926 # conditioning + 299 - 1
26 | n_unmasked: 628 # 30 camera embeddings + 299 depth and 299 cond embeddings
27 | n_layer: 32
28 | n_head: 16
29 | n_embd: 1024
30 |
31 | first_stage_key:
32 | x: "dst_img"
33 |
34 | cond_stage_key:
35 | c: "src_img"
36 |
37 | emb_stage_key:
38 | points: "src_points"
39 | R: "R_rel"
40 | t: "t_rel"
41 | K: "K"
42 | K_inv: "K_inv"
43 |
44 | first_stage_config:
45 | target: geofree.models.vqgan.VQModel
46 | params:
47 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt"
48 | embed_dim: 256
49 | n_embed: 16384
50 | ddconfig:
51 | double_z: False
52 | z_channels: 256
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
58 | num_res_blocks: 2
59 | attn_resolutions: [ 16 ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: geofree.modules.losses.vqperceptual.DummyLoss
63 |
64 | cond_stage_config: "__is_first_stage__"
65 |
66 | depth_stage_config:
67 | target: geofree.models.vqgan.VQModel
68 | params:
69 | ckpt_path: "pretrained_models/realestate_depth_stage/last.ckpt"
70 | embed_dim: 256
71 | n_embed: 1024
72 | ddconfig:
73 | double_z: false
74 | z_channels: 256
75 | resolution: 256
76 | in_channels: 1
77 | out_ch: 1
78 | ch: 128
79 | ch_mult:
80 | - 1
81 | - 1
82 | - 2
83 | - 2
84 | - 4
85 | num_res_blocks: 2
86 | attn_resolutions:
87 | - 16
88 | dropout: 0.0
89 | lossconfig:
90 | target: geofree.modules.losses.vqperceptual.DummyLoss
91 |
92 | emb_stage_config:
93 | target: geofree.modules.util.MultiEmbedder
94 | params:
95 | keys:
96 | - "R"
97 | - "t"
98 | - "K"
99 | - "K_inv"
100 | n_positions: 30
101 | n_channels: 1
102 | n_embed: 1024
103 | bias: False
104 |
105 | data:
106 | target: geofree.main.DataModuleFromConfig
107 | params:
108 | # bs 4 and accumulate_grad_batches 4 for 34gb vram
109 | batch_size: 4
110 | num_workers: 8
111 | train:
112 | target: geofree.data.realestate.RealEstate10KSparseTrain
113 | params:
114 | size:
115 | - 208
116 | - 368
117 |
118 | validation:
119 | target: geofree.data.realestate.RealEstate10KCustomTest
120 | params:
121 | size:
122 | - 208
123 | - 368
124 |
125 | lightning:
126 | trainer:
127 | accumulate_grad_batches: 4
128 | benchmark: True
129 |
--------------------------------------------------------------------------------
/configs/realestate/realestate_13x23_impl_depth.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.geogpt.GeoTransformer
4 | params:
5 | merge_channels: 512 # channels of cond vq + depth vq
6 |
7 | plot_cond_stage: True
8 | monitor: "val/loss"
9 |
10 | use_scheduler: True
11 | scheduler_config:
12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
13 | params:
14 | verbosity_interval: 0 # 0 or negative to disable
15 | warm_up_steps: 5000
16 | max_decay_steps: 500001
17 | lr_start: 2.5e-6
18 | lr_max: 1.5e-4
19 | lr_min: 1.0e-8
20 |
21 | transformer_config:
22 | target: geofree.modules.transformer.mingpt.GPT
23 | params:
24 | vocab_size: 16384
25 | block_size: 627 # conditioning + 299 - 1
26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings
27 | n_layer: 32
28 | n_head: 16
29 | n_embd: 1024
30 |
31 | first_stage_key:
32 | x: "dst_img"
33 |
34 | cond_stage_key:
35 | c: "src_img"
36 |
37 | emb_stage_key:
38 | points: "src_points"
39 | R: "R_rel"
40 | t: "t_rel"
41 | K: "K"
42 | K_inv: "K_inv"
43 |
44 | first_stage_config:
45 | target: geofree.models.vqgan.VQModel
46 | params:
47 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt"
48 | embed_dim: 256
49 | n_embed: 16384
50 | ddconfig:
51 | double_z: False
52 | z_channels: 256
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
58 | num_res_blocks: 2
59 | attn_resolutions: [ 16 ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: geofree.modules.losses.vqperceptual.DummyLoss
63 |
64 | cond_stage_config: "__is_first_stage__"
65 |
66 | depth_stage_config:
67 | target: geofree.models.vqgan.VQModel
68 | params:
69 | ckpt_path: "pretrained_models/realestate_depth_stage/last.ckpt"
70 | embed_dim: 256
71 | n_embed: 1024
72 | ddconfig:
73 | double_z: false
74 | z_channels: 256
75 | resolution: 256
76 | in_channels: 1
77 | out_ch: 1
78 | ch: 128
79 | ch_mult:
80 | - 1
81 | - 1
82 | - 2
83 | - 2
84 | - 4
85 | num_res_blocks: 2
86 | attn_resolutions:
87 | - 16
88 | dropout: 0.0
89 | lossconfig:
90 | target: geofree.modules.losses.vqperceptual.DummyLoss
91 |
92 | emb_stage_config:
93 | target: geofree.modules.util.MultiEmbedder
94 | params:
95 | keys:
96 | - "R"
97 | - "t"
98 | - "K"
99 | - "K_inv"
100 | n_positions: 30
101 | n_channels: 1
102 | n_embed: 1024
103 | bias: False
104 |
105 | data:
106 | target: geofree.main.DataModuleFromConfig
107 | params:
108 | # bs 8 and accumulate_grad_batches 2 for 36gb vram
109 | batch_size: 8
110 | num_workers: 16
111 | train:
112 | target: geofree.data.realestate.RealEstate10KSparseTrain
113 | params:
114 | size:
115 | - 208
116 | - 368
117 |
118 | validation:
119 | target: geofree.data.realestate.RealEstate10KCustomTest
120 | params:
121 | size:
122 | - 208
123 | - 368
124 |
125 | lightning:
126 | trainer:
127 | accumulate_grad_batches: 2
128 | benchmark: True
129 |
--------------------------------------------------------------------------------
/configs/realestate/realestate_13x23_impl_nodepth.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0625
3 | target: geofree.models.transformers.geogpt.GeoTransformer
4 | params:
5 | use_depth: False # depth is not provided to transformer but only used to rescale t
6 |
7 | plot_cond_stage: True
8 | monitor: "val/loss"
9 |
10 | use_scheduler: True
11 | scheduler_config:
12 | target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
13 | params:
14 | verbosity_interval: 0 # 0 or negative to disable
15 | warm_up_steps: 5000
16 | max_decay_steps: 500001
17 | lr_start: 2.5e-6
18 | lr_max: 1.5e-4
19 | lr_min: 1.0e-8
20 |
21 | transformer_config:
22 | target: geofree.modules.transformer.mingpt.GPT
23 | params:
24 | vocab_size: 16384
25 | block_size: 627 # conditioning + 299 - 1
26 | n_unmasked: 329 # 30 camera embeddings + 299 merged cond and depth embeddings
27 | n_layer: 32
28 | n_head: 16
29 | n_embd: 1024
30 |
31 | first_stage_key:
32 | x: "dst_img"
33 |
34 | cond_stage_key:
35 | c: "src_img"
36 |
37 | emb_stage_key:
38 | points: "src_points"
39 | R: "R_rel"
40 | t: "t_rel"
41 | K: "K"
42 | K_inv: "K_inv"
43 |
44 | first_stage_config:
45 | target: geofree.models.vqgan.VQModel
46 | params:
47 | ckpt_path: "pretrained_models/realestate_first_stage/last.ckpt"
48 | embed_dim: 256
49 | n_embed: 16384
50 | ddconfig:
51 | double_z: False
52 | z_channels: 256
53 | resolution: 256
54 | in_channels: 3
55 | out_ch: 3
56 | ch: 128
57 | ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
58 | num_res_blocks: 2
59 | attn_resolutions: [ 16 ]
60 | dropout: 0.0
61 | lossconfig:
62 | target: geofree.modules.losses.vqperceptual.DummyLoss
63 |
64 | cond_stage_config: "__is_first_stage__"
65 |
66 | depth_stage_config:
67 | target: geofree.models.vqgan.VQModel
68 | params:
69 | ckpt_path: "pretrained_models/realestate_depth_stage/last.ckpt"
70 | embed_dim: 256
71 | n_embed: 1024
72 | ddconfig:
73 | double_z: false
74 | z_channels: 256
75 | resolution: 256
76 | in_channels: 1
77 | out_ch: 1
78 | ch: 128
79 | ch_mult:
80 | - 1
81 | - 1
82 | - 2
83 | - 2
84 | - 4
85 | num_res_blocks: 2
86 | attn_resolutions:
87 | - 16
88 | dropout: 0.0
89 | lossconfig:
90 | target: geofree.modules.losses.vqperceptual.DummyLoss
91 |
92 | emb_stage_config:
93 | target: geofree.modules.util.MultiEmbedder
94 | params:
95 | keys:
96 | - "R"
97 | - "t"
98 | - "K"
99 | - "K_inv"
100 | n_positions: 30
101 | n_channels: 1
102 | n_embed: 1024
103 | bias: False
104 |
105 | data:
106 | target: geofree.main.DataModuleFromConfig
107 | params:
108 | # bs 8 and accumulate_grad_batches 2 for 36gb vram
109 | batch_size: 8
110 | num_workers: 16
111 | train:
112 | target: geofree.data.realestate.RealEstate10KSparseTrain
113 | params:
114 | size:
115 | - 208
116 | - 368
117 |
118 | validation:
119 | target: geofree.data.realestate.RealEstate10KCustomTest
120 | params:
121 | size:
122 | - 208
123 | - 368
124 |
125 | lightning:
126 | trainer:
127 | accumulate_grad_batches: 2
128 | benchmark: True
129 |
--------------------------------------------------------------------------------
/data/acid_custom_frames.txt:
--------------------------------------------------------------------------------
1 | d0126a869518014c,27933333.png,30600000.png,33366667.png
2 | d080086d1e293422,1096195000.png,1092558000.png,1087886000.png
3 | d1512dba43b6ebd8,188555033.png,192625767.png,197163633.png
4 | d179d50f84f4b83f,217880000.png,224280000.png,231080000.png
5 | d19c9fe809c57c92,22822800.png,25692333.png,28962267.png
6 | d3738aa86d5eda8e,80583333.png,74125000.png,66916667.png
7 | d3835cc526e338bc,71871800.png,75842433.png,79846433.png
8 | d47a3786e17ee61c,47747000.png,53220000.png,58359000.png
9 | d60884a98e4608c7,157790967.png,164531033.png,171738233.png
10 | d674021d935b083e,270671000.png,279946000.png,289389000.png
11 | d68495d47bf40a86,254633000.png,258133000.png,260633000.png
12 | d6bdf9d5236d4b9e,131965167.png,143309833.png,155488667.png
13 | d6cf2995d0ee0f30,509308000.png,518318000.png,527260000.png
14 | d78eccc35066d590,21271250.png,26276250.png,31698333.png
15 | d7a050a81422bdcb,1168167000.png,1183916067.png,1199464933.png
16 | d7e53e00d33b4935,105005000.png,102503000.png,100033000.png
17 | d7fca566a7941495,4456218433.png,4440269167.png,4424686933.png
18 | d82270f5d9180bb6,323657000.png,328395000.png,332799000.png
19 | d8409fae7feda328,42542500.png,50383667.png,56598208.png
20 | d884a63dd53b6919,162863000.png,178378000.png,193794000.png
21 | d90a8043e1ebfeae,1234200000.png,1226392000.png,1216315000.png
22 | d93af9b55cfac3d6,71571000.png,79546000.png,87220000.png
23 | d94b14b1000175be,620586000.png,625291000.png,629896000.png
24 | d994ce5969d83614,876475600.png,891957733.png,907940367.png
25 | da346e30050c6c00,176376000.png,181114000.png,185452000.png
26 | daa02bdc0b99c679,66153844.png,69720278.png,73286711.png
27 | daafe4cf1afae0ac,54687967.png,64697967.png,75108367.png
28 | daea6ab83805891a,183483300.png,196496300.png,209576033.png
29 | db18bbcefe49f79f,142709233.png,145945800.png,149482667.png
30 | db5184b7af9fd1c2,47200000.png,50233333.png,53200000.png
31 | db593b5b5fa8324a,141908000.png,144377000.png,146646000.png
32 | db9b1d074c875cfe,59226000.png,65632000.png,71972000.png
33 | dbd581a742408adf,307015000.png,315190000.png,323156000.png
34 | dc4002fc40c73bcc,188922000.png,180814000.png,173973000.png
35 | dc5f1c50fbd89f2a,201068000.png,202436000.png,203903000.png
36 | dcc8e9a3ec4195e1,30800000.png,27866667.png,24733333.png
37 | dd0c70b67523bfe3,700967000.png,716382000.png,731831000.png
38 | dd3339c318949f11,80046000.png,80981000.png,81882000.png
39 | de813f957e093bd9,31322958.png,35076708.png,38747042.png
40 | ded5b92ff8802e34,4318747767.png,4303432467.png,4287583300.png
41 | e02ba6089a8e6996,319919000.png,313447000.png,307374000.png
42 | e0b669c5a9af8487,74008000.png,76676000.png,79179000.png
43 | e0e9f1803a42d9a6,2860024000.png,2865562000.png,2871235000.png
44 | e0fb6bc955dc2827,71866667.png,70533333.png,69133333.png
45 | e0ff100eaf2ff841,18585000.png,25092000.png,31564000.png
46 | e196d71034cf55dd,170003000.png,175675000.png,180881000.png
47 | e22ba3ffad552f90,33767067.png,48281567.png,61027633.png
48 | e27a6391e3fced31,198656000.png,212587000.png,224516000.png
49 | e28fb40231cc8e31,27894000.png,31164000.png,34468000.png
50 | e33da2efe5ad29ac,1386251533.png,1401166433.png,1416748667.png
51 | e3475cb70a67e1aa,190490300.png,196129267.png,201467933.png
52 | e355f7aaf7d08fc3,164289122.png,165915756.png,167500667.png
53 | e36cd137792f2711,72439033.png,84084000.png,99999900.png
54 | e3792317ae266d83,14047367.png,18885533.png,24057367.png
55 | e3c9b29aaacf1fba,58758000.png,63230000.png,67601000.png
56 | e3d4f7fbbcd3fc3b,12554208.png,13596917.png,14764750.png
57 | e3ef559c3137aa36,2161392000.png,2169934000.png,2178009000.png
58 | e40594059ec1a2a8,69904000.png,73040000.png,77577000.png
59 | e52993cdbcf75b65,109818044.png,120537089.png,131840044.png
60 | e5310a4dc228ec5b,36803433.png,41408033.png,45745700.png
61 | e614b4670c543a72,114966667.png,119566667.png,124466667.png
62 | e668f185a5ba7c38,395028000.png,401801000.png,408875000.png
63 | e70e7383f837f2d4,1274573300.png,1277542933.png,1280579300.png
64 | e71fbf373f35269e,71471467.png,75508833.png,79145800.png
65 | e7443b934b5f4b0a,103603500.png,109209100.png,113980533.png
66 | e74ca33c36a51184,665397462.png,666332665.png,667267869.png
67 | e75b4dd9975c6f71,31200000.png,32733333.png,34500000.png
68 | e7a51eee636a0ba5,110510000.png,114081000.png,117817000.png
69 | e831e23cd75ab6c6,100000000.png,110010000.png,121321000.png
70 | e8426e39152cce71,106205000.png,109042000.png,111645000.png
71 | e86c6167b214bb99,34902000.png,27794000.png,20254000.png
72 | e88f65da150600b1,100033333.png,103000000.png,106133333.png
73 | e8a2758fe5401c9a,738911156.png,741482966.png,744255177.png
74 | e93d2814054ac90c,151584767.png,154287467.png,156789967.png
75 | e9828e82fe988c70,2544709000.png,2538836000.png,2532663000.png
76 | e9d627308856bc62,45278000.png,34268000.png,31398000.png
77 | e9eb29c8c46a8dfe,154554400.png,159792967.png,165431933.png
78 | ea02488ece6242e8,36469767.png,31231200.png,26292933.png
79 | ea36b0ce652cd8b4,155955000.png,158458000.png,161128000.png
80 | eab2bd6815838f8e,2293758000.png,2277141000.png,2261926000.png
81 | eabe884be6f7e517,158066667.png,161200000.png,164300000.png
82 | ebb0c513c19ea84f,144042000.png,152875000.png,163208000.png
83 | ecec6359aa6acbcf,213313100.png,218785233.png,223957067.png
84 | ed01ba5660fd718a,304666667.png,308166667.png,311708333.png
85 | ed0e643389fb2c0f,869469000.png,880880000.png,892792000.png
86 | ed29587ce3e9c577,441308000.png,457291000.png,472506000.png
87 | ed3c89742143039a,44477767.png,47547500.png,50784067.png
88 | edeb8a878709b716,72640000.png,68168000.png,63663000.png
89 | edf5fca1c9f4cb6c,269168900.png,271638033.png,274173900.png
90 | ee09e048af8deba6,17751067.png,30730700.png,42642600.png
91 | ee32b574f99ff150,155155000.png,152277125.png,149649500.png
92 | ee4b7adccd9ecd5c,202667000.png,207125000.png,211917000.png
93 | ee69da79dbad7bb3,70894789.png,73412467.png,76209889.png
94 | eea426ba67142e9b,149416000.png,150651000.png,151852000.png
95 | ef0753c88ee52b6f,105480000.png,109400000.png,113400000.png
96 | ef1f6e769ebc1c7a,208700000.png,211266667.png,214000000.png
97 | efaf202f7c09123c,56347958.png,60018292.png,63313250.png
98 | f02d43a4dabe789d,85518000.png,92425000.png,99799000.png
99 | f16b9ff7838b81c6,52733000.png,55100000.png,57166000.png
100 | f1beb221d42d1b84,359459100.png,356022333.png,352485467.png
101 | f23cd2826eff1565,30230200.png,31064367.png,31798433.png
102 | f2a34c7254940f53,494827000.png,487854000.png,478811000.png
103 | f2aa477766df229c,189166667.png,190900000.png,192533333.png
104 | f31b3bb307daf84a,51250000.png,53083000.png,54833000.png
105 | f3218100d129be59,62729000.png,66800000.png,69269000.png
106 | f33cad5fe9bd65f9,14916667.png,12916667.png,10916667.png
107 | f4421fc161eb69f3,89189100.png,94194100.png,99999900.png
108 | f4ce75753de06c58,66366667.png,75800000.png,86066667.png
109 | f4e0e97f02bcb1e3,245946000.png,247881000.png,249850000.png
110 | f51d8789db1f722c,179500000.png,173800000.png,167900000.png
111 | f58db3b7a2ce9cc0,1125724600.png,1128393933.png,1132064267.png
112 | f5ae92dd3d7e87e0,604303000.png,618385000.png,634434000.png
113 | f5ef3487cdcee3fb,33433000.png,36003000.png,38639000.png
114 | f6996b1f345e96b4,153954000.png,159960000.png,166065000.png
115 | f77008b83e767032,1734933000.png,1750449000.png,1764996000.png
116 | f7f7b65716d33d83,53219833.png,53787067.png,54320933.png
117 | f830083184521d34,60293627.png,65865866.png,71938605.png
118 | f84181beccbd5e13,451984867.png,454654200.png,457523733.png
119 | f8c932575ce9f2a5,47115385.png,50923077.png,54730769.png
120 | f8e283943d00a04d,215649000.png,221855000.png,228995000.png
121 | f91e3039f88375a3,88388300.png,90190100.png,91925167.png
122 | f93cd4d2a4313d38,55989267.png,61761700.png,68201467.png
123 | f976c57572a6ce0e,4217413200.png,4232695133.png,4249144900.png
124 | f98981baa6b8ee65,2185617000.png,2195660000.png,2207472000.png
125 | fa0d910e19270413,228061000.png,232800000.png,238104000.png
126 | fa329483307d59ac,109576000.png,117317000.png,125525000.png
127 | fa35a303ae5d4be6,2936767000.png,2951649000.png,2959323000.png
128 | faf3bb76a633f435,1004703000.png,1002368000.png,1000099000.png
129 | fb155e5f56868f1b,604479479.png,608775442.png,613446780.png
130 | fb54fffc370d1b0a,527994133.png,539372167.png,549682467.png
131 | fb9c37c89297d619,1316148167.png,1331029700.png,1347045700.png
132 | fbc7bf83e0ae84d3,13646967.png,15081733.png,16349667.png
133 | fc00edbd1d936174,38433333.png,43266667.png,47866667.png
134 | fc9bf0a83f1891e8,305680000.png,309851000.png,314397000.png
135 | fd012e8d736a6c8e,38505133.png,41875167.png,45512133.png
136 | fd2f9b4c485544f3,50050000.png,50984000.png,51919000.png
137 | fd3acb362205d6eb,52085000.png,54420000.png,56923000.png
138 | fd9d98306e9bf7d0,19186000.png,22022000.png,22556000.png
139 | fe0428833b293005,15448767.png,24524500.png,33333300.png
140 | ff85e9318de5f4bf,69002267.png,75241833.png,81948533.png
141 | ffc56e43145f3ae5,105305200.png,109843067.png,114848067.png
142 |
--------------------------------------------------------------------------------
/data/realestate_custom_frames.txt:
--------------------------------------------------------------------------------
1 | f04880903462e8f0,106806000.png,109609000.png,112646000.png
2 | f1171dbc746a3cd6,152919433.png,154621133.png,156256100.png
3 | f1226d5483f1c1cd,96029267.png,97931167.png,99999900.png
4 | f14627d2fc3158d4,56756700.png,60894167.png,65532133.png
5 | f1623ec1d552408c,87220467.png,89222467.png,91291200.png
6 | f22301abb4ab062c,299032000.png,300333000.png,301734000.png
7 | f2309b1e3f43d138,271738133.png,274841233.png,279145533.png
8 | f253339fa19467f5,307574000.png,308809000.png,312346000.png
9 | f2736f39d6213d22,140940000.png,141975000.png,143076000.png
10 | f289b5f7136bba85,46546500.png,50650600.png,55355300.png
11 | f3236d9eb1996c58,140073267.png,142142000.png,144677867.png
12 | f32643006f1260b9,85085000.png,89155733.png,92492400.png
13 | f3534478d06ca677,194727867.png,197096900.png,199532667.png
14 | f372706f15279f2d,212779000.png,215415000.png,218419000.png
15 | f408fd1ecf9639dd,41583000.png,43291000.png,44999000.png
16 | f412f8fa0ea298ff,253553000.png,255755000.png,258892000.png
17 | f413f880cdafe53c,226226000.png,228161267.png,230096533.png
18 | f43420f8216f180b,134300833.png,132732600.png,131097633.png
19 | f4769a32a225c172,178444000.png,182582000.png,187253000.png
20 | f478d5f55e41f4f7,273072800.png,276743133.png,279979700.png
21 | f48829b917629fe0,42642000.png,47113000.png,50650000.png
22 | f5257718f0aae70d,114547767.png,115648867.png,118351567.png
23 | f533f7e3ca2b82b2,255355100.png,259425833.png,264264000.png
24 | f536430fb223d623,290957333.png,292792500.png,294594300.png
25 | f546b4bb3f06e5ad,279613000.png,283216000.png,286319000.png
26 | f591e1921f2c850c,154020533.png,156489667.png,158958800.png
27 | f6057c281758b5f6,104104000.png,108641000.png,113113000.png
28 | f60664acd06e22aa,129729600.png,132098633.png,135268467.png
29 | f626135a90d68e85,177577400.png,179412567.png,181381200.png
30 | f649244a6907838c,52218833.png,54587867.png,57624233.png
31 | f654d9cb6bc8869e,260426833.png,261060800.png,262328733.png
32 | f673068196024955,196196000.png,197865000.png,199733000.png
33 | f6884e5c27acbafe,184384200.png,186219367.png,187921067.png
34 | f714ccd51adc29a8,31398000.png,34234000.png,37037000.png
35 | f74139ac48f19b3c,172005167.png,176543033.png,179712867.png
36 | f793e374e29bbd16,145745600.png,148481667.png,151484667.png
37 | f795469d3856697e,188021167.png,191558033.png,194727867.png
38 | f8545ff1c0ebc410,34900000.png,37833333.png,40466667.png
39 | f882aaab13d8d7a6,186719000.png,189756000.png,192826000.png
40 | f91749843a42fa3e,247247000.png,249649400.png,252185267.png
41 | f954f61234d49919,121254000.png,123290000.png,125258000.png
42 | f956425e6e6f9863,114448000.png,115982000.png,117650000.png
43 | f96049f4355544bd,194260733.png,198097900.png,202001800.png
44 | f968e1bb4d5f9e4b,273239633.png,275208267.png,277610667.png
45 | f99691764cd67e0c,102969533.png,106506400.png,110243467.png
46 | f99730b487e57e9b,217250367.png,221287733.png,225291733.png
47 | fa005e9c2a1c4828,167567400.png,171504667.png,175675500.png
48 | fa14b62a46ffd7b7,30964267.png,35502133.png,39873167.png
49 | fa1fe0e1d0ba7450,133199733.png,134701233.png,136402933.png
50 | fa25ffc4897fc142,70303567.png,73373300.png,76309567.png
51 | fa53b8c13f6caabb,45344000.png,47680000.png,52118000.png
52 | fa71d84898d6ac7f,30697333.png,35502133.png,39706333.png
53 | fa7ac0885965196f,78545133.png,83016267.png,85752333.png
54 | fa82d71146e33555,113413300.png,116749967.png,121087633.png
55 | fa95ce310f223e51,157123790.png,158258258.png,159426093.png
56 | faa891b87b4b33ad,270169837.png,273139473.png,276309309.png
57 | faaba1f48348f5c6,262463000.png,263897000.png,265332000.png
58 | fab28f712c0edea4,188772000.png,190732000.png,192609000.png
59 | faf2fb655d5974b5,120820700.png,125391933.png,129796333.png
60 | fb08b3c2668c99c3,204871000.png,208041000.png,211378000.png
61 | fb0d3f523276eefe,67033633.png,71471400.png,75742333.png
62 | fb1bd05ed2a473ad,59759700.png,64064000.png,68568500.png
63 | fb277b237d7bdcb5,249015433.png,250883967.png,253019433.png
64 | fb32e6765f776b18,289622667.png,292292000.png,294877922.png
65 | fb35415a0b8eb135,260026000.png,263196000.png,265965000.png
66 | fb36cf1f6923924b,112145000.png,115248000.png,118718000.png
67 | fb3d48fa6ce76da1,132098633.png,133266467.png,134434300.png
68 | fb424ab019f3e1a1,77978000.png,82315000.png,86653000.png
69 | fb52f951d8a8ad11,295628667.png,297664033.png,299899600.png
70 | fb606e0ed3d19e90,120187000.png,122122000.png,124158000.png
71 | fb723ecfff811097,138538400.png,143009533.png,147213733.png
72 | fb80aacde91ee824,212879333.png,217483933.png,221788233.png
73 | fbb437c54b5d9d78,113380000.png,117617000.png,122121000.png
74 | fbd37246c3e03576,167534033.png,171838333.png,176376200.png
75 | fbdcb1880c6bab8e,81915167.png,84250833.png,86719967.png
76 | fbdf22c23b31ceb4,30730700.png,35101733.png,39606233.png
77 | fbe4055d5dfd480e,40240200.png,43043000.png,46146100.png
78 | fbe9232477f69e33,100166733.png,101801700.png,103336567.png
79 | fbeb367248e705bc,104070633.png,107607500.png,110844067.png
80 | fbef04fb266195f9,81414667.png,83183100.png,85085000.png
81 | fbfed12d638968c2,40807433.png,45078367.png,48715333.png
82 | fc174b681e13cdf4,120854067.png,123356567.png,126092633.png
83 | fc43cc25dadacb18,82682600.png,85051633.png,87587500.png
84 | fc4410feec1308a2,176409567.png,180914067.png,185385200.png
85 | fc50278b5f950186,100133367.png,101734967.png,102335567.png
86 | fc50693501976251,159592767.png,161828333.png,165064900.png
87 | fc508faa91c9807d,145278000.png,146847000.png,148281000.png
88 | fc6f664a700121e9,33833800.png,37037000.png,39939900.png
89 | fc75727847fd61f4,260193267.png,264397467.png,268668400.png
90 | fc87af9c4f0000b1,57023633.png,56423033.png,55055000.png
91 | fc966f25afa3659f,160660500.png,162829333.png,167033533.png
92 | fc991bfad5040ca4,232607378.png,234234000.png,236069167.png
93 | fc99e61da3cfb1b1,54788067.png,57123733.png,59726333.png
94 | fca310e3805d5ab4,78945533.png,80613867.png,81581500.png
95 | fca5ca7d4517812b,253286367.png,254020433.png,258258000.png
96 | fcb0cb3d3da5c2e5,38738700.png,41107733.png,43710333.png
97 | fcd7ebdf0454ef30,149182367.png,151251100.png,152519033.png
98 | fcde616cb28da426,200800000.png,202735000.png,204571000.png
99 | fce3a289eed2d6c6,158058000.png,160828000.png,163563000.png
100 | fcfa188fdd8e4cdc,100199000.png,104804000.png,108141000.png
101 | fd020d75ab5e2fa3,91558133.png,94794700.png,97931167.png
102 | fd2c0fc2bf2befa7,87720967.png,89055633.png,90290200.png
103 | fd3c54bb57284e5c,120019000.png,121254000.png,122455000.png
104 | fd48a65a5e252855,71671000.png,75909000.png,80180000.png
105 | fd6ee791edb85b45,202802600.png,205138267.png,207273733.png
106 | fd7890ec5b0a924c,310226000.png,307015000.png,304346000.png
107 | fd9ab1e2b1f8e5aa,78233333.png,80666667.png,83066667.png
108 | fdaab1f49851782c,69736333.png,72305567.png,74874800.png
109 | fdb60d991671bd45,47213833.png,51718333.png,55955900.png
110 | fdceb26461ef9adb,50050000.png,54020633.png,57924533.png
111 | fdd41bb5cab98c86,68701967.png,72105367.png,76042633.png
112 | fdf17f383d76c327,238605000.png,241641000.png,244745000.png
113 | fe04a079dc651d9b,112245000.png,117884000.png,121087000.png
114 | fe26ad446164cb32,112378933.png,114714600.png,118351567.png
115 | fe2fadf89a84e92a,134000533.png,135468667.png,137070267.png
116 | fe625de05cd0a34b,83082000.png,87353000.png,91958000.png
117 | fe8431a1c5bd3c81,260793867.png,265198267.png,269569300.png
118 | fe97f5516fea659b,144544400.png,148014533.png,151484667.png
119 | fea3a0ff96583244,271104167.png,273373100.png,275408467.png
120 | fea544b472e9abd1,44110000.png,48682000.png,53086000.png
121 | feba28a8f1b69de6,102870000.png,107107000.png,111778000.png
122 | fef11d498a882edc,44911533.png,47113733.png,49182467.png
123 | fef6f87f0eb94f9d,243310000.png,246112000.png,248648000.png
124 | fefa62ae2b9915cb,34635000.png,38471000.png,42141000.png
125 | ff036af715175dce,313413100.png,314981333.png,316616300.png
126 | ff03c7906d70bc61,213713000.png,216916000.png,219986000.png
127 | ff0749db248be7fb,115315200.png,117650867.png,120120000.png
128 | ff1c8223873aa02b,93727000.png,95161000.png,96730000.png
129 | ff20a3e943ea6d63,179246000.png,183583000.png,188021000.png
130 | ff58ffaac40eb035,277043433.png,280013067.png,282649033.png
131 | ff63123a7ef312a5,130630500.png,132765967.png,134934800.png
132 | ff6d8ab35e042db5,137237100.png,140006533.png,142775967.png
133 | ff7f5042dddd5e12,113146000.png,115048000.png,118952000.png
134 | ff821ee982e8f194,181981800.png,184517667.png,186986800.png
135 | ff887646981745a6,270970700.png,275275000.png,279479200.png
136 | ff9e265755208438,82849433.png,84717967.png,87287200.png
137 | ffa95c3b40609c76,146179000.png,149449000.png,152986000.png
138 | ffb3b1bb765b96eb,195828967.png,199933067.png,204404200.png
139 | ffbd89833d851c18,177210367.png,179979800.png,182882700.png
140 | ffd80f73ab22500b,181514667.png,184417567.png,187053533.png
141 | ffe67ac537febe41,57290567.png,60727333.png,63596867.png
142 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: geofree
2 | channels:
3 | - pytorch
4 | - defaults
5 | - conda-forge
6 | dependencies:
7 | - python=3.8.5
8 | - pip=20.3
9 | - cudatoolkit=10.1
10 | - cudatoolkit-dev=10.1.243 # for nvcc
11 | - gcc_linux-64=7.3.0 # for nvcc compatibility
12 | - gxx_linux-64=7.3.0 # for nvcc compatibility
13 | - pytorch=1.7.1
14 | - torchvision=0.8.2
15 | - numpy=1.19.2
16 |
--------------------------------------------------------------------------------
/geofree/__init__.py:
--------------------------------------------------------------------------------
1 | from geofree.util import pretrained_models
2 |
--------------------------------------------------------------------------------
/geofree/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/data/__init__.py
--------------------------------------------------------------------------------
/geofree/data/acid.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 |
5 | from geofree.data.realestate import PRNGMixin, load_sparse_model_example, pad_points
6 |
7 |
8 | class ACIDSparseBase(Dataset, PRNGMixin):
9 | def __init__(self):
10 | self.sparse_dir = "data/acid_sparse"
11 |
12 | def __len__(self):
13 | return len(self.sequences)
14 |
15 | def __getitem__(self, index):
16 | seq = self.sequences[index]
17 | root = os.path.join(self.sequence_dir, seq)
18 | frames = sorted([fname for fname in os.listdir(os.path.join(root, "images")) if fname.endswith(".png")])
19 | segments = self.prng.choice(3, 2, replace=False)
20 | if segments[0] < segments[1]: # forward
21 | if segments[1]-segments[0] == 1: # small
22 | label = 0
23 | else:
24 | label = 1 # large
25 | else: # backward
26 | if segments[1]-segments[0] == 1: # small
27 | label = 2
28 | else:
29 | label = 3
30 | n = len(frames)
31 | dst_indices = list(range(segments[0]*n//3, (segments[0]+1)*n//3))
32 | src_indices = list(range(segments[1]*n//3, (segments[1]+1)*n//3))
33 | dst_index = self.prng.choice(dst_indices)
34 | src_index = self.prng.choice(src_indices)
35 | img_dst = frames[dst_index]
36 | img_src = frames[src_index]
37 |
38 | example = load_sparse_model_example(
39 | root=root, img_dst=img_dst, img_src=img_src, size=self.size)
40 |
41 | for k in example:
42 | example[k] = example[k].astype(np.float32)
43 |
44 | example["src_points"] = pad_points(example["src_points"],
45 | self.max_points)
46 | example["seq"] = seq
47 | example["label"] = label
48 | example["dst_fname"] = img_dst
49 | example["src_fname"] = img_src
50 |
51 | return example
52 |
53 |
54 | class ACIDSparseTrain(ACIDSparseBase):
55 | def __init__(self, size=None, max_points=16384):
56 | super().__init__()
57 | self.size = size
58 | self.max_points = max_points
59 |
60 | self.split = "train"
61 | self.sequence_dir = os.path.join(self.sparse_dir, self.split)
62 | with open("data/acid_train_sequences.txt", "r") as f:
63 | self.sequences = f.read().splitlines()
64 |
65 |
66 | class ACIDSparseValidation(ACIDSparseBase):
67 | def __init__(self, size=None, max_points=16384):
68 | super().__init__()
69 | self.size = size
70 | self.max_points = max_points
71 |
72 | self.split = "validation"
73 | self.sequence_dir = os.path.join(self.sparse_dir, self.split)
74 | with open("data/acid_validation_sequences.txt", "r") as f:
75 | self.sequences = f.read().splitlines()
76 |
77 |
78 | class ACIDSparseTest(ACIDSparseBase):
79 | def __init__(self, size=None, max_points=16384):
80 | super().__init__()
81 | self.size = size
82 | self.max_points = max_points
83 |
84 | self.split = "test"
85 | self.sequence_dir = os.path.join(self.sparse_dir, self.split)
86 | with open("data/acid_test_sequences.txt", "r") as f:
87 | self.sequences = f.read().splitlines()
88 |
89 |
90 | class ACIDCustomTest(Dataset):
91 | def __init__(self, size=None, max_points=16384):
92 | self.size = size
93 | self.max_points = max_points
94 |
95 | self.frames_file = "data/acid_custom_frames.txt"
96 | self.sparse_dir = "data/acid_sparse"
97 | self.split = "test"
98 |
99 | with open(self.frames_file, "r") as f:
100 | frames = f.read().splitlines()
101 |
102 | seq_data = dict()
103 | for line in frames:
104 | seq,a,b,c = line.split(",")
105 | assert not seq in seq_data
106 | seq_data[seq] = [a,b,c]
107 |
108 | # sequential list of seq, label, dst, src
109 | # where label is used to disambiguate different warping scenarios
110 | # 0: small forward movement
111 | # 1: large forward movement
112 | # 2: small backward movement (reverse of 0)
113 | # 3: large backward movement (reverse of 1)
114 | frame_data = list()
115 | for seq in sorted(seq_data.keys()):
116 | abc = seq_data[seq]
117 | frame_data.append([seq, 0, abc[1], abc[0]]) # b|a
118 | frame_data.append([seq, 1, abc[2], abc[0]]) # c|a
119 | frame_data.append([seq, 2, abc[0], abc[1]]) # a|b
120 | frame_data.append([seq, 3, abc[0], abc[2]]) # a|c
121 |
122 | self.frame_data = frame_data
123 |
124 | def __len__(self):
125 | return len(self.frame_data)
126 |
127 | def __getitem__(self, index):
128 | seq, label, img_dst, img_src = self.frame_data[index]
129 | root = os.path.join(self.sparse_dir, self.split, seq)
130 |
131 | example = load_sparse_model_example(
132 | root=root, img_dst=img_dst, img_src=img_src, size=self.size)
133 |
134 | for k in example:
135 | example[k] = example[k].astype(np.float32)
136 |
137 | example["src_points"] = pad_points(example["src_points"],
138 | self.max_points)
139 | example["seq"] = seq
140 | example["label"] = label
141 | example["dst_fname"] = img_dst
142 | example["src_fname"] = img_src
143 |
144 | return example
145 |
--------------------------------------------------------------------------------
/geofree/data/realestate.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | import os
3 | import numpy as np
4 | import torch.utils.data as data
5 | from PIL import Image
6 | from torchvision.transforms import Compose, Normalize, Resize, ToTensor
7 | from geofree.data.read_write_model import read_model
8 |
9 |
10 | class PRNGMixin(object):
11 | """Adds a prng property which is a numpy RandomState which gets
12 | reinitialized whenever the pid changes to avoid synchronized sampling
13 | behavior when used in conjunction with multiprocessing."""
14 |
15 | @property
16 | def prng(self):
17 | currentpid = os.getpid()
18 | if getattr(self, "_initpid", None) != currentpid:
19 | self._initpid = currentpid
20 | self._prng = np.random.RandomState()
21 | return self._prng
22 |
23 |
24 | def load_sparse_model_example(root, img_dst, img_src, size):
25 | """
26 | Parameters
27 | root folder containing directory sparse with points3D.bin,
28 | images.bin and cameras.bin, and directory images
29 | img_dst filename of image in images to be used as destination
30 | img_src filename of image in images to be used as source
31 | size size to resize image and parameters to. If None nothing is
32 | done, otherwise it should be in (h,w) format
33 | Returns
34 | example dictionary containing
35 | dst_img destination image as (h,w,c) array in range (-1,1)
36 | src_img source image as (h,w,c) array in range (-1,1)
37 | src_points sparse set of 3d points for source image as (N,3) array
38 | with (:,:2) being pixel coordinates and (:,2) depth values
39 | K 3x3 camera intrinsics
40 | K_inv inverse of camera intrinsics
41 | R_rel relative rotation mapping from source to destination
42 | coordinate system
43 | t_rel relative translation mapping from source to destination
44 | coordinate system
45 | """
46 | # load sparse model
47 | model = os.path.join(root, "sparse")
48 | try:
49 | cameras, images, points3D = read_model(path=model, ext=".bin")
50 | except Exception as e:
51 | raise Exception(f"Failed to load sparse model {model}.") from e
52 |
53 |
54 | # load camera parameters and image size
55 | cam = cameras[1]
56 | h = cam.height
57 | w = cam.width
58 | params = cam.params
59 | K = np.array([[params[0], 0.0, params[2]],
60 | [0.0, params[1], params[3]],
61 | [0.0, 0.0, 1.0]])
62 |
63 | # find keys of desired dst and src images
64 | key_dst = [k for k in images.keys() if images[k].name==img_dst]
65 | assert len(key_dst)==1, (img_dst, key_dst)
66 | key_src = [k for k in images.keys() if images[k].name==img_src]
67 | assert len(key_src)==1, (img_src, key_src)
68 | keys = [key_dst[0], key_src[0]]
69 |
70 | # load extrinsics
71 | Rs = np.stack([images[k].qvec2rotmat() for k in keys])
72 | ts = np.stack([images[k].tvec for k in keys])
73 |
74 | # load sparse 3d points to be able to estimate scale
75 | sparse_points = [None, None]
76 | #for i in range(len(keys)):
77 | for i in [1]: # only need it for source
78 | key = keys[i]
79 | xys = images[key].xys
80 | p3D = images[key].point3D_ids
81 | pmask = p3D > 0
82 | # if verbose: print("Found {} 3d points in sparse model.".format(pmask.sum()))
83 | xys = xys[pmask]
84 | p3D = p3D[pmask]
85 | worlds = np.stack([points3D[id_].xyz for id_ in p3D]) # N, 3
86 | # project to current view
87 | worlds = worlds[..., None] # N,3,1
88 | pixels = K[None,...]@(Rs[i][None,...]@worlds+ts[i][None,...,None])
89 | pixels = pixels.squeeze(-1) # N,3
90 |
91 | # instead of using provided xys, one could also project pixels, ie
92 | # xys ~ pixels[:,:2]/pixels[:,[2]]
93 | points = np.concatenate([xys, pixels[:,[2]]], axis=1)
94 | sparse_points[i] = points
95 |
96 | # code to convert to sparse depth map
97 | # xys = points[:,:2]
98 | # xys = np.round(xys).astype(np.int)
99 | # xys[:,0] = xys[:,0].clip(min=0,max=w-1)
100 | # xys[:,1] = xys[:,1].clip(min=0,max=h-1)
101 | # indices = xys[:,1]*w+xys[:,0]
102 | # flatdm = np.zeros(h*w)
103 | # flatz = pixels[:,2]
104 | # np.put_along_axis(flatdm, indices, flatz, axis=0)
105 | # sparse_dm = flatdm.reshape(h,w)
106 |
107 | # load images
108 | im_root = os.path.join(root, "images")
109 | im_paths = [os.path.join(im_root, images[k].name) for k in keys]
110 | ims = list()
111 | for path in im_paths:
112 | im = Image.open(path)
113 | ims.append(im)
114 |
115 | if size is not None and (size[0] != h or size[1] != w):
116 | # resize
117 | ## K
118 | K[0,:] = K[0,:]*size[1]/w
119 | K[1,:] = K[1,:]*size[0]/h
120 | ## points
121 | points[:,0] = points[:,0]*size[1]/w
122 | points[:,1] = points[:,1]*size[0]/h
123 | ## img
124 | for i in range(len(ims)):
125 | ims[i] = ims[i].resize((size[1],size[0]),
126 | resample=Image.LANCZOS)
127 |
128 |
129 | for i in range(len(ims)):
130 | ims[i] = np.array(ims[i])/127.5-1.0
131 |
132 |
133 | # relative camera
134 | R_dst = Rs[0]
135 | t_dst = ts[0]
136 | R_src_inv = Rs[1].transpose(-1,-2)
137 | t_src = ts[1]
138 | R_rel = R_dst@R_src_inv
139 | t_rel = t_dst-R_rel@t_src
140 | K_inv = np.linalg.inv(K)
141 |
142 | # collect results
143 | example = {
144 | "dst_img": ims[0],
145 | "src_img": ims[1],
146 | "src_points": sparse_points[1],
147 | "K": K,
148 | "K_inv": K_inv,
149 | "R_rel": R_rel,
150 | "t_rel": t_rel,
151 | }
152 |
153 | return example
154 |
155 | def pad_points(points, N):
156 | padded = -1*np.ones((N,3), dtype=points.dtype)
157 | padded[:points.shape[0],:] = points
158 | return padded
159 |
160 | class RealEstate10KCustomTest(data.Dataset):
161 | def __init__(self, size=None, max_points=16384):
162 | self.size = size
163 | self.max_points = max_points
164 |
165 | self.frames_file = "data/realestate_custom_frames.txt"
166 | self.sparse_dir = "data/realestate_sparse"
167 | self.split = "test"
168 |
169 | with open(self.frames_file, "r") as f:
170 | frames = f.read().splitlines()
171 |
172 | seq_data = dict()
173 | for line in frames:
174 | seq,a,b,c = line.split(",")
175 | assert not seq in seq_data
176 | seq_data[seq] = [a,b,c]
177 |
178 | # sequential list of seq, label, dst, src
179 | # where label is used to disambiguate different warping scenarios
180 | # 0: small forward movement
181 | # 1: large forward movement
182 | # 2: small backward movement (reverse of 0)
183 | # 3: large backward movement (reverse of 1)
184 | frame_data = list()
185 | for seq in sorted(seq_data.keys()):
186 | abc = seq_data[seq]
187 | frame_data.append([seq, 0, abc[1], abc[0]]) # b|a
188 | frame_data.append([seq, 1, abc[2], abc[0]]) # c|a
189 | frame_data.append([seq, 2, abc[0], abc[1]]) # a|b
190 | frame_data.append([seq, 3, abc[0], abc[2]]) # a|c
191 |
192 | self.frame_data = frame_data
193 |
194 | def __len__(self):
195 | return len(self.frame_data)
196 |
197 | def __getitem__(self, index):
198 | seq, label, img_dst, img_src = self.frame_data[index]
199 | root = os.path.join(self.sparse_dir, self.split, seq)
200 |
201 | example = load_sparse_model_example(
202 | root=root, img_dst=img_dst, img_src=img_src, size=self.size)
203 |
204 | for k in example:
205 | example[k] = example[k].astype(np.float32)
206 |
207 | example["src_points"] = pad_points(example["src_points"],
208 | self.max_points)
209 | example["seq"] = seq
210 | example["label"] = label
211 | example["dst_fname"] = img_dst
212 | example["src_fname"] = img_src
213 |
214 | return example
215 |
216 |
217 | class RealEstate10KSparseTrain(data.Dataset, PRNGMixin):
218 | def __init__(self, size=None, max_points=16384):
219 | self.size = size
220 | self.max_points = max_points
221 |
222 | self.sparse_dir = "data/realestate_sparse"
223 | self.split = "train"
224 | self.sequence_dir = os.path.join(self.sparse_dir, self.split)
225 | with open("data/realestate_train_sequences.txt", "r") as f:
226 | self.sequences = f.read().splitlines()
227 |
228 | def __len__(self):
229 | return len(self.sequences)
230 |
231 | def __getitem__(self, index):
232 | seq = self.sequences[index]
233 | root = os.path.join(self.sequence_dir, seq)
234 | frames = sorted([fname for fname in os.listdir(os.path.join(root, "images")) if fname.endswith(".png")])
235 | segments = self.prng.choice(3, 2, replace=False)
236 | if segments[0] < segments[1]: # forward
237 | if segments[1]-segments[0] == 1: # small
238 | label = 0
239 | else:
240 | label = 1 # large
241 | else: # backward
242 | if segments[1]-segments[0] == 1: # small
243 | label = 2
244 | else:
245 | label = 3
246 | n = len(frames)
247 | dst_indices = list(range(segments[0]*n//3, (segments[0]+1)*n//3))
248 | src_indices = list(range(segments[1]*n//3, (segments[1]+1)*n//3))
249 | dst_index = self.prng.choice(dst_indices)
250 | src_index = self.prng.choice(src_indices)
251 | img_dst = frames[dst_index]
252 | img_src = frames[src_index]
253 |
254 | example = load_sparse_model_example(
255 | root=root, img_dst=img_dst, img_src=img_src, size=self.size)
256 |
257 | for k in example:
258 | example[k] = example[k].astype(np.float32)
259 |
260 | example["src_points"] = pad_points(example["src_points"],
261 | self.max_points)
262 | example["seq"] = seq
263 | example["label"] = label
264 | example["dst_fname"] = img_dst
265 | example["src_fname"] = img_src
266 |
267 | return example
268 |
269 |
270 | class RealEstate10KSparseCustom(data.Dataset):
271 | def __init__(self, frame_data, split, size=None, max_points=16384):
272 | self.size = size
273 | self.max_points = max_points
274 |
275 | self.sparse_dir = "data/realestate_sparse"
276 | self.split = split
277 | self.frame_data = frame_data
278 |
279 | def __len__(self):
280 | return len(self.frame_data)
281 |
282 | def __getitem__(self, index):
283 | seq, label, img_dst, img_src = self.frame_data[index]
284 | root = os.path.join(self.sparse_dir, self.split, seq)
285 |
286 | example = load_sparse_model_example(
287 | root=root, img_dst=img_dst, img_src=img_src, size=self.size)
288 |
289 | for k in example:
290 | example[k] = example[k].astype(np.float32)
291 |
292 | example["src_points"] = pad_points(example["src_points"],
293 | self.max_points)
294 | example["seq"] = seq
295 | example["label"] = label
296 | example["dst_fname"] = img_dst
297 | example["src_fname"] = img_src
298 |
299 | return example
300 |
--------------------------------------------------------------------------------
/geofree/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/examples/__init__.py
--------------------------------------------------------------------------------
/geofree/examples/artist.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/examples/artist.jpg
--------------------------------------------------------------------------------
/geofree/examples/beach.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/examples/beach.jpg
--------------------------------------------------------------------------------
/geofree/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n):
33 | return self.schedule(n)
34 |
35 |
--------------------------------------------------------------------------------
/geofree/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/models/__init__.py
--------------------------------------------------------------------------------
/geofree/models/transformers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/models/transformers/__init__.py
--------------------------------------------------------------------------------
/geofree/models/transformers/warpgpt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import torch.nn.functional as F
4 | import pytorch_lightning as pl
5 | from torch.optim.lr_scheduler import LambdaLR
6 |
7 | from geofree.main import instantiate_from_config
8 |
9 |
10 | class WarpTransformer(pl.LightningModule):
11 | """This one relies on WarpGPT, where the warper handles the warping to
12 | support warping of positional embeddings, too.
13 | first stage to encode dst
14 | cond stage to encode src
15 | """
16 | def __init__(self,
17 | transformer_config,
18 | first_stage_config,
19 | cond_stage_config,
20 | ckpt_path=None,
21 | ignore_keys=[],
22 | first_stage_key="dst_img",
23 | cond_stage_key="src_img",
24 | use_scheduler=False,
25 | scheduler_config=None,
26 | monitor="val/loss",
27 | pkeep=1.0,
28 | plot_cond_stage=False,
29 | log_det_sample=False,
30 | top_k=None
31 | ):
32 |
33 | super().__init__()
34 | if monitor is not None:
35 | self.monitor = monitor
36 | self.log_det_sample = log_det_sample
37 | self.init_first_stage_from_ckpt(first_stage_config)
38 | self.init_cond_stage_from_ckpt(cond_stage_config)
39 | self.transformer = instantiate_from_config(config=transformer_config)
40 |
41 | if ckpt_path is not None:
42 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
43 | self.first_stage_key = first_stage_key
44 | self.cond_stage_key = cond_stage_key
45 | self.pkeep = pkeep
46 |
47 | self.use_scheduler = use_scheduler
48 | if use_scheduler:
49 | assert scheduler_config is not None
50 | self.scheduler_config = scheduler_config
51 | self.plot_cond_stage = plot_cond_stage
52 | self.top_k = top_k if top_k is not None else 100
53 | self.warpkwargs_keys = {
54 | "x": "src_img",
55 | "points": "src_points",
56 | "R": "R_rel",
57 | "t": "t_rel",
58 | "K_dst": "K",
59 | "K_src_inv": "K_inv",
60 | }
61 |
62 | def init_from_ckpt(self, path, ignore_keys=list()):
63 | sd = torch.load(path, map_location="cpu")["state_dict"]
64 | for k in sd.keys():
65 | for ik in ignore_keys:
66 | if k.startswith(ik):
67 | self.print("Deleting key {} from state_dict.".format(k))
68 | del sd[k]
69 | missing, unexpected = self.load_state_dict(sd, strict=False)
70 | print(f"Restored from {path} with {len(missing)} missing keys and {len(unexpected)} unexpected keys.")
71 |
72 | def init_first_stage_from_ckpt(self, config):
73 | model = instantiate_from_config(config)
74 | self.first_stage_model = model.eval()
75 |
76 | def init_cond_stage_from_ckpt(self, config):
77 | if config == "__is_first_stage__":
78 | print("Using first stage also as cond stage.")
79 | self.cond_stage_model = self.first_stage_model
80 | else:
81 | model = instantiate_from_config(config)
82 | self.cond_stage_model = model.eval()
83 |
84 | def forward(self, x, c, warpkwargs):
85 | # one step to produce the logits
86 | _, z_indices = self.encode_to_z(x)
87 | _, c_indices = self.encode_to_c(c)
88 |
89 | if self.training and self.pkeep < 1.0:
90 | mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
91 | device=z_indices.device))
92 | mask = mask.round().to(dtype=torch.int64)
93 | r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
94 | a_indices = mask*z_indices+(1-mask)*r_indices
95 | else:
96 | a_indices = z_indices
97 |
98 | cz_indices = torch.cat((c_indices, a_indices), dim=1)
99 |
100 | # target includes all sequence elements (no need to handle first one
101 | # differently because we are conditioning)
102 | target = z_indices
103 | # make the prediction
104 | logits, _ = self.transformer(cz_indices[:, :-1], warpkwargs)
105 | # cut off conditioning outputs - output i corresponds to p(z_i | z_{ 3:
183 | # colorize with random projection
184 | assert xrec.shape[1] > 3
185 | x = self.to_rgb(x)
186 | xrec = self.to_rgb(xrec)
187 | log["inputs"] = x
188 | log["reconstructions"] = xrec
189 | return log
190 |
191 | def to_rgb(self, x):
192 | assert self.image_key == "segmentation"
193 | if not hasattr(self, "colorize"):
194 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
195 | x = F.conv2d(x, weight=self.colorize)
196 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
197 | return x
198 |
--------------------------------------------------------------------------------
/geofree/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/modules/__init__.py
--------------------------------------------------------------------------------
/geofree/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/geofree/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from geofree.modules.losses.vqperceptual import DummyLoss
2 |
--------------------------------------------------------------------------------
/geofree/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class DummyLoss(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
--------------------------------------------------------------------------------
/geofree/modules/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/modules/transformer/__init__.py
--------------------------------------------------------------------------------
/geofree/modules/transformer/mingpt.py:
--------------------------------------------------------------------------------
1 | """
2 | credit to: https://github.com/karpathy/minGPT/
3 | GPT model:
4 | - the initial stem consists of a combination of token encoding and a positional encoding
5 | - the meat of it is a uniform sequence of Transformer blocks
6 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
7 | - all blocks feed into a central residual pathway similar to resnets
8 | - the final decoder is a linear projection into a vanilla Softmax classifier
9 | """
10 |
11 | import math
12 | import logging
13 |
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn import functional as F
17 |
18 | from geofree.main import instantiate_from_config
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class GPTConfig:
24 | """ base GPT config, params common to all GPT versions """
25 | embd_pdrop = 0.1
26 | resid_pdrop = 0.1
27 | attn_pdrop = 0.1
28 |
29 | def __init__(self, vocab_size, block_size, **kwargs):
30 | self.vocab_size = vocab_size
31 | self.block_size = block_size
32 | for k,v in kwargs.items():
33 | setattr(self, k, v)
34 |
35 |
36 | class GPT1Config(GPTConfig):
37 | """ GPT-1 like network roughly 125M params """
38 | n_layer = 12
39 | n_head = 12
40 | n_embd = 768
41 |
42 |
43 | class CausalSelfAttention(nn.Module):
44 | def __init__(self, config):
45 | super().__init__()
46 | assert config.n_embd % config.n_head == 0, f"n_embd is {config.n_embd} but n_head is {config.n_head}."
47 | # key, query, value projections for all heads
48 | self.key = nn.Linear(config.n_embd, config.n_embd)
49 | self.query = nn.Linear(config.n_embd, config.n_embd)
50 | self.value = nn.Linear(config.n_embd, config.n_embd)
51 | # regularization
52 | self.attn_drop = nn.Dropout(config.attn_pdrop)
53 | self.resid_drop = nn.Dropout(config.resid_pdrop)
54 | # output projection
55 | self.proj = nn.Linear(config.n_embd, config.n_embd)
56 | # causal mask to ensure that attention is only applied to the left in the input sequence
57 | mask = torch.tril(torch.ones(config.block_size,
58 | config.block_size))
59 | if hasattr(config, "n_unmasked"):
60 | mask[:config.n_unmasked, :config.n_unmasked] = 1
61 | self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
62 | self.n_head = config.n_head
63 |
64 | def forward(self, x, layer_past=None):
65 | B, T, C = x.size()
66 |
67 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
68 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
69 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
70 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
71 |
72 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
73 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
74 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
75 | att = F.softmax(att, dim=-1)
76 | att = self.attn_drop(att)
77 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
78 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
79 |
80 | # output projection
81 | y = self.resid_drop(self.proj(y))
82 | return y
83 |
84 |
85 | class Block(nn.Module):
86 | """ an unassuming Transformer block """
87 | def __init__(self, config):
88 | super().__init__()
89 | self.ln1 = nn.LayerNorm(config.n_embd)
90 | self.ln2 = nn.LayerNorm(config.n_embd)
91 | self.attn = CausalSelfAttention(config)
92 | self.mlp = nn.Sequential(
93 | nn.Linear(config.n_embd, 4 * config.n_embd),
94 | nn.GELU(), # nice
95 | nn.Linear(4 * config.n_embd, config.n_embd),
96 | nn.Dropout(config.resid_pdrop),
97 | )
98 |
99 | def forward(self, x):
100 | x = x + self.attn(self.ln1(x))
101 | x = x + self.mlp(self.ln2(x))
102 | return x
103 |
104 |
105 | class GPT(nn.Module):
106 | """ the full GPT language model, with a context size of block_size """
107 | def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
108 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0,
109 | input_vocab_size=None):
110 | super().__init__()
111 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
112 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
113 | n_layer=n_layer, n_head=n_head, n_embd=n_embd,
114 | n_unmasked=n_unmasked)
115 | # input embedding stem
116 | in_vocab_size = vocab_size if not input_vocab_size else input_vocab_size
117 | self.tok_emb = nn.Embedding(in_vocab_size, config.n_embd)
118 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
119 | self.drop = nn.Dropout(config.embd_pdrop)
120 | # transformer
121 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
122 | # decoder head
123 | self.ln_f = nn.LayerNorm(config.n_embd)
124 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
125 | self.block_size = config.block_size
126 | self.apply(self._init_weights)
127 | self.config = config
128 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
129 |
130 | def get_block_size(self):
131 | return self.block_size
132 |
133 | def _init_weights(self, module):
134 | if isinstance(module, (nn.Linear, nn.Embedding)):
135 | module.weight.data.normal_(mean=0.0, std=0.02)
136 | if isinstance(module, nn.Linear) and module.bias is not None:
137 | module.bias.data.zero_()
138 | elif isinstance(module, nn.LayerNorm):
139 | module.bias.data.zero_()
140 | module.weight.data.fill_(1.0)
141 |
142 | def forward(self, idx, embeddings=None, targets=None, return_layers=False):
143 | # forward the GPT model
144 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
145 |
146 | if embeddings is not None: # prepend explicit embeddings
147 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
148 |
149 | t = token_embeddings.shape[1]
150 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
151 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
152 | x = self.drop(token_embeddings + position_embeddings)
153 |
154 | if return_layers:
155 | layers = [x]
156 | for block in self.blocks:
157 | x = block(x)
158 | layers.append(x)
159 | return layers
160 |
161 | x = self.blocks(x)
162 | x = self.ln_f(x)
163 | logits = self.head(x)
164 |
165 | # if we are given some desired targets also calculate the loss
166 | loss = None
167 | if targets is not None:
168 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
169 | return logits, loss
170 |
171 |
172 | class DummyGPT(nn.Module):
173 | # for debugging
174 | def __init__(self, add_value=1):
175 | super().__init__()
176 | self.add_value = add_value
177 |
178 | def forward(self, idx):
179 | return idx + self.add_value, None
180 |
181 |
182 | class CodeGPT(nn.Module):
183 | """Takes in semi-embeddings"""
184 | def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
185 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0, clf_head=True,
186 | init_weights=True, init_last_layer_to_zero=False):
187 | super().__init__()
188 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
189 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
190 | n_layer=n_layer, n_head=n_head, n_embd=n_embd,
191 | n_unmasked=n_unmasked)
192 | # input embedding stem
193 | self.tok_emb = nn.Linear(in_channels, config.n_embd)
194 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
195 | self.drop = nn.Dropout(config.embd_pdrop)
196 | # transformer
197 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
198 |
199 | self.clf_head = clf_head
200 | if self.clf_head:
201 | # decoder head
202 | self.ln_f = nn.LayerNorm(config.n_embd)
203 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
204 | else:
205 | self.head = nn.Linear(config.n_embd, in_channels)
206 | #print(f"Not using classification head of GPT model. n_unmasked is {n_unmasked}.")
207 | self.block_size = config.block_size
208 | if init_weights:
209 | self.apply(self._init_weights)
210 | if init_last_layer_to_zero:
211 | print("### WARNING: Initializing CodeGPT-head to zero ###")
212 | self.head.apply(self._init_to_zero)
213 | self.config = config
214 | #print(f"{self.__class__.__name__}: number of parameters: {sum(p.numel() for p in self.parameters())*1.0e-6:.2f} M.")
215 |
216 | def get_block_size(self):
217 | return self.block_size
218 |
219 | def _init_weights(self, module):
220 | if isinstance(module, (nn.Linear, nn.Embedding)):
221 | module.weight.data.normal_(mean=0.0, std=0.02)
222 | if isinstance(module, nn.Linear) and module.bias is not None:
223 | module.bias.data.zero_()
224 | elif isinstance(module, nn.LayerNorm):
225 | module.bias.data.zero_()
226 | module.weight.data.fill_(1.0)
227 |
228 | def _init_to_zero(self, module):
229 | if isinstance(module, (nn.Linear, nn.Embedding)):
230 | module.weight.data.zero_()
231 | if isinstance(module, nn.Linear) and module.bias is not None:
232 | module.bias.data.zero_()
233 | elif isinstance(module, nn.LayerNorm):
234 | module.bias.data.zero_()
235 | module.weight.data.zero_()
236 |
237 |
238 | def forward(self, idx, embeddings=None, targets=None):
239 | # forward the GPT model
240 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
241 |
242 | if embeddings is not None: # prepend explicit embeddings
243 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
244 |
245 | t = token_embeddings.shape[1]
246 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
247 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
248 | x = self.drop(token_embeddings + position_embeddings)
249 | x = self.blocks(x)
250 |
251 | if not self.clf_head:
252 | # maps back to in_channels
253 | return self.head(x)
254 |
255 | x = self.ln_f(x)
256 | logits = self.head(x)
257 | # if we are given some desired targets also calculate the loss
258 | loss = None
259 | if targets is not None:
260 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
261 | return logits, loss
262 |
263 |
264 | class WarpGPT(GPT):
265 | """ the full GPT language model, with a context size of block_size """
266 | def __init__(self, **kwargs):
267 | assert "n_unmasked" in kwargs and kwargs["n_unmasked"] != 0
268 | warper_config = kwargs.pop("warper_config")
269 | if warper_config.params is None:
270 | warper_config.params = dict()
271 | warper_config.params["n_unmasked"] = kwargs["n_unmasked"]
272 | warper_config.params["block_size"] = kwargs["block_size"]
273 | warper_config.params["n_embd"] = kwargs["n_embd"]
274 | super().__init__(**kwargs)
275 | self.warper = instantiate_from_config(warper_config)
276 |
277 | def forward(self, idx, warpkwargs, embeddings=None, targets=None):
278 | # forward the GPT model
279 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
280 |
281 | if embeddings is not None: # prepend explicit embeddings
282 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
283 |
284 | t = token_embeddings.shape[1]
285 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
286 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
287 |
288 | token_embeddings, position_embeddings = self.warper(token_embeddings,
289 | position_embeddings,
290 | warpkwargs)
291 |
292 | x = self.drop(token_embeddings + position_embeddings)
293 | x = self.blocks(x)
294 | x = self.ln_f(x)
295 | logits = self.head(x)
296 |
297 | # if we are given some desired targets also calculate the loss
298 | loss = None
299 | if targets is not None:
300 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
301 | return logits, loss
302 |
303 | #### sampling utils
304 |
305 | def top_k_logits(logits, k):
306 | v, ix = torch.topk(logits, k)
307 | out = logits.clone()
308 | out[out < v[:, [-1]]] = -float('Inf')
309 | return out
310 |
311 | @torch.no_grad()
312 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
313 | """
314 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
315 | the sequence, feeding the predictions back into the model each time. Clearly the sampling
316 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window
317 | of block_size, unlike an RNN that has an infinite context window.
318 | """
319 | block_size = model.get_block_size()
320 | model.eval()
321 | for k in range(steps):
322 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
323 | logits, _ = model(x_cond)
324 | # pluck the logits at the final step and scale by temperature
325 | logits = logits[:, -1, :] / temperature
326 | # optionally crop probabilities to only the top k options
327 | if top_k is not None:
328 | logits = top_k_logits(logits, top_k)
329 | # apply softmax to convert to probabilities
330 | probs = F.softmax(logits, dim=-1)
331 | # sample from the distribution or take the most likely
332 | if sample:
333 | ix = torch.multinomial(probs, num_samples=1)
334 | else:
335 | _, ix = torch.topk(probs, k=1, dim=-1)
336 | # append to the sequence and continue
337 | x = torch.cat((x, ix), dim=1)
338 |
339 | return x
340 |
--------------------------------------------------------------------------------
/geofree/modules/transformer/warper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 |
5 | from geofree.modules.warp.midas import Midas
6 |
7 |
8 | def disabled_train(self, mode=True):
9 | """Overwrite model.train with this function to make sure train/eval mode
10 | does not change anymore."""
11 | return self
12 |
13 |
14 | class AbstractWarper(nn.Module):
15 | def __init__(self, *args, **kwargs):
16 | super().__init__()
17 | self._midas = Midas()
18 | self._midas.eval()
19 | self._midas.train = disabled_train
20 | for param in self._midas.parameters():
21 | param.requires_grad = False
22 |
23 | self.n_unmasked = kwargs["n_unmasked"] # length of conditioning
24 | self.n_embd = kwargs["n_embd"]
25 | self.block_size = kwargs["block_size"]
26 | self.size = kwargs["size"] # h, w tuple
27 | self.start_idx = kwargs.get("start_idx", 0) # hint to not modify parts
28 |
29 | self._use_cache = False
30 | self.new_emb = None # cache
31 | self.new_pos = None # cache
32 |
33 | def set_cache(self, value):
34 | self._use_cache = value
35 |
36 | def get_embeddings(self, token_embeddings, position_embeddings, warpkwargs):
37 | if self._use_cache:
38 | assert not self.training, "Do you really want to use caching during training?"
39 | assert self.new_emb is not None
40 | assert self.new_pos is not None
41 | return self.new_emb, self.new_pos
42 | self.new_emb, self.new_pos = self._get_embeddings(token_embeddings,
43 | position_embeddings,
44 | warpkwargs)
45 | return self.new_emb, self.new_pos
46 |
47 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs):
48 | raise NotImplementedError()
49 |
50 | def forward(self, token_embeddings, position_embeddings, warpkwargs):
51 | new_emb, new_pos = self.get_embeddings(token_embeddings,
52 | position_embeddings,
53 | warpkwargs)
54 |
55 | new_emb = torch.cat([new_emb, token_embeddings[:,self.n_unmasked:,:]],
56 | dim=1)
57 | b = new_pos.shape[0]
58 | new_pos = torch.cat([new_pos, position_embeddings[:,self.n_unmasked:,:][b*[0],...]],
59 | dim=1)
60 |
61 | return new_emb, new_pos
62 |
63 | def _to_sequence(self, x):
64 | x = rearrange(x, 'b c h w -> b (h w) c')
65 | return x
66 |
67 | def _to_imglike(self, x):
68 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.size[0])
69 | return x
70 |
71 |
72 | class AbstractWarperWithCustomEmbedding(AbstractWarper):
73 | def __init__(self, *args, **kwargs):
74 | super().__init__()
75 | self.pos_emb = nn.Parameter(torch.zeros(1, self.block_size, self.n_embd))
76 |
77 |
78 | class NoSourceWarper(AbstractWarper):
79 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs):
80 | cond_emb = token_embeddings[:,:self.n_unmasked,:]
81 | cond_pos = position_embeddings[:,:self.n_unmasked,:]
82 |
83 | b, seq_length, chn = cond_emb.shape
84 | cond_emb = self._to_imglike(cond_emb)
85 |
86 | cond_pos = self._to_imglike(cond_pos)
87 | cond_pos = cond_pos[b*[0],...]
88 |
89 | new_emb, _ = self._midas.warp_features(f=cond_emb, no_depth_grad=True,
90 | boltzmann_factor=0.0,
91 | **warpkwargs)
92 | new_pos, _ = self._midas.warp_features(f=cond_pos, no_depth_grad=True,
93 | boltzmann_factor=0.0,
94 | **warpkwargs)
95 | new_emb = self._filter_nans(new_emb)
96 | new_pos = self._filter_nans(new_pos)
97 |
98 | new_emb = self._to_sequence(new_emb)
99 | new_pos = self._to_sequence(new_pos)
100 | return new_emb, new_pos
101 |
102 | def _filter_nans(self, x):
103 | x[torch.isnan(x)] = 0.
104 | return x
105 |
106 |
107 | class ConvWarper(AbstractWarper):
108 | def __init__(self, *args, **kwargs):
109 | super().__init__(*args, **kwargs)
110 | self.emb_conv = nn.Conv2d(2*self.n_embd, self.n_embd,
111 | kernel_size=1,
112 | padding=0,
113 | bias=False)
114 | self.pos_conv = nn.Conv2d(2*self.n_embd, self.n_embd,
115 | kernel_size=1,
116 | padding=0,
117 | bias=False)
118 |
119 | def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs):
120 | cond_emb = token_embeddings[:,self.start_idx:self.n_unmasked,:]
121 | cond_pos = position_embeddings[:,self.start_idx:self.n_unmasked,:]
122 |
123 | b, seq_length, chn = cond_emb.shape
124 | cond_emb = cond_emb.reshape(b, self.size[0], self.size[1], chn)
125 | cond_emb = cond_emb.permute(0,3,1,2)
126 |
127 | cond_pos = cond_pos.reshape(1, self.size[0], self.size[1], chn)
128 | cond_pos = cond_pos.permute(0,3,1,2)
129 | cond_pos = cond_pos[b*[0],...]
130 |
131 | with torch.no_grad():
132 | warp_emb, _ = self._midas.warp_features(f=cond_emb, no_depth_grad=True, **warpkwargs)
133 | warp_pos, _ = self._midas.warp_features(f=cond_pos, no_depth_grad=True, **warpkwargs)
134 |
135 | new_emb = self.emb_conv(torch.cat([cond_emb, warp_emb], dim=1))
136 | new_pos = self.pos_conv(torch.cat([cond_pos, warp_pos], dim=1))
137 |
138 | new_emb = new_emb.permute(0,2,3,1)
139 | new_emb = new_emb.reshape(b,seq_length,chn)
140 |
141 | new_pos = new_pos.permute(0,2,3,1)
142 | new_pos = new_pos.reshape(b,seq_length,chn)
143 |
144 | # prepend unmodified ones again
145 | new_emb = torch.cat((token_embeddings[:,:self.start_idx,:], new_emb),
146 | dim=1)
147 | new_pos = torch.cat((position_embeddings[:,:self.start_idx,:][b*[0],...], new_pos),
148 | dim=1)
149 |
150 | return new_emb, new_pos
151 |
--------------------------------------------------------------------------------
/geofree/modules/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AbstractEncoder(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
9 | def encode(self, *args, **kwargs):
10 | raise NotImplementedError
11 |
12 |
13 | class SOSProvider(AbstractEncoder):
14 | # for unconditional training
15 | def __init__(self, sos_token, quantize_interface=True):
16 | super().__init__()
17 | self.sos_token = sos_token
18 | self.quantize_interface = quantize_interface
19 |
20 | def encode(self, x):
21 | # get batch size from data and replicate sos_token
22 | c = torch.ones(x.shape[0], 1)*self.sos_token
23 | c = c.long().to(x.device)
24 | if self.quantize_interface:
25 | return None, None, [None, None, c]
26 | return c
27 |
28 |
29 | def count_params(model, verbose=False):
30 | total_params = sum(p.numel() for p in model.parameters())
31 | if verbose:
32 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
33 | return total_params
34 |
35 |
36 | class ActNorm(nn.Module):
37 | def __init__(self, num_features, logdet=False, affine=True,
38 | allow_reverse_init=False):
39 | assert affine
40 | super().__init__()
41 | self.logdet = logdet
42 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
43 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
44 | self.allow_reverse_init = allow_reverse_init
45 |
46 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
47 |
48 | def initialize(self, input):
49 | with torch.no_grad():
50 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
51 | mean = (
52 | flatten.mean(1)
53 | .unsqueeze(1)
54 | .unsqueeze(2)
55 | .unsqueeze(3)
56 | .permute(1, 0, 2, 3)
57 | )
58 | std = (
59 | flatten.std(1)
60 | .unsqueeze(1)
61 | .unsqueeze(2)
62 | .unsqueeze(3)
63 | .permute(1, 0, 2, 3)
64 | )
65 |
66 | self.loc.data.copy_(-mean)
67 | self.scale.data.copy_(1 / (std + 1e-6))
68 |
69 | def forward(self, input, reverse=False):
70 | if reverse:
71 | return self.reverse(input)
72 | if len(input.shape) == 2:
73 | input = input[:,:,None,None]
74 | squeeze = True
75 | else:
76 | squeeze = False
77 |
78 | _, _, height, width = input.shape
79 |
80 | if self.training and self.initialized.item() == 0:
81 | self.initialize(input)
82 | self.initialized.fill_(1)
83 |
84 | h = self.scale * (input + self.loc)
85 |
86 | if squeeze:
87 | h = h.squeeze(-1).squeeze(-1)
88 |
89 | if self.logdet:
90 | log_abs = torch.log(torch.abs(self.scale))
91 | logdet = height*width*torch.sum(log_abs)
92 | logdet = logdet * torch.ones(input.shape[0]).to(input)
93 | return h, logdet
94 |
95 | return h
96 |
97 | def reverse(self, output):
98 | if self.training and self.initialized.item() == 0:
99 | if not self.allow_reverse_init:
100 | raise RuntimeError(
101 | "Initializing ActNorm in reverse direction is "
102 | "disabled by default. Use allow_reverse_init=True to enable."
103 | )
104 | else:
105 | self.initialize(output)
106 | self.initialized.fill_(1)
107 |
108 | if len(output.shape) == 2:
109 | output = output[:,:,None,None]
110 | squeeze = True
111 | else:
112 | squeeze = False
113 |
114 | h = output / self.scale - self.loc
115 |
116 | if squeeze:
117 | h = h.squeeze(-1).squeeze(-1)
118 | return h
119 |
120 |
121 | class Embedder(nn.Module):
122 | """to replace the convolutional architecture entirely"""
123 | def __init__(self, n_positions, n_channels, n_embed, bias=False):
124 | super().__init__()
125 | self.n_positions = n_positions
126 | self.n_channels = n_channels
127 | self.n_embed = n_embed
128 | self.fc = nn.Linear(self.n_channels, self.n_embed, bias=bias)
129 |
130 | def forward(self, x):
131 | x = x.reshape(x.shape[0], self.n_positions, self.n_channels)
132 | x = self.fc(x)
133 | return x
134 |
135 |
136 | class MultiEmbedder(nn.Module):
137 | def __init__(self, keys, n_positions, n_channels, n_embed, bias=False):
138 | super().__init__()
139 | self.keys = keys
140 | self.n_positions = n_positions
141 | self.n_channels = n_channels
142 | self.n_embed = n_embed
143 | self.fc = nn.Linear(self.n_channels, self.n_embed, bias=bias)
144 |
145 | def forward(self, **kwargs):
146 | values = [kwargs[k] for k in self.keys]
147 | inputs = list()
148 | for k in self.keys:
149 | entry = kwargs[k].reshape(kwargs[k].shape[0], -1, self.n_channels)
150 | inputs.append(entry)
151 | x = torch.cat(inputs, dim=1)
152 | assert x.shape[1] == self.n_positions, x.shape
153 | x = self.fc(x)
154 | return x
155 |
156 |
157 | class SpatialEmbedder(nn.Module):
158 | def __init__(self, keys, n_channels, n_embed, bias=False, shape=[13, 23]):
159 | # here, n_channels = dim(params)
160 | super().__init__()
161 | self.shape = shape
162 | self.keys = keys
163 | self.n_channels = n_channels
164 | self.n_embed = n_embed
165 | self.linear = nn.Conv2d(self.n_channels, self.n_embed, 1, bias=bias)
166 |
167 | def forward(self, **kwargs):
168 | inputs = list()
169 | for k in self.keys:
170 | entry = kwargs[k].reshape(kwargs[k].shape[0], -1, 1, 1)
171 | inputs.append(entry)
172 | x = torch.cat(inputs, dim=1) # b, n_channels, 1, 1
173 | assert x.shape[1] == self.n_channels, f"expecting {self.n_channels} channels but got {x.shape[1]}"
174 | x = x.repeat(1, 1, self.shape[0], self.shape[1]) # duplicate spatially
175 | x = self.linear(x)
176 | return x
177 |
178 |
179 | def to_rgb(model, x):
180 | if not hasattr(model, "colorize"):
181 | model.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
182 | x = nn.functional.conv2d(x, weight=model.colorize)
183 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
184 | return x
185 |
--------------------------------------------------------------------------------
/geofree/modules/vqvae/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/modules/vqvae/__init__.py
--------------------------------------------------------------------------------
/geofree/modules/vqvae/quantize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch import einsum
5 |
6 |
7 | class VectorQuantizer(nn.Module):
8 | """
9 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
10 | ____________________________________________
11 | Discretization bottleneck part of the VQ-VAE.
12 | Inputs:
13 | - n_e : number of embeddings
14 | - e_dim : dimension of embedding
15 | - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
16 | _____________________________________________
17 | """
18 |
19 | def __init__(self, n_e, e_dim, beta):
20 | super(VectorQuantizer, self).__init__()
21 | self.n_e = n_e
22 | self.e_dim = e_dim
23 | self.beta = beta
24 |
25 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
26 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
27 |
28 | def forward(self, z):
29 | """
30 | Inputs the output of the encoder network z and maps it to a discrete
31 | one-hot vector that is the index of the closest embedding vector e_j
32 | z (continuous) -> z_q (discrete)
33 | z.shape = (batch, channel, height, width)
34 | quantization pipeline:
35 | 1. get encoder input (B,C,H,W)
36 | 2. flatten input to (B*H*W,C)
37 | """
38 | # reshape z -> (batch, height, width, channel) and flatten
39 | z = z.permute(0, 2, 3, 1).contiguous()
40 | z_flattened = z.view(-1, self.e_dim)
41 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
42 |
43 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
44 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
45 | torch.matmul(z_flattened, self.embedding.weight.t())
46 |
47 | # find closest encodings
48 | min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
49 |
50 | min_encodings = torch.zeros(
51 | min_encoding_indices.shape[0], self.n_e).to(z)
52 | min_encodings.scatter_(1, min_encoding_indices, 1)
53 |
54 | # get quantized latent vectors
55 | z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
56 |
57 | # compute loss for embedding
58 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
59 | torch.mean((z_q - z.detach()) ** 2)
60 |
61 | # preserve gradients
62 | z_q = z + (z_q - z).detach()
63 |
64 | # perplexity
65 | e_mean = torch.mean(min_encodings, dim=0)
66 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
67 |
68 | # reshape back to match original input shape
69 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
70 |
71 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
72 |
73 | def get_codebook_entry(self, indices, shape):
74 | # shape specifying (batch, height, width, channel)
75 | min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
76 | min_encodings.scatter_(1, indices[:,None], 1)
77 |
78 | # get quantized latent vectors
79 | z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
80 |
81 | if shape is not None:
82 | z_q = z_q.view(shape)
83 |
84 | # reshape back to match original input shape
85 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
86 |
87 | return z_q
88 |
--------------------------------------------------------------------------------
/geofree/modules/warp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/geometry-free-view-synthesis/00dc639c98dfb9246bee0009649c5be8f8b58e1e/geofree/modules/warp/__init__.py
--------------------------------------------------------------------------------
/geofree/modules/warp/midas.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from splatting import splatting_function
8 |
9 | # pretty much like eval_re_midas but b,c,h,w format and module which should be
10 | # finetunable
11 |
12 | def render_forward(src_ims, src_dms,
13 | R, t,
14 | K_src_inv,
15 | K_dst,
16 | alpha=None):
17 | # R: b,3,3
18 | # t: b,3
19 | # K_dst: b,3,3
20 | # K_src_inv: b,3,3
21 | t = t[...,None]
22 |
23 | #######
24 |
25 | assert len(src_ims.shape) == 4 # b,c,h,w
26 | assert len(src_dms.shape) == 3 # b,h,w
27 | assert src_ims.shape[2:4] == src_dms.shape[1:3], (src_ims.shape,
28 | src_dms.shape)
29 |
30 | x = np.arange(src_ims.shape[3])
31 | y = np.arange(src_ims.shape[2])
32 | coord = np.stack(np.meshgrid(x,y), -1)
33 | coord = np.concatenate((coord, np.ones_like(coord)[:,:,[0]]), -1) # z=1
34 | coord = coord.astype(np.float32)
35 | coord = torch.as_tensor(coord, dtype=K_dst.dtype, device=K_dst.device)
36 | coord = coord[None] # b,h,w,3
37 |
38 | D = src_dms[:,:,:,None,None] # b,h,w,1,1
39 |
40 | points = K_dst[:,None,None,...]@(R[:,None,None,...]@(D*K_src_inv[:,None,None,...]@coord[:,:,:,:,None])+t[:,None,None,:,:])
41 | points = points.squeeze(-1)
42 |
43 | new_z = points[:,:,:,[2]].clone().permute(0,3,1,2) # b,1,h,w
44 | points = points/torch.clamp(points[:,:,:,[2]], 1e-8, None)
45 |
46 | flow = points - coord
47 | flow = flow.permute(0,3,1,2)[:,:2,...]
48 |
49 | if alpha is not None:
50 | # used to be 50 but this is unstable even if we subtract the maximum
51 | importance = alpha/new_z
52 | #importance = importance-importance.amin((1,2,3),keepdim=True)
53 | importance = importance.exp()
54 | else:
55 | # use heuristic to rescale import between 0 and 10 to be stable in
56 | # float32
57 | importance = 1.0/new_z
58 | importance_min = importance.amin((1,2,3),keepdim=True)
59 | importance_max = importance.amax((1,2,3),keepdim=True)
60 | importance=(importance-importance_min)/(importance_max-importance_min+1e-6)*10-10
61 | importance = importance.exp()
62 |
63 | input_data = torch.cat([importance*src_ims, importance], 1)
64 | output_data = splatting_function("summation", input_data, flow)
65 |
66 | num = output_data[:,:-1,:,:]
67 | nom = output_data[:,-1:,:,:]
68 |
69 | #rendered = num/(nom+1e-7)
70 | rendered = num/nom.clamp(min=1e-8)
71 | return rendered
72 |
73 |
74 |
75 | class Midas(nn.Module):
76 | def __init__(self):
77 | super().__init__()
78 | midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
79 |
80 | self.midas = midas
81 |
82 | # parameters to reproduce the provided transform
83 | mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)
84 | mean = mean.reshape(1,3,1,1)
85 | self.register_buffer("mean", mean)
86 | std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)
87 | std = std.reshape(1,3,1,1)
88 | self.register_buffer("std", std)
89 | self.__height = 384
90 | self.__width = 384
91 | self.__keep_aspect_ratio = True
92 | self.__resize_method = "upper_bound"
93 | self.__multiple_of = 32
94 |
95 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
96 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
97 |
98 | if max_val is not None and y > max_val:
99 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
100 |
101 | if y < min_val:
102 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
103 |
104 | return y
105 |
106 | def get_size(self, width, height):
107 | # determine new height and width
108 | scale_height = self.__height / height
109 | scale_width = self.__width / width
110 |
111 | if self.__keep_aspect_ratio:
112 | if self.__resize_method == "lower_bound":
113 | # scale such that output size is lower bound
114 | if scale_width > scale_height:
115 | # fit width
116 | scale_height = scale_width
117 | else:
118 | # fit height
119 | scale_width = scale_height
120 | elif self.__resize_method == "upper_bound":
121 | # scale such that output size is upper bound
122 | if scale_width < scale_height:
123 | # fit width
124 | scale_height = scale_width
125 | else:
126 | # fit height
127 | scale_width = scale_height
128 | elif self.__resize_method == "minimal":
129 | # scale as least as possbile
130 | if abs(1 - scale_width) < abs(1 - scale_height):
131 | # fit width
132 | scale_height = scale_width
133 | else:
134 | # fit height
135 | scale_width = scale_height
136 | else:
137 | raise ValueError(
138 | f"resize_method {self.__resize_method} not implemented"
139 | )
140 |
141 | if self.__resize_method == "lower_bound":
142 | new_height = self.constrain_to_multiple_of(
143 | scale_height * height, min_val=self.__height
144 | )
145 | new_width = self.constrain_to_multiple_of(
146 | scale_width * width, min_val=self.__width
147 | )
148 | elif self.__resize_method == "upper_bound":
149 | new_height = self.constrain_to_multiple_of(
150 | scale_height * height, max_val=self.__height
151 | )
152 | new_width = self.constrain_to_multiple_of(
153 | scale_width * width, max_val=self.__width
154 | )
155 | elif self.__resize_method == "minimal":
156 | new_height = self.constrain_to_multiple_of(scale_height * height)
157 | new_width = self.constrain_to_multiple_of(scale_width * width)
158 | else:
159 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
160 |
161 | return (new_width, new_height)
162 |
163 | def resize(self, x):
164 | assert len(x.shape)==4
165 | assert x.shape[1]==3, x.shape
166 |
167 | width, height = self.get_size(
168 | x.shape[3], x.shape[2]
169 | )
170 | x = torch.nn.functional.interpolate(
171 | x,
172 | size=(height,width),
173 | mode="bicubic",
174 | align_corners=False,
175 | )
176 | return x
177 |
178 | def __call__(self, x, clamp=True, out_size="original"):
179 | assert len(x.shape)==4, x.shape
180 | assert x.shape[1]==3, x.shape
181 | assert -1.0 <= x.min() <= x.max() <= 1.0
182 |
183 | # replace provided transform by differentiable one supporting batches
184 | if out_size == "original":
185 | out_size = x.shape[2:4]
186 | # to [0,1]
187 | x = (x+1.0)/2.0
188 | # resize
189 | x = self.resize(x)
190 | # normalize (x-mean)/std
191 | x = (x - self.mean)/self.std
192 | # prepare = transpose to (b,c,h,w)
193 |
194 | x = self.midas(x)
195 |
196 | if out_size is not None:
197 | x = torch.nn.functional.interpolate(
198 | x.unsqueeze(1),
199 | size=out_size,
200 | mode="bicubic",
201 | align_corners=False,
202 | ).squeeze(1)
203 |
204 | if clamp:
205 | # negative values due to resizing
206 | x = x.clamp(min=1e-8)
207 |
208 | return x
209 |
210 | def scaled_depth(self, x, points, return_inverse_depth=False):
211 | b,c,h,w = x.shape
212 | assert c==3, c
213 | b_,_,c_ = points.shape
214 | assert b==b_
215 | assert c_==3
216 |
217 | dm = self(x)
218 |
219 | xys = points[:,:,:2]
220 | xys = xys.round().to(dtype=torch.long)
221 | xys[:,:,0] = xys[:,:,0].clamp(min=0,max=w-1)
222 | xys[:,:,1] = xys[:,:,1].clamp(min=0,max=h-1)
223 | indices = xys[:,:,1]*w+xys[:,:,0] # b,N
224 | flatdm = torch.zeros(b,h*w, dtype=dm.dtype, device=dm.device)
225 | flatz = points[:,:,2] # b,N
226 | flatdm.scatter_(dim=1, index=indices, src=flatz)
227 | sparse_dm = flatdm.reshape(b,h,w)
228 | mask = (sparse_dm<1e-3).to(dtype=sparse_dm.dtype,
229 | device=sparse_dm.device)
230 |
231 | #error = (1-mask)*(dm[i]-(scale*dm+offset))**2
232 | N = (1-mask).sum(dim=(1,2)) # b
233 | m_sparse_dm = (1-mask)*sparse_dm # was mdm
234 | m_sparse_dm[(1-mask)>0] = 1.0/m_sparse_dm[(1-mask)>0] # we align disparity
235 | m_dm = (1-mask)*dm # was mmi
236 | s = ((m_dm*m_sparse_dm).sum(dim=(1,2))-1/N*m_sparse_dm.sum(dim=(1,2))*m_dm.sum(dim=(1,2))) / ((m_dm**2).sum(dim=(1,2))-1/N*(m_dm.sum(dim=(1,2))**2))
237 | c = 1/N*(m_sparse_dm.sum(dim=(1,2))-s*m_dm.sum(dim=(1,2)))
238 |
239 | scaled_dm = s[:,None,None]*dm + c[:,None,None]
240 | scaled_dm = scaled_dm.clamp(min=1e-8)
241 | if not return_inverse_depth:
242 | scaled_dm[scaled_dm!=0] = 1.0/scaled_dm[scaled_dm!=0] # disparity to depth
243 |
244 | return scaled_dm
245 |
246 | def fixed_scale_depth(self, x, return_inverse_depth=False, scale=[0.18577382, 0.93059154]):
247 | b,c,h,w = x.shape
248 | assert c==3, c
249 |
250 | dm = self(x)
251 | dmmin = dm.amin(dim=(1,2), keepdim=True)
252 | dmmax = dm.amax(dim=(1,2), keepdim=True)
253 | scaled_dm = (dm-dmmin)/(dmmax-dmmin)*(scale[1]-scale[0])+scale[0]
254 |
255 | if not return_inverse_depth:
256 | scaled_dm[scaled_dm!=0] = 1.0/scaled_dm[scaled_dm!=0] # disparity to depth
257 |
258 | return scaled_dm
259 |
260 | def warp(self, x, points, R, t, K_src_inv, K_dst):
261 | src_dms = self.scaled_depth(x, points)
262 | wrp = render_forward(src_ims=x, src_dms=src_dms,
263 | R=R, t=t,
264 | K_src_inv=K_src_inv, K_dst=K_dst)
265 | return wrp, src_dms
266 |
267 | def warp_features(self, f, x, points, R, t, K_src_inv, K_dst,
268 | no_depth_grad=False, boltzmann_factor=None):
269 | b,c,h,w = f.shape
270 |
271 | context = torch.no_grad() if no_depth_grad else nullcontext()
272 | with context:
273 | src_dms = self.scaled_depth(x, points)
274 |
275 | # rescale depth map to feature map size
276 | src_dms = torch.nn.functional.interpolate(
277 | src_dms.unsqueeze(1),
278 | size=(h,w),
279 | mode="bicubic",
280 | align_corners=False,
281 | ).squeeze(1)
282 |
283 |
284 | # rescale intrinsics to feature map size
285 | K_dst = K_dst.clone()
286 | K_dst[:,0,:] *= f.shape[3]/x.shape[3]
287 | K_dst[:,1,:] *= f.shape[2]/x.shape[2]
288 | K_src_inv = K_src_inv.clone()
289 | K_src_inv[:,0,0] /= f.shape[3]/x.shape[3]
290 | K_src_inv[:,1,1] /= f.shape[3]/x.shape[3]
291 |
292 | wrp = render_forward(src_ims=f, src_dms=src_dms,
293 | R=R, t=t,
294 | K_src_inv=K_src_inv, K_dst=K_dst,
295 | alpha=boltzmann_factor)
296 | return wrp, src_dms
297 |
--------------------------------------------------------------------------------
/geofree/util.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import requests
3 | from tqdm import tqdm
4 | import torch
5 | import os
6 | from omegaconf import OmegaConf
7 |
8 | from geofree.models.transformers.geogpt import GeoTransformer
9 |
10 |
11 | URL_MAP = {
12 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1",
13 | "re_first_stage": "https://heibox.uni-heidelberg.de/f/6db3e2beebe34c1e8d06/?dl=1",
14 | "re_depth_stage": "https://heibox.uni-heidelberg.de/f/c5e51b91377942d7be18/?dl=1",
15 | "re_impl_depth_config": "https://heibox.uni-heidelberg.de/f/234e2e21a6414690b663/?dl=1",
16 | "re_impl_depth": "https://heibox.uni-heidelberg.de/f/740100d4cdbd46d4bb39/?dl=1",
17 | "re_impl_nodepth_config": "https://heibox.uni-heidelberg.de/f/21989b42cea544bbbb9e/?dl=1",
18 | "re_impl_nodepth": "https://heibox.uni-heidelberg.de/f/b909c3d4ac7143209387/?dl=1",
19 | "ac_first_stage": "https://heibox.uni-heidelberg.de/f/f21118f64fde4134917b/?dl=1",
20 | "ac_depth_stage": "https://heibox.uni-heidelberg.de/f/fd4bd78496884990b75e/?dl=1",
21 | "ac_impl_depth_config": "https://heibox.uni-heidelberg.de/f/6e3081df54e245cd9aaa/?dl=1",
22 | "ac_impl_depth": "https://heibox.uni-heidelberg.de/f/6e3081df54e245cd9aaa/?dl=1",
23 | "ac_impl_nodepth_config": "https://heibox.uni-heidelberg.de/f/0a2749e895784c0099a4/?dl=1",
24 | "ac_impl_nodepth": "https://heibox.uni-heidelberg.de/f/bead9082ed0c425fb6b4/?dl=1",
25 | }
26 |
27 | CKPT_MAP = {
28 | "vgg_lpips": "geofree/lpips/vgg.pth",
29 | "re_first_stage": "geofree/re_first_stage/last.ckpt",
30 | "re_depth_stage": "geofree/re_depth_stage/last.ckpt",
31 | "re_impl_depth_config": "geofree/re_impl_depth/config.yaml",
32 | "re_impl_depth": "geofree/re_impl_depth/last.ckpt",
33 | "re_impl_nodepth_config": "geofree/re_impl_nodepth/config.yaml",
34 | "re_impl_nodepth": "geofree/re_impl_nodepth/last.ckpt",
35 | "ac_first_stage": "geofree/ac_first_stage/last.ckpt",
36 | "ac_depth_stage": "geofree/ac_depth_stage/last.ckpt",
37 | "ac_impl_depth_config": "geofree/ac_impl_depth/config.yaml",
38 | "ac_impl_depth": "geofree/ac_impl_depth/last.ckpt",
39 | "ac_impl_nodepth_config": "geofree/ac_impl_nodepth/config.yaml",
40 | "ac_impl_nodepth": "geofree/ac_impl_nodepth/last.ckpt",
41 | }
42 | CACHE = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
43 | CKPT_MAP = dict((k, os.path.join(CACHE, CKPT_MAP[k])) for k in CKPT_MAP)
44 |
45 | MD5_MAP = {
46 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a",
47 | "re_first_stage": "b8b999aba6b618757329c1f114e9f5a5",
48 | "re_depth_stage": "ab35861b9050476d02fa8f4c8f761d61",
49 | "re_impl_depth_config": "e75df96667a3a6022ca2d5b27515324b",
50 | "re_impl_depth": "144dcdeb1379760d2f1ae763cabcac85",
51 | "re_impl_nodepth_config": "351c976463c4740fc575a7c64a836624",
52 | "re_impl_nodepth": "b6646db26a756b80840aa2b6aca924d8",
53 | "ac_first_stage": "657f683698477be24254b2be85fbced0",
54 | "ac_depth_stage": "e75ef829cf72be4f5f252450bdeb10db",
55 | "ac_impl_depth_config": "79b38a57fe4a195165a79c795492092c",
56 | "ac_impl_depth": "e92673e24acb28977d74a2fc106bcfb1",
57 | "ac_impl_nodepth_config": "22e1a55122cef561597cc5c040d45fba",
58 | "ac_impl_nodepth": "22e1a55122cef561597cc5c040d45fba",
59 | }
60 |
61 |
62 | def download(url, local_path, chunk_size=1024):
63 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
64 | with requests.get(url, stream=True) as r:
65 | total_size = int(r.headers.get("content-length", 0))
66 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
67 | with open(local_path, "wb") as f:
68 | for data in r.iter_content(chunk_size=chunk_size):
69 | if data:
70 | f.write(data)
71 | pbar.update(chunk_size)
72 |
73 |
74 | def md5_hash(path):
75 | with open(path, "rb") as f:
76 | content = f.read()
77 | return hashlib.md5(content).hexdigest()
78 |
79 |
80 | def get_local_path(name, root=None, check=False):
81 | path = CKPT_MAP[name]
82 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
83 | assert name in URL_MAP, name
84 | print("Downloading {} from {} to {}".format(name, URL_MAP[name], path))
85 | download(URL_MAP[name], path)
86 | md5 = md5_hash(path)
87 | assert md5 == MD5_MAP[name], md5
88 | return path
89 |
90 |
91 | def pretrained_models(model="re_impl_nodepth"):
92 | prefix = model[:2]
93 | assert prefix in ["re", "ac"], "not implemented"
94 |
95 | config_path = get_local_path(model+"_config")
96 | config = OmegaConf.load(config_path)
97 | config.model.params.first_stage_config.params["ckpt_path"] = get_local_path(f"{prefix}_first_stage")
98 | config.model.params.depth_stage_config.params["ckpt_path"] = get_local_path(f"{prefix}_depth_stage")
99 |
100 | ckpt_path = get_local_path(model)
101 |
102 | model = GeoTransformer(**config.model.params)
103 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
104 | missing, unexpected = model.load_state_dict(sd, strict=False)
105 | model = model.eval()
106 | print(f"Restored model from {ckpt_path}")
107 | return model
108 |
--------------------------------------------------------------------------------
/scripts/braindance.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import sys, os
3 | import argparse
4 | import pygame
5 | import torch
6 | import numpy as np
7 | from PIL import Image
8 | from splatting import splatting_function
9 | from torch.utils.data.dataloader import default_collate
10 | from geofree import pretrained_models
11 | import imageio
12 |
13 | torch.set_grad_enabled(False)
14 |
15 | from geofree.modules.warp.midas import Midas
16 |
17 | from tkinter.filedialog import askopenfilename
18 | from tkinter import Tk
19 |
20 |
21 | def to_surface(x, text=None):
22 | if hasattr(x, "detach"):
23 | x = x.detach().cpu().numpy()
24 | x = x.transpose(1,0,2)
25 | x = (x+1.0)*127.5
26 | x = x.clip(0, 255).astype(np.uint8)
27 | if text is not None:
28 | from PIL import ImageDraw, ImageFont
29 | fontsize=22
30 | try:
31 | font = ImageFont.truetype("/usr/share/fonts/liberation/LiberationSans-BoldItalic.ttf", fontsize)
32 | except OSError:
33 | font = ImageFont.load_default()
34 | margin = 8
35 |
36 | x = x.transpose(1,0,2)
37 | pos = (margin, x.shape[0]-fontsize-margin//2)
38 | x = x.astype(np.float)
39 | x[x.shape[0]-fontsize-margin:,:,:] *= 0.5
40 | x = x.astype(np.uint8)
41 |
42 | img = Image.fromarray(x)
43 | ImageDraw.Draw(img).text(pos, f'{text}', (255, 255, 255), font=font) # coordinates, text, color, font
44 | x = np.array(img)
45 | x = x.transpose(1,0,2)
46 | return pygame.surfarray.make_surface(x), x.transpose(1,0,2)
47 |
48 |
49 | def render_forward(src_ims, src_dms,
50 | Rcam, tcam,
51 | K_src,
52 | K_dst):
53 | Rcam = Rcam.to(device=src_ims.device)[None]
54 | tcam = tcam.to(device=src_ims.device)[None]
55 |
56 | R = Rcam
57 | t = tcam[...,None]
58 | K_src_inv = K_src.inverse()
59 |
60 | assert len(src_ims.shape) == 4
61 | assert len(src_dms.shape) == 3
62 | assert src_ims.shape[1:3] == src_dms.shape[1:3], (src_ims.shape,
63 | src_dms.shape)
64 |
65 | x = np.arange(src_ims[0].shape[1])
66 | y = np.arange(src_ims[0].shape[0])
67 | coord = np.stack(np.meshgrid(x,y), -1)
68 | coord = np.concatenate((coord, np.ones_like(coord)[:,:,[0]]), -1) # z=1
69 | coord = coord.astype(np.float32)
70 | coord = torch.as_tensor(coord, dtype=K_src.dtype, device=K_src.device)
71 | coord = coord[None] # bs, h, w, 3
72 |
73 | D = src_dms[:,:,:,None,None]
74 |
75 | points = K_dst[None,None,None,...]@(R[:,None,None,...]@(D*K_src_inv[None,None,None,...]@coord[:,:,:,:,None])+t[:,None,None,:,:])
76 | points = points.squeeze(-1)
77 |
78 | new_z = points[:,:,:,[2]].clone().permute(0,3,1,2) # b,1,h,w
79 | points = points/torch.clamp(points[:,:,:,[2]], 1e-8, None)
80 |
81 | src_ims = src_ims.permute(0,3,1,2)
82 | flow = points - coord
83 | flow = flow.permute(0,3,1,2)[:,:2,...]
84 |
85 | alpha = 0.5
86 | importance = alpha/new_z
87 | importance_min = importance.amin((1,2,3),keepdim=True)
88 | importance_max = importance.amax((1,2,3),keepdim=True)
89 | importance=(importance-importance_min)/(importance_max-importance_min+1e-6)*10-10
90 | importance = importance.exp()
91 |
92 | input_data = torch.cat([importance*src_ims, importance], 1)
93 | output_data = splatting_function("summation", input_data, flow)
94 |
95 | num = torch.sum(output_data[:,:-1,:,:], dim=0, keepdim=True)
96 | nom = torch.sum(output_data[:,-1:,:,:], dim=0, keepdim=True)
97 |
98 | rendered = num/(nom+1e-7)
99 | rendered = rendered.permute(0,2,3,1)[0,...]
100 | return rendered
101 |
102 | def normalize(x):
103 | return x/np.linalg.norm(x)
104 |
105 | def cosd(x):
106 | return np.cos(np.deg2rad(x))
107 |
108 | def sind(x):
109 | return np.sin(np.deg2rad(x))
110 |
111 | def look_to(camera_pos, camera_dir, camera_up):
112 | camera_right = normalize(np.cross(camera_up, camera_dir))
113 | R = np.zeros((4, 4))
114 | R[0,0:3] = normalize(camera_right)
115 | R[1,0:3] = normalize(np.cross(camera_dir, camera_right))
116 | R[2,0:3] = normalize(camera_dir)
117 | R[3,3] = 1
118 | trans_matrix = np.array([[1.0, 0.0, 0.0, -camera_pos[0]],
119 | [0.0, 1.0, 0.0, -camera_pos[1]],
120 | [0.0, 0.0, 1.0, -camera_pos[2]],
121 | [0.0, 0.0, 0.0, 1.0]])
122 | tmp = R@trans_matrix
123 | return tmp[:3,:3], tmp[:3,3]
124 |
125 | def rotate_around_axis(angle, axis):
126 | axis = normalize(axis)
127 | rotation = np.array([[cosd(angle)+axis[0]**2*(1-cosd(angle)),
128 | axis[0]*axis[1]*(1-cosd(angle))-axis[2]*sind(angle),
129 | axis[0]*axis[2]*(1-cosd(angle))+axis[1]*sind(angle)],
130 | [axis[1]*axis[0]*(1-cosd(angle))+axis[2]*sind(angle),
131 | cosd(angle)+axis[1]**2*(1-cosd(angle)),
132 | axis[1]*axis[2]*(1-cosd(angle))-axis[0]*sind(angle)],
133 | [axis[2]*axis[0]*(1-cosd(angle))-axis[1]*sind(angle),
134 | axis[2]*axis[1]*(1-cosd(angle))+axis[0]*sind(angle),
135 | cosd(angle)+axis[2]**2*(1-cosd(angle))]])
136 | return rotation
137 |
138 |
139 | class Renderer(object):
140 | def __init__(self, model, device):
141 | self.model = pretrained_models(model=model)
142 | self.model = self.model.to(device=device)
143 | self._active = False
144 | # rough estimates for min and maximum inverse depth values on the
145 | # training datasets
146 | if model.startswith("re"):
147 | self.scale = [0.18577382, 0.93059154]
148 | else:
149 | self.scale = [1e-8, 0.75]
150 |
151 | def init(self,
152 | start_im,
153 | example,
154 | show_R,
155 | show_t):
156 | self._active = True
157 | self.step = 0
158 |
159 | batch = self.batch = default_collate([example])
160 | batch["R_rel"] = show_R[None,...]
161 | batch["t_rel"] = show_t[None,...]
162 |
163 | _, cdict, edict = self.model.get_xce(batch)
164 | for k in cdict:
165 | cdict[k] = cdict[k].to(device=self.model.device)
166 | for k in edict:
167 | edict[k] = edict[k].to(device=self.model.device)
168 |
169 | quant_d, quant_c, dc_indices, embeddings = self.model.get_normalized_c(
170 | cdict, edict, fixed_scale=True, scale=self.scale)
171 |
172 | start_im = start_im[None,...].to(self.model.device).permute(0,3,1,2)
173 | quant_c, c_indices = self.model.encode_to_c(c=start_im)
174 | cond_rec = self.model.cond_stage_model.decode(quant_c)
175 |
176 | self.current_im = cond_rec.permute(0,2,3,1)[0]
177 | self.current_sample = c_indices
178 |
179 | self.quant_c = quant_c # to know shape
180 | # for sampling
181 | self.dc_indices = dc_indices
182 | self.embeddings = embeddings
183 |
184 | def __call__(self):
185 | if self.step < self.current_sample.shape[1]:
186 | z_start_indices = self.current_sample[:, :self.step]
187 | temperature=None
188 | top_k=250
189 | callback=None
190 | index_sample = self.model.sample(z_start_indices, self.dc_indices,
191 | steps=1,
192 | temperature=temperature if temperature is not None else 1.0,
193 | sample=True,
194 | top_k=top_k if top_k is not None else 100,
195 | callback=callback if callback is not None else lambda k: None,
196 | embeddings=self.embeddings)
197 | self.current_sample = torch.cat((index_sample,
198 | self.current_sample[:,self.step+1:]),
199 | dim=1)
200 |
201 | sample_dec = self.model.decode_to_img(self.current_sample,
202 | self.quant_c.shape)
203 | self.current_im = sample_dec.permute(0,2,3,1)[0]
204 | self.step += 1
205 |
206 | if self.step >= self.current_sample.shape[1]:
207 | self._active = False
208 |
209 | return self.current_im
210 |
211 | def active(self):
212 | return self._active
213 |
214 | def reconstruct(self, x):
215 | x = x.to(self.model.device).permute(0,3,1,2)
216 | quant_c, c_indices = self.model.encode_to_c(c=x)
217 | x_rec = self.model.cond_stage_model.decode(quant_c)
218 | return x_rec.permute(0,2,3,1)
219 |
220 |
221 | def load_as_example(path, model="re"):
222 | size = [208, 368]
223 | im = Image.open(path)
224 | w,h = im.size
225 | if np.abs(w/h - size[1]/size[0]) > 0.1:
226 | print(f"Center cropping {path} to AR {size[1]/size[0]}")
227 | if w/h < size[1]/size[0]:
228 | # crop h
229 | left = 0
230 | right = w
231 | top = h/2 - size[0]/size[1]*w/2
232 | bottom = h/2 + size[0]/size[1]*w/2
233 | else:
234 | # crop w
235 | top = 0
236 | bottom = h
237 | left = w/2 - size[1]/size[0]*h
238 | right = w/2 + size[1]/size[0]*h
239 | im = im.crop(box=(left, top, right, bottom))
240 |
241 | im = im.resize((size[1],size[0]),
242 | resample=Image.LANCZOS)
243 | im = np.array(im)/127.5-1.0
244 | im = im.astype(np.float32)
245 |
246 | example = dict()
247 | example["src_img"] = im
248 | if model.startswith("re"):
249 | example["K"] = np.array([[184.0, 0.0, 184.0],
250 | [0.0, 184.0, 104.0],
251 | [0.0, 0.0, 1.0]], dtype=np.float32)
252 | elif model.startswith("ac"):
253 | example["K"] = np.array([[200.0, 0.0, 184.0],
254 | [0.0, 200.0, 104.0],
255 | [0.0, 0.0, 1.0]], dtype=np.float32)
256 | else:
257 | raise NotImplementedError()
258 | example["K_inv"] = np.linalg.inv(example["K"])
259 |
260 | ## dummy data not used during inference
261 | example["dst_img"] = np.zeros_like(example["src_img"])
262 | example["src_points"] = np.zeros((1,3), dtype=np.float32)
263 |
264 | return example
265 |
266 |
267 | if __name__ == "__main__":
268 | helptxt = "What's up, BD-maniacs?\n\n"+"\n".join([
269 | "{: <12} {: <24}".format("key(s)", "action"),
270 | "="*37,
271 | "{: <12} {: <24}".format("wasd", "move around"),
272 | "{: <12} {: <24}".format("arrows", "look around"),
273 | "{: <12} {: <24}".format("m", "enable looking with mouse"),
274 | "{: <12} {: <24}".format("space", "render with transformer"),
275 | "{: <12} {: <24}".format("q", "quit"),
276 | ])
277 | parser = argparse.ArgumentParser(description=helptxt,
278 | formatter_class=argparse.RawTextHelpFormatter)
279 | parser.add_argument('path', type=str, nargs='?', default=None,
280 | help='path to image or directory from which to select '
281 | 'image. Default example is used if not specified.')
282 | parser.add_argument('--model', choices=["re_impl_nodepth", "re_impl_depth",
283 | "ac_impl_nodepth", "ac_impl_depth"],
284 | default="re_impl_nodepth",
285 | help='pretrained model to use.')
286 | parser.add_argument('--video', type=str, nargs='?', default=None,
287 | help='path to write video recording to. (no recording if unspecified).')
288 | opt = parser.parse_args()
289 | print(helptxt)
290 |
291 | if torch.cuda.is_available():
292 | device = torch.device("cuda")
293 | else:
294 | print("Warning: Running on CPU---sampling might take a while...")
295 | device = torch.device("cpu")
296 | midas = Midas().eval().to(device)
297 | # init transformer
298 | renderer = Renderer(model=opt.model, device=device)
299 |
300 | if opt.path is None:
301 | try:
302 | import importlib.resources as pkg_resources
303 | except ImportError:
304 | import importlib_resources as pkg_resources
305 |
306 | example_name = "artist.jpg" if opt.model.startswith("re") else "beach.jpg"
307 | with pkg_resources.path("geofree.examples", example_name) as path:
308 | example = load_as_example(path, model=opt.model)
309 | else:
310 | path = opt.path
311 | if not os.path.isfile(path):
312 | Tk().withdraw()
313 | path = askopenfilename(initialdir=sys.argv[1])
314 | example = load_as_example(path, model=opt.model)
315 |
316 | ims = example["src_img"][None,...]
317 | K = example["K"]
318 |
319 | # compute depth for preview
320 | dms = [None]
321 | for i in range(ims.shape[0]):
322 | midas_in = torch.tensor(ims[i])[None,...].permute(0,3,1,2).to(device)
323 | scaled_idepth = midas.fixed_scale_depth(midas_in,
324 | return_inverse_depth=True,
325 | scale=renderer.scale)
326 | dms[i] = 1.0/scaled_idepth[0].cpu().numpy()
327 |
328 | # now switch to pytorch
329 | src_ims = torch.tensor(ims, dtype=torch.float32)
330 | src_dms = torch.tensor(dms, dtype=torch.float32)
331 | K = torch.tensor(K, dtype=torch.float32)
332 |
333 | src_ims = src_ims.to(device=device)
334 | src_dms = src_dms.to(device=device)
335 | K = K.to(device=device)
336 |
337 | K_cam = K.clone().detach()
338 |
339 | RENDERING = False
340 | DISPLAY_REC = True
341 | if DISPLAY_REC:
342 | rec_ims = renderer.reconstruct(src_ims)
343 |
344 | # init pygame
345 | b,h,w,c = src_ims.shape
346 | pygame.init()
347 | display = (w, h)
348 | surface = pygame.display.set_mode(display)
349 | clock = pygame.time.Clock()
350 |
351 | # init camera
352 | camera_pos = np.array([0.0, 0.0, 0.0])
353 | camera_dir = np.array([0.0, 0.0, 1.0])
354 | camera_up = np.array([0.0, 1.0, 0.0])
355 | CAM_SPEED = 0.025
356 | CAM_SPEED_YAW = 0.5
357 | CAM_SPEED_PITCH = 0.25
358 | MOUSE_SENSITIVITY = 0.02
359 | USE_MOUSE = False
360 | if opt.model.startswith("ac"):
361 | CAM_SPEED *= 0.1
362 | CAM_SPEED_YAW *= 0.5
363 | CAM_SPEED_PITCH *= 0.5
364 |
365 | if opt.video is not None:
366 | writer = imageio.get_writer(opt.video, fps=40)
367 |
368 | step = 0
369 | step_PHASE = 0
370 | while True:
371 | ######## Boring stuff
372 | clock.tick(40)
373 | for event in pygame.event.get():
374 | if event.type == pygame.QUIT:
375 | pygame.quit()
376 | quit()
377 |
378 | keys = pygame.key.get_pressed()
379 | if keys[pygame.K_q]:
380 | if opt.video is not None:
381 | writer.close()
382 | pygame.quit()
383 | quit()
384 |
385 | ######### Camera
386 | camera_yaw = 0
387 | camera_pitch = 0
388 | if keys[pygame.K_a]:
389 | camera_pos += CAM_SPEED*normalize(np.cross(camera_dir, camera_up))
390 | if keys[pygame.K_d]:
391 | camera_pos -= CAM_SPEED*normalize(np.cross(camera_dir, camera_up))
392 | if keys[pygame.K_w]:
393 | camera_pos += CAM_SPEED*normalize(camera_dir)
394 | if keys[pygame.K_s]:
395 | camera_pos -= CAM_SPEED*normalize(camera_dir)
396 | if keys[pygame.K_PAGEUP]:
397 | camera_pos -= CAM_SPEED*normalize(camera_up)
398 | if keys[pygame.K_PAGEDOWN]:
399 | camera_pos += CAM_SPEED*normalize(camera_up)
400 |
401 | if keys[pygame.K_LEFT]:
402 | camera_yaw += CAM_SPEED_YAW
403 | if keys[pygame.K_RIGHT]:
404 | camera_yaw -= CAM_SPEED_YAW
405 | if keys[pygame.K_UP]:
406 | camera_pitch -= CAM_SPEED_PITCH
407 | if keys[pygame.K_DOWN]:
408 | camera_pitch += CAM_SPEED_PITCH
409 |
410 | if USE_MOUSE:
411 | dx, dy = pygame.mouse.get_rel()
412 | if not RENDERING:
413 | camera_yaw -= MOUSE_SENSITIVITY*dx
414 | camera_pitch += MOUSE_SENSITIVITY*dy
415 |
416 | if keys[pygame.K_PLUS]:
417 | CAM_SPEED += 0.1
418 | print(CAM_SPEED)
419 | if keys[pygame.K_MINUS]:
420 | CAM_SPEED -= 0.1
421 | print(CAM_SPEED)
422 |
423 | if keys[pygame.K_m]:
424 | if not USE_MOUSE:
425 | pygame.mouse.set_visible(False)
426 | pygame.event.set_grab(True)
427 | USE_MOUSE = True
428 | else:
429 | pygame.mouse.set_visible(True)
430 | pygame.event.set_grab(False)
431 | USE_MOUSE = False
432 |
433 | # adjust for yaw and pitch
434 | rotation = np.array([[cosd(-camera_yaw), 0.0, sind(-camera_yaw)],
435 | [0.0, 1.0, 0.0],
436 | [-sind(-camera_yaw), 0.0, cosd(-camera_yaw)]])
437 | camera_dir = rotation@camera_dir
438 |
439 | rotation = rotate_around_axis(camera_pitch, np.cross(camera_dir,
440 | camera_up))
441 | camera_dir = rotation@camera_dir
442 |
443 | show_R, show_t = look_to(camera_pos, camera_dir, camera_up) # look from pos in direction dir
444 | show_R = torch.as_tensor(show_R, dtype=torch.float32)
445 | show_t = torch.as_tensor(show_t, dtype=torch.float32)
446 |
447 | ############# /Camera
448 | ###### control rendering
449 | if keys[pygame.K_SPACE]:
450 | RENDERING = True
451 | renderer.init(wrp_im, example, show_R, show_t)
452 |
453 | PRESSED = False
454 | if any(keys[k] for k in [pygame.K_a, pygame.K_d, pygame.K_w,
455 | pygame.K_s]):
456 | RENDERING = False
457 |
458 | # display
459 | if not RENDERING:
460 | with torch.no_grad():
461 | wrp_im = render_forward(src_ims, src_dms,
462 | show_R, show_t,
463 | K_src=K,
464 | K_dst=K_cam)
465 | else:
466 | with torch.no_grad():
467 | wrp_im = renderer()
468 |
469 | text = "Sampling" if renderer._active else None
470 | image, frame = to_surface(wrp_im, text)
471 | surface.blit(image, (0,0))
472 | pygame.display.flip()
473 | if opt.video is not None:
474 | writer.append_data(frame)
475 |
476 | step +=1
477 |
--------------------------------------------------------------------------------
/scripts/database.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # * Redistributions of source code must retain the above copyright
8 | # notice, this list of conditions and the following disclaimer.
9 | #
10 | # * Redistributions in binary form must reproduce the above copyright
11 | # notice, this list of conditions and the following disclaimer in the
12 | # documentation and/or other materials provided with the distribution.
13 | #
14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 | # its contributors may be used to endorse or promote products derived
16 | # from this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 | # POSSIBILITY OF SUCH DAMAGE.
29 | #
30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31 |
32 | # This script is based on an original implementation by True Price.
33 |
34 | import sys
35 | import sqlite3
36 | import numpy as np
37 |
38 |
39 | IS_PYTHON3 = sys.version_info[0] >= 3
40 |
41 | MAX_IMAGE_ID = 2**31 - 1
42 |
43 | CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
44 | camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
45 | model INTEGER NOT NULL,
46 | width INTEGER NOT NULL,
47 | height INTEGER NOT NULL,
48 | params BLOB,
49 | prior_focal_length INTEGER NOT NULL)"""
50 |
51 | CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
52 | image_id INTEGER PRIMARY KEY NOT NULL,
53 | rows INTEGER NOT NULL,
54 | cols INTEGER NOT NULL,
55 | data BLOB,
56 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
57 |
58 | CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
59 | image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
60 | name TEXT NOT NULL UNIQUE,
61 | camera_id INTEGER NOT NULL,
62 | prior_qw REAL,
63 | prior_qx REAL,
64 | prior_qy REAL,
65 | prior_qz REAL,
66 | prior_tx REAL,
67 | prior_ty REAL,
68 | prior_tz REAL,
69 | CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
70 | FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
71 | """.format(MAX_IMAGE_ID)
72 |
73 | CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
74 | CREATE TABLE IF NOT EXISTS two_view_geometries (
75 | pair_id INTEGER PRIMARY KEY NOT NULL,
76 | rows INTEGER NOT NULL,
77 | cols INTEGER NOT NULL,
78 | data BLOB,
79 | config INTEGER NOT NULL,
80 | F BLOB,
81 | E BLOB,
82 | H BLOB)
83 | """
84 |
85 | CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
86 | image_id INTEGER PRIMARY KEY NOT NULL,
87 | rows INTEGER NOT NULL,
88 | cols INTEGER NOT NULL,
89 | data BLOB,
90 | FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
91 | """
92 |
93 | CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
94 | pair_id INTEGER PRIMARY KEY NOT NULL,
95 | rows INTEGER NOT NULL,
96 | cols INTEGER NOT NULL,
97 | data BLOB)"""
98 |
99 | CREATE_NAME_INDEX = \
100 | "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
101 |
102 | CREATE_ALL = "; ".join([
103 | CREATE_CAMERAS_TABLE,
104 | CREATE_IMAGES_TABLE,
105 | CREATE_KEYPOINTS_TABLE,
106 | CREATE_DESCRIPTORS_TABLE,
107 | CREATE_MATCHES_TABLE,
108 | CREATE_TWO_VIEW_GEOMETRIES_TABLE,
109 | CREATE_NAME_INDEX
110 | ])
111 |
112 |
113 | def image_ids_to_pair_id(image_id1, image_id2):
114 | if image_id1 > image_id2:
115 | image_id1, image_id2 = image_id2, image_id1
116 | return image_id1 * MAX_IMAGE_ID + image_id2
117 |
118 |
119 | def pair_id_to_image_ids(pair_id):
120 | image_id2 = pair_id % MAX_IMAGE_ID
121 | image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
122 | return image_id1, image_id2
123 |
124 |
125 | def array_to_blob(array):
126 | if IS_PYTHON3:
127 | return array.tostring()
128 | else:
129 | return np.getbuffer(array)
130 |
131 |
132 | def blob_to_array(blob, dtype, shape=(-1,)):
133 | if IS_PYTHON3:
134 | return np.fromstring(blob, dtype=dtype).reshape(*shape)
135 | else:
136 | return np.frombuffer(blob, dtype=dtype).reshape(*shape)
137 |
138 |
139 | class COLMAPDatabase(sqlite3.Connection):
140 |
141 | @staticmethod
142 | def connect(database_path):
143 | return sqlite3.connect(database_path, factory=COLMAPDatabase)
144 |
145 |
146 | def __init__(self, *args, **kwargs):
147 | super(COLMAPDatabase, self).__init__(*args, **kwargs)
148 |
149 | self.create_tables = lambda: self.executescript(CREATE_ALL)
150 | self.create_cameras_table = \
151 | lambda: self.executescript(CREATE_CAMERAS_TABLE)
152 | self.create_descriptors_table = \
153 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
154 | self.create_images_table = \
155 | lambda: self.executescript(CREATE_IMAGES_TABLE)
156 | self.create_two_view_geometries_table = \
157 | lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
158 | self.create_keypoints_table = \
159 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
160 | self.create_matches_table = \
161 | lambda: self.executescript(CREATE_MATCHES_TABLE)
162 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
163 |
164 | def add_camera(self, model, width, height, params,
165 | prior_focal_length=False, camera_id=None):
166 | params = np.asarray(params, np.float64)
167 | cursor = self.execute(
168 | "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
169 | (camera_id, model, width, height, array_to_blob(params),
170 | prior_focal_length))
171 | return cursor.lastrowid
172 |
173 | def add_image(self, name, camera_id,
174 | prior_q=np.full(4, np.NaN), prior_t=np.full(3, np.NaN), image_id=None):
175 | cursor = self.execute(
176 | "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
177 | (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
178 | prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
179 | return cursor.lastrowid
180 |
181 | def add_keypoints(self, image_id, keypoints):
182 | assert(len(keypoints.shape) == 2)
183 | assert(keypoints.shape[1] in [2, 4, 6])
184 |
185 | keypoints = np.asarray(keypoints, np.float32)
186 | self.execute(
187 | "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
188 | (image_id,) + keypoints.shape + (array_to_blob(keypoints),))
189 |
190 | def add_descriptors(self, image_id, descriptors):
191 | descriptors = np.ascontiguousarray(descriptors, np.uint8)
192 | self.execute(
193 | "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
194 | (image_id,) + descriptors.shape + (array_to_blob(descriptors),))
195 |
196 | def add_matches(self, image_id1, image_id2, matches):
197 | assert(len(matches.shape) == 2)
198 | assert(matches.shape[1] == 2)
199 |
200 | if image_id1 > image_id2:
201 | matches = matches[:,::-1]
202 |
203 | pair_id = image_ids_to_pair_id(image_id1, image_id2)
204 | matches = np.asarray(matches, np.uint32)
205 | self.execute(
206 | "INSERT INTO matches VALUES (?, ?, ?, ?)",
207 | (pair_id,) + matches.shape + (array_to_blob(matches),))
208 |
209 | def add_two_view_geometry(self, image_id1, image_id2, matches,
210 | F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2):
211 | assert(len(matches.shape) == 2)
212 | assert(matches.shape[1] == 2)
213 |
214 | if image_id1 > image_id2:
215 | matches = matches[:,::-1]
216 |
217 | pair_id = image_ids_to_pair_id(image_id1, image_id2)
218 | matches = np.asarray(matches, np.uint32)
219 | F = np.asarray(F, dtype=np.float64)
220 | E = np.asarray(E, dtype=np.float64)
221 | H = np.asarray(H, dtype=np.float64)
222 | self.execute(
223 | "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
224 | (pair_id,) + matches.shape + (array_to_blob(matches), config,
225 | array_to_blob(F), array_to_blob(E), array_to_blob(H)))
226 |
227 |
228 | def example_usage():
229 | import os
230 | import argparse
231 |
232 | parser = argparse.ArgumentParser()
233 | parser.add_argument("--database_path", default="database.db")
234 | args = parser.parse_args()
235 |
236 | if os.path.exists(args.database_path):
237 | print("ERROR: database path already exists -- will not modify it.")
238 | return
239 |
240 | # Open the database.
241 |
242 | db = COLMAPDatabase.connect(args.database_path)
243 |
244 | # For convenience, try creating all the tables upfront.
245 |
246 | db.create_tables()
247 |
248 | # Create dummy cameras.
249 |
250 | model1, width1, height1, params1 = \
251 | 0, 1024, 768, np.array((1024., 512., 384.))
252 | model2, width2, height2, params2 = \
253 | 2, 1024, 768, np.array((1024., 512., 384., 0.1))
254 |
255 | camera_id1 = db.add_camera(model1, width1, height1, params1)
256 | camera_id2 = db.add_camera(model2, width2, height2, params2)
257 |
258 | # Create dummy images.
259 |
260 | image_id1 = db.add_image("image1.png", camera_id1)
261 | image_id2 = db.add_image("image2.png", camera_id1)
262 | image_id3 = db.add_image("image3.png", camera_id2)
263 | image_id4 = db.add_image("image4.png", camera_id2)
264 |
265 | # Create dummy keypoints.
266 | #
267 | # Note that COLMAP supports:
268 | # - 2D keypoints: (x, y)
269 | # - 4D keypoints: (x, y, theta, scale)
270 | # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22)
271 |
272 | num_keypoints = 1000
273 | keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1)
274 | keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1)
275 | keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2)
276 | keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2)
277 |
278 | db.add_keypoints(image_id1, keypoints1)
279 | db.add_keypoints(image_id2, keypoints2)
280 | db.add_keypoints(image_id3, keypoints3)
281 | db.add_keypoints(image_id4, keypoints4)
282 |
283 | # Create dummy matches.
284 |
285 | M = 50
286 | matches12 = np.random.randint(num_keypoints, size=(M, 2))
287 | matches23 = np.random.randint(num_keypoints, size=(M, 2))
288 | matches34 = np.random.randint(num_keypoints, size=(M, 2))
289 |
290 | db.add_matches(image_id1, image_id2, matches12)
291 | db.add_matches(image_id2, image_id3, matches23)
292 | db.add_matches(image_id3, image_id4, matches34)
293 |
294 | # Commit the data to the file.
295 |
296 | db.commit()
297 |
298 | # Read and check cameras.
299 |
300 | rows = db.execute("SELECT * FROM cameras")
301 |
302 | camera_id, model, width, height, params, prior = next(rows)
303 | params = blob_to_array(params, np.float64)
304 | assert camera_id == camera_id1
305 | assert model == model1 and width == width1 and height == height1
306 | assert np.allclose(params, params1)
307 |
308 | camera_id, model, width, height, params, prior = next(rows)
309 | params = blob_to_array(params, np.float64)
310 | assert camera_id == camera_id2
311 | assert model == model2 and width == width2 and height == height2
312 | assert np.allclose(params, params2)
313 |
314 | # Read and check keypoints.
315 |
316 | keypoints = dict(
317 | (image_id, blob_to_array(data, np.float32, (-1, 2)))
318 | for image_id, data in db.execute(
319 | "SELECT image_id, data FROM keypoints"))
320 |
321 | assert np.allclose(keypoints[image_id1], keypoints1)
322 | assert np.allclose(keypoints[image_id2], keypoints2)
323 | assert np.allclose(keypoints[image_id3], keypoints3)
324 | assert np.allclose(keypoints[image_id4], keypoints4)
325 |
326 | # Read and check matches.
327 |
328 | pair_ids = [image_ids_to_pair_id(*pair) for pair in
329 | ((image_id1, image_id2),
330 | (image_id2, image_id3),
331 | (image_id3, image_id4))]
332 |
333 | matches = dict(
334 | (pair_id_to_image_ids(pair_id),
335 | blob_to_array(data, np.uint32, (-1, 2)))
336 | for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
337 | )
338 |
339 | assert np.all(matches[(image_id1, image_id2)] == matches12)
340 | assert np.all(matches[(image_id2, image_id3)] == matches23)
341 | assert np.all(matches[(image_id3, image_id4)] == matches34)
342 |
343 | # Clean up.
344 |
345 | db.close()
346 |
347 | if os.path.exists(args.database_path):
348 | os.remove(args.database_path)
349 |
350 |
351 | if __name__ == "__main__":
352 | example_usage()
353 |
354 |
--------------------------------------------------------------------------------
/scripts/download_vqmodels.py:
--------------------------------------------------------------------------------
1 | # convenience function to download pretrained vqmodels and run with provided training configs
2 | import os
3 | from geofree.util import get_local_path
4 |
5 |
6 | SYMLINK_MAP = {
7 | "re_first_stage": "pretrained_models/realestate_first_stage/last.ckpt",
8 | "re_depth_stage": "pretrained_models/realestate_depth_stage/last.ckpt",
9 | "ac_first_stage": "pretrained_models/acid_first_stage/last.ckpt",
10 | "ac_depth_stage": "pretrained_models/acid_depth_stage/last.ckpt",
11 | }
12 |
13 |
14 | def create_symlink(name, path):
15 | print(f"Creating symlink from {path} to {SYMLINK_MAP[name]}")
16 | os.makedirs("/".join(SYMLINK_MAP[name].split(os.sep)[:-1]))
17 | os.symlink(src=path, dst=SYMLINK_MAP[name], target_is_directory=False)
18 |
19 |
20 | if __name__ == "__main__":
21 | for model in SYMLINK_MAP:
22 | path = get_local_path(model)
23 | create_symlink(model, path)
24 | print("done.")
25 |
26 |
--------------------------------------------------------------------------------
/scripts/sparse_from_realestate_format.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob, shutil, subprocess
2 | import numpy as np
3 |
4 |
5 | def pose_to_sparse(txt_src, img_src, spa_dst, DEBUG=False, exists_ok=True,
6 | worker_idx=None, world_size=None):
7 | opt = argparse.Namespace(txt_src=txt_src, img_src=img_src, spa_dst=spa_dst)
8 |
9 | assert os.path.exists(opt.txt_src), opt.txt_src
10 | assert os.path.exists(opt.img_src), opt.img_src
11 |
12 | if not os.path.exists(opt.spa_dst):
13 | os.makedirs(opt.spa_dst)
14 | else:
15 | if DEBUG:
16 | shutil.rmtree(opt.spa_dst)
17 | elif not exists_ok:
18 | print("Output directory exists, doing nothing")
19 | return 0
20 |
21 |
22 | txts = sorted(glob.glob(os.path.join(opt.txt_src, "*.txt")))
23 | if DEBUG: print(txts)
24 |
25 | if worker_idx is not None and world_size is not None:
26 | txts = txts[worker_idx::world_size]
27 |
28 | if DEBUG: txts = txts[:1]
29 | failed = list()
30 | for txt in txts:
31 | vidid = os.path.splitext(os.path.split(txt)[1])[0]
32 | print(f"Processing {vidid}")
33 |
34 | if (os.path.exists(os.path.join(opt.spa_dst, vidid, "sparse", "cameras.bin")) and
35 | os.path.exists(os.path.join(opt.spa_dst, vidid, "sparse", "images.bin")) and
36 | os.path.exists(os.path.join(opt.spa_dst, vidid, "sparse", "points3D.bin"))):
37 | print("Found sparse model, skipping {}".format(vidid))
38 | continue
39 |
40 | if os.path.exists(os.path.join(opt.spa_dst, vidid)):
41 | shutil.rmtree(os.path.join(opt.spa_dst, vidid))
42 | print("Found partial output of previous run, removed {}".format(vidid))
43 |
44 |
45 | # read camera poses for this sequence
46 | with open(txt, "r") as f:
47 | firstline = f.readline()
48 |
49 | if firstline.startswith("http"):
50 | if DEBUG: print("Ignoring first line.")
51 | skiprows = 1
52 | else:
53 | skiprows = 0
54 |
55 | vid_data = np.loadtxt(txt, skiprows=skiprows)
56 | if len(vid_data.shape) != 2:
57 | failed.append(vidid)
58 | print(f"Wrong txt format for {vidid}!")
59 | continue
60 |
61 | timestamps = vid_data[:,0].astype(np.int)
62 | if DEBUG: print(timestamps)
63 | filenames = [str(ts)+".png" for ts in timestamps]
64 |
65 | if not len(filenames) > 1:
66 | failed.append(vidid)
67 | print(f"Less than two frames, skipping {vidid}!")
68 | continue
69 |
70 | if not os.path.exists(os.path.join(opt.img_src, vidid)):
71 | failed.append(vidid)
72 | print(f"Could not find frames, skipping {vidid}!")
73 | continue
74 |
75 | if not len(glob.glob(os.path.join(opt.img_src, vidid, "*.png"))) == len(filenames):
76 | failed.append(vidid)
77 | print(f"Could not find all frames, skipping {vidid}!")
78 | continue
79 |
80 | if DEBUG: print(vid_data[0,1:])
81 | K_params = vid_data[:,1:7]
82 | Ks = np.zeros((K_params.shape[0], 3, 3))
83 | Ks[:,0,0] = K_params[:,0]
84 | Ks[:,1,1] = K_params[:,1]
85 | Ks[:,0,2] = K_params[:,2]
86 | Ks[:,1,2] = K_params[:,3]
87 | Ks[:,2,2] = 1
88 | assert (Ks[0,...]==Ks[1,...]).all()
89 | K = Ks[0]
90 | if DEBUG: print(K)
91 |
92 | Rts = vid_data[:,7:].reshape(-1, 3, 4)
93 | if DEBUG: print(Rts[0])
94 |
95 | # given these intrinsics and extrinsics, find a sparse set of scale
96 | # consistent 3d points following
97 | # https://colmap.github.io/faq.html#reconstruct-sparse-dense-model-from-known-camera-poses
98 |
99 | # extract and match features on frames
100 | dst_dir = os.path.join(opt.spa_dst, vidid)
101 | os.makedirs(dst_dir)
102 | database_path = os.path.join(dst_dir, "database.db")
103 |
104 | # symlink images
105 | image_path = os.path.join(dst_dir, "images")
106 | os.symlink(os.path.abspath(os.path.join(opt.img_src, vidid)), image_path)
107 |
108 | cmd = ["colmap", "feature_extractor",
109 | "--database_path", database_path,
110 | "--image_path", image_path,
111 | "--ImageReader.camera_model", "PINHOLE",
112 | "--ImageReader.single_camera", "1",
113 | "--SiftExtraction.use_gpu", "1"]
114 | if DEBUG: print(" ".join(cmd))
115 | subprocess.run(cmd, check=True)
116 |
117 | # read the database
118 | from database import COLMAPDatabase, blob_to_array, array_to_blob
119 | db = COLMAPDatabase.connect(database_path)
120 |
121 | # read and update camera
122 | ## https://colmap.github.io/cameras.html
123 | cam = db.execute("SELECT * FROM cameras").fetchone()
124 | camera_id = cam[0]
125 | camera_model = cam[1]
126 | assert camera_model == 1 # PINHOLE
127 | width = cam[2]
128 | height = cam[3]
129 | params = blob_to_array(cam[4], dtype=np.float64)
130 | assert len(params) == 4 # fx, fy, cx, cy for PINHOLE
131 |
132 | # adjust params
133 | params[0] = width*K[0,0]
134 | params[1] = height*K[1,1]
135 | params[2] = width*K[0,2]
136 | params[3] = height*K[1,2]
137 |
138 | # update
139 | db.execute("UPDATE cameras SET params = ? WHERE camera_id = ?",
140 | (array_to_blob(params), camera_id))
141 | db.commit()
142 |
143 | # match features
144 | cmd = ["colmap", "sequential_matcher",
145 | "--database_path", database_path,
146 | "--SiftMatching.use_gpu", "1"]
147 | if DEBUG: print(" ".join(cmd))
148 | subprocess.run(cmd, check=True)
149 |
150 | # triangulate
151 | ## prepare pose model
152 | ### https://colmap.github.io/format.html#text-format
153 | pose_dir = os.path.join(dst_dir, "pose")
154 | os.makedirs(pose_dir)
155 | cameras_txt = os.path.join(pose_dir, "cameras.txt")
156 | with open(cameras_txt, "w") as f:
157 | f.write("{} PINHOLE {} {} {}".format(camera_id, width, height,
158 | " ".join(["{:.2f}".format(p) for p in params])))
159 |
160 | images_txt = os.path.join(pose_dir, "images.txt")
161 | # match image ids with filenames and export their extrinsics to images.txt
162 | images = db.execute("SELECT image_id, name, camera_id FROM images").fetchall()
163 | lines = list()
164 | for image in images:
165 | assert image[2] == camera_id
166 | image_id = image[0]
167 | image_name = image[1]
168 | image_idx = filenames.index(image_name)
169 | Rt = Rts[image_idx]
170 | R = Rt[:3,:3]
171 | t = Rt[:3,3]
172 | # convert R to quaternion
173 | from scipy.spatial.transform import Rotation
174 | Q = Rotation.from_matrix(R).as_quat()
175 | # from x,y,z,w to w,x,y,z
176 | line = " ".join(["{:.6f}".format(x) for x in [Q[3],Q[0],Q[1],Q[2],t[0],t[1],t[2]]])
177 | line = "{} ".format(image_id)+line+" {} {}".format(camera_id, image_name)
178 | lines.append(line)
179 | lines.append("") # empty line for 3d points to be triangulated
180 | with open(images_txt, "w") as f:
181 | f.write("\n".join(lines)+"\n")
182 |
183 | # create empty points3D.txt
184 | points3D_txt = os.path.join(pose_dir, "points3D.txt")
185 | open(points3D_txt, "w").close()
186 |
187 | # run point_triangulator
188 | out_dir = os.path.join(dst_dir, "sparse")
189 | os.makedirs(out_dir)
190 | cmd = ["colmap", "point_triangulator",
191 | "--database_path", database_path,
192 | "--image_path", image_path,
193 | "--input_path", pose_dir,
194 | "--output_path", out_dir]
195 | result = subprocess.run(cmd)
196 | if result.returncode != 0:
197 | print(f"Triangulation failed for {vidid}!")
198 | failed.append(vidid)
199 |
200 | print("Failed sequences:")
201 | print("\n".join(failed))
202 | print(f"Could not create sparse models for {len(failed)} sequences.")
203 | return len(txts)
204 |
205 |
206 | if __name__ == "__main__":
207 | parser = argparse.ArgumentParser(description='Process some integers.')
208 | parser.add_argument('--txt_src', type=str,
209 | help='path to directory containing .txt files of realestate format')
210 | parser.add_argument('--img_src', type=str,
211 | help='path to directory containing /.png frames')
212 | parser.add_argument('--spa_dst', type=str,
213 | help='path to directory to write sparse models into')
214 | parser.add_argument('--DEBUG', action="store_true",
215 | help='for quick development')
216 | parser.add_argument('--worker_idx', type=int,
217 | help='if world_size is specified, should be 0<=worker_idx 1
224 | assert opt.worker_idx is not None
225 | assert 0<=opt.worker_idx=2.0.0',
20 | 'pytorch-lightning>=1.0.8',
21 | 'pygame',
22 | 'splatting @ git+https://github.com/pesser/splatting@1427d7c4204282d117403b35698d489e0324287f#egg=splatting',
23 | 'einops',
24 | 'importlib-resources',
25 | 'imageio',
26 | 'imageio-ffmpeg',
27 | 'test-tube'
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------