├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── assets
└── demo.gif
├── configs
├── base.yaml
├── infer.yaml
└── render
│ ├── cathedral.hdr
│ ├── cathedral.xml
│ ├── common.xml
│ ├── integrator_path.xml
│ ├── scene.xml
│ └── sensors.xml
├── dataLoader
├── __init__.py
├── gobjverse.py
├── google_scanned_objects.py
├── instant3d.py
├── mipnerf.py
├── mvgen.py
└── utils.py
├── environment.yml
├── eval_all.py
├── evaluation.py
├── lightning
├── loss.py
├── network.py
├── renderer.py
├── renderer_2dgs.py
├── system.py
├── utils.py
└── vis.py
├── third_party
└── image_generator
│ ├── .github
│ └── workflows
│ │ ├── black.yml
│ │ ├── test-build.yaml
│ │ └── test-inference.yml
│ ├── .gitignore
│ ├── CODEOWNERS
│ ├── LICENSE-CODE
│ ├── README.md
│ ├── assets
│ ├── 000.jpg
│ ├── sv3d.gif
│ └── tile.gif
│ ├── configs
│ ├── example_training
│ │ ├── autoencoder
│ │ │ └── kl-f4
│ │ │ │ ├── imagenet-attnfree-logvar.yaml
│ │ │ │ └── imagenet-kl_f8_8chn.yaml
│ │ ├── imagenet-f8_cond.yaml
│ │ ├── toy
│ │ │ ├── cifar10_cond.yaml
│ │ │ ├── mnist.yaml
│ │ │ ├── mnist_cond.yaml
│ │ │ ├── mnist_cond_discrete_eps.yaml
│ │ │ ├── mnist_cond_l1_loss.yaml
│ │ │ └── mnist_cond_with_ema.yaml
│ │ ├── txt2img-clipl-legacy-ucg-training.yaml
│ │ └── txt2img-clipl.yaml
│ ├── inference
│ │ ├── sd_2_1.yaml
│ │ ├── sd_2_1_768.yaml
│ │ ├── sd_xl_base.yaml
│ │ ├── sd_xl_refiner.yaml
│ │ ├── sv3d_p.yaml
│ │ ├── sv3d_u.yaml
│ │ ├── svd.yaml
│ │ └── svd_image_decoder.yaml
│ └── sd_xl_base.yaml
│ ├── data
│ └── DejaVuSans.ttf
│ ├── generator.py
│ ├── main.py
│ ├── model_licenses
│ ├── LICENCE-SD-Turbo
│ ├── LICENSE-SDXL-Turbo
│ ├── LICENSE-SDXL0.9
│ ├── LICENSE-SDXL1.0
│ ├── LICENSE-SV3D
│ └── LICENSE-SVD
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── scripts
│ ├── __init__.py
│ ├── demo
│ │ ├── __init__.py
│ │ ├── detect.py
│ │ ├── discretization.py
│ │ ├── gradio_app.py
│ │ ├── sampling.py
│ │ ├── streamlit_helpers.py
│ │ ├── sv3d_helpers.py
│ │ ├── turbo.py
│ │ └── video_sampling.py
│ ├── sampling
│ │ ├── configs
│ │ │ ├── sv3d_p.yaml
│ │ │ ├── sv3d_u.yaml
│ │ │ ├── svd.yaml
│ │ │ ├── svd_image_decoder.yaml
│ │ │ ├── svd_xt.yaml
│ │ │ ├── svd_xt_1_1.yaml
│ │ │ └── svd_xt_image_decoder.yaml
│ │ └── simple_video_sample.py
│ ├── tests
│ │ └── attention.py
│ └── util
│ │ ├── __init__.py
│ │ └── detection
│ │ ├── __init__.py
│ │ ├── nsfw_and_watermark_dectection.py
│ │ ├── p_head_v1.npz
│ │ └── w_head_v1.npz
│ ├── sgm
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cifar10.py
│ │ ├── dataset.py
│ │ └── mnist.py
│ ├── inference
│ │ ├── api.py
│ │ └── helpers.py
│ ├── lr_scheduler.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── autoencoder.py
│ │ └── diffusion.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── autoencoding
│ │ │ ├── __init__.py
│ │ │ ├── losses
│ │ │ │ ├── __init__.py
│ │ │ │ ├── discriminator_loss.py
│ │ │ │ └── lpips.py
│ │ │ ├── lpips
│ │ │ │ ├── __init__.py
│ │ │ │ ├── loss
│ │ │ │ │ ├── .gitignore
│ │ │ │ │ ├── LICENSE
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── lpips.py
│ │ │ │ ├── model
│ │ │ │ │ ├── LICENSE
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── model.py
│ │ │ │ ├── util.py
│ │ │ │ └── vqperceptual.py
│ │ │ ├── regularizers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ └── quantize.py
│ │ │ └── temporal_ae.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── denoiser.py
│ │ │ ├── denoiser_scaling.py
│ │ │ ├── denoiser_weighting.py
│ │ │ ├── discretizer.py
│ │ │ ├── guiders.py
│ │ │ ├── loss.py
│ │ │ ├── loss_weighting.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ ├── sampling.py
│ │ │ ├── sampling_utils.py
│ │ │ ├── sigma_sampling.py
│ │ │ ├── util.py
│ │ │ ├── video_model.py
│ │ │ └── wrappers.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ └── video_attention.py
│ └── util.py
│ └── tests
│ └── inference
│ └── test_inference.py
├── tools
├── camera.py
├── camera_utils.py
├── depth.py
├── download_dataset.py
├── download_objaverse.py
├── gen_video_path.py
├── hdf5_split_merge.py
├── img_utils.py
├── meshExtractor.py
├── meshRender.py
├── prepare_dataset_co3d.py
├── prepare_dataset_objaverse.py
└── rsh.py
└── train_lightning.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .idea/
3 | .ipynb_checkpoints/
4 | *.py[cod]
5 | *.so
6 | *.orig
7 | *.o
8 | *.json
9 | *.pth
10 | *.npy
11 | *.ipynb
12 | *.png
13 | logs/*
14 | outputs/*
15 | ckpts/*
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/diff-surfel-rasterization"]
2 | path = third_party/diff-surfel-rasterization
3 | url = git@github.com:hbb1/diff-surfel-rasterization.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Florent Bartoccioni and Eloi Zablocki and Andrei Bursuc and Patrick Perez and Matthieu Cord and Karteek Alahari
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LaRa: Efficient Large-Baseline Radiance Fields
2 |
3 | [Project page](https://apchenstu.github.io/LaRa/) | [Paper](https://arxiv.org/abs/2407.04699) | [Data](https://huggingface.co/apchen/LaRa/tree/main/dataset) | [Checkpoint](https://huggingface.co/apchen/LaRa/tree/main/ckpts) |
4 |
5 | 
6 |
7 | ## ⭐ New Features
8 | - 2024/04/05: Important updates -
9 | Now our method supports half precision training, achieving over **100% faster** convergence and about **1.5dB** gains with less iterations!
10 |
11 | | Model | PSNR ↑ | SSIM ↑ | Abs err (Geo) ↓ | Epoch | Time(day) | ckpt |
12 | | ------ | ------ | ------ | ------ | ------ | ------ | ------ |
13 | | Paper | 27.65 | 0.951 | 0.0654 | 50 | 3.5 | ------ |
14 | | bf16 | 29.15 | 0.956 | 0.0574 | 30 | 1.5 | [Download](https://huggingface.co/apchen/LaRa/tree/main/ckpts/) |
15 |
16 | Please download the pre-trained checkpoint from the provided link and place it in the `ckpts` folder.
17 |
18 | # Installation
19 |
20 | ```
21 | git clone https://github.com/autonomousvision/LaRa.git --recursive
22 | conda env create --file environment.yml
23 | conda activate lara
24 | ```
25 |
26 |
27 | # Dataset
28 | We used the processed [gobjaverse dataset](https://aigc3d.github.io/gobjaverse/) for training. A download script `tools/download_dataset.py` is provided to automatically download the datasets.
29 |
30 | ```
31 | python tools/download_dataset.py all
32 | ```
33 | Note: The GObjaverse dataset requires approximately 1.4 TB of storage. You can also download a subset of the dataset. Please refer to the provided script for details. Please manually delete the `_temp` folder after completing the download.
34 |
35 | If you would like to process the data by yourself, we provide preprocess scripts for the gobjaverse and co3d datasets, please check `tools/prepare_dataset_*`.
36 | You can also download our preprocessed data and put them to `dataset` folder:
37 | * [gobjaverse](#gobjaverse)
38 | * [Google Scaned Object](#GSO)
39 | * [Co3D](#Co3D)
40 | * Instant3D - Please contact the authors of Instant3D if you wish to obtain the data for comparison.
41 | # Training
42 | ```
43 | python train_lightning.py
44 | ```
45 | **note:** You can configure the GPU id and other parameter with `configs/base.yaml`.
46 |
47 | # Evaluation
48 | Our method supports the reconstruction of radiance fields from **multi-view**, **text**, and **single view** inputs. We provide a pre-trained checkpoint at [ckpt](https://huggingface.co/apchen/LaRa/resolve/main/ckpts/epoch%3D29.ckpt).
49 |
50 | ## multi-view to 3D
51 | To reproduce the table results, you can simply use:
52 | ```
53 | python eval_all.py
54 | ```
55 | **note:**
56 | - Please double-check that the paths inside the script are correct for your specific case.
57 | - Please specify the video_frames and save_mesh [labels](https://github.com/autonomousvision/LaRa/blob/main/eval_all.py#L11) if you would like to output mesh or video during the evaluation
58 |
59 | ## text to 3D
60 | ```
61 | python evaluation.py configs/infer.yaml
62 | infer.ckpt_path=ckpts/epoch=29.ckpt
63 | infer.save_folder=outputs/prompts/
64 | infer.dataset.generator_type=xxx
65 | infer.dataset.prompts=["a car made out of sushi","a beautiful rainbow fish"]
66 | ```
67 | **note:** This part is currently unavailable due to a permissions issue. I will look for an alternative text-to-multi-view generator later next week.
68 |
69 |
70 | ## single view to 3D
71 | ```
72 | python evaluation.py configs/infer.yaml
73 | infer.ckpt_path=ckpts/epoch=29.ckpt
74 | infer.save_folder=outputs/single-view/
75 | infer.dataset.generator_type="zero123plus-v1"
76 | infer.dataset.image_pathes=\["assets/examples/13_realfusion_cherry_1.png"\]
77 | ```
78 | **note:** It supports the generator types `zero123plus-v1.1` and `zero123plus-v1`.
79 |
80 |
81 |
82 | ## Acknowledgements
83 | Our render is built upon [2DGS](https://github.com/hbb1/2d-gaussian-splatting). The data preprocessing code for the Co3D dataset is partially borrowed from [Splatter-Image](https://github.com/szymanowiczs/splatter-image/blob/main/data_preprocessing/preprocess_co3d.py). Additionally, the script for generating multi-view images from text and single view image is sourced from [GRM](https://github.com/justimyhxu/grm). We thank all the authors for their great repos.
84 |
85 | ## Citation
86 | If you find our code or paper helps, please consider citing:
87 | ```bibtex
88 | @inproceedings{LaRa,
89 | author = {Anpei Chen and Haofei Xu and Stefano Esposito and Siyu Tang and Andreas Geiger},
90 | title = {LaRa: Efficient Large-Baseline Radiance Fields},
91 | booktitle = {European Conference on Computer Vision (ECCV)},
92 | year = {2024}
93 | }
94 | ```
95 |
96 |
97 |
--------------------------------------------------------------------------------
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/assets/demo.gif
--------------------------------------------------------------------------------
/configs/base.yaml:
--------------------------------------------------------------------------------
1 | gpu_id: [4,5,6,7]
2 |
3 | exp_name: LaRa/release-test
4 | n_views: 4
5 |
6 | model:
7 |
8 | encoder_backbone: 'vit_base_patch16_224.dino' # ['vit_small_patch16_224.dino','vit_base_patch16_224.dino']
9 |
10 | n_groups: [16] # n_groups for local attention
11 | n_offset_groups: 32 # offset radius of 1/n_offset_groups of the scene size
12 |
13 | K: 2 # primitives per-voxel
14 | sh_degree: 1 # view dependent color
15 |
16 | num_layers: 12
17 | num_heads: 16
18 |
19 | view_embed_dim: 32
20 | embedding_dim: 256
21 |
22 | vol_feat_reso: 16
23 | vol_embedding_reso: 32
24 |
25 | vol_embedding_out_dim: 80
26 |
27 | ckpt_path: null # specify a ckpt path if you want to continue training
28 |
29 | train_dataset:
30 | dataset_name: gobjeverse
31 | data_root: dataset/gobjaverse/gobjaverse.h5
32 |
33 | split: train
34 | img_size: [512,512] # image resolution
35 | n_group: ${n_views} # image resolution
36 | n_scenes: 3000000
37 | load_normal: True
38 |
39 |
40 |
41 | test_dataset:
42 | dataset_name: gobjeverse
43 | data_root: dataset/gobjaverse/gobjaverse.h5
44 |
45 | split: test
46 | img_size: [512,512]
47 | n_group: ${n_views}
48 | n_scenes: 3000000
49 | load_normal: True
50 |
51 | train:
52 | batch_size: 3
53 | lr: 4e-4
54 | beta1: 0.9
55 | beta2: 0.95
56 | weight_decay: 0.05
57 | # betas: [0.9, 0.95]
58 | warmup_iters: 1000
59 | n_epoch: 30
60 | limit_train_batches: 0.2
61 | limit_val_batches: 0.02
62 | check_val_every_n_epoch: 1
63 | start_fine: 5000
64 | use_rand_views: False
65 | test:
66 | batch_size: 3
67 |
68 | logger:
69 | name: tensorboard
70 | dir: logs/${exp_name}
71 |
--------------------------------------------------------------------------------
/configs/infer.yaml:
--------------------------------------------------------------------------------
1 | n_views: 4
2 |
3 | infer:
4 | dataset:
5 | # dataset_name: gobjeverse
6 | # data_root: dataset/gobjaverse_280k/gobjaverse_280k.hdf5
7 | # data_root: dataset/Co3D/co3d_teddybear.hdf5
8 | # data_root: dataset/Co3D/co3d_hydrant.hdf5
9 |
10 | # dataset_name: GSO
11 | # data_root: dataset/google_scanned_objects
12 |
13 | # dataset_name: instant3d
14 | # data_root: dataset/instant3D
15 |
16 | # text to 3D
17 | dataset_name: mvgen
18 | generator_type: instant3d
19 | prompts: ["a car made out of sushi"]
20 | image_pathes: []
21 |
22 | ## single view to 3D
23 | # dataset_name: mvgen
24 | # generator_type: zero123plus-v1.1 # zero123plus-v1.1,zero123plus-v1.2,sv3d
25 | # prompts: []
26 | # image_pathes: ['examples/19_dalle3_stump1.png']
27 |
28 | # # unposed inputs
29 | # dataset_name: unposed
30 | # image_pathes: examples/unposed/*.png
31 |
32 | split: test
33 | img_size: [512,512]
34 | n_group: 4
35 | n_scenes: 30000
36 | num_workers: 0
37 | batch_size: 1
38 |
39 | load_normal: False
40 |
41 | ckpt_path: ckpts/lara.ckpt
42 |
43 | eval_novel_view_only: True
44 | eval_depth: []
45 | metric_path: None
46 |
47 | save_folder: outputs/video_vis/mvgen
48 | video_frames: 120
49 | mesh_video_frames: 0
50 |
51 | save_mesh: True
52 | aabb: [-0.5,-0.5,-0.5,0.5,0.5,0.5]
53 |
54 | finetuning:
55 | with_ft: False
56 | steps: 500
57 |
58 | # lr
59 | position_lr: 0.000016
60 | feature_lr: 0.0025
61 | opacity_lr: 0.05
62 | scaling_lr: 0.005
63 | rotation_lr: 0.001
64 |
65 |
66 |
--------------------------------------------------------------------------------
/configs/render/cathedral.hdr:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/configs/render/cathedral.hdr
--------------------------------------------------------------------------------
/configs/render/cathedral.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/configs/render/common.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/configs/render/integrator_path.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/configs/render/scene.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/configs/render/sensors.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
24 |
25 |
26 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/dataLoader/__init__.py:
--------------------------------------------------------------------------------
1 | from dataLoader.gobjverse import gobjverse
2 | from dataLoader.google_scanned_objects import GoogleObjsDataset
3 | from dataLoader.instant3d import Instant3DObjsDataset
4 | from dataLoader.mipnerf import MipNeRF360Dataset
5 | from dataLoader.mvgen import MVGenDataset
6 |
7 | dataset_dict = {'gobjeverse': gobjverse,
8 | 'GSO': GoogleObjsDataset,
9 | 'instant3d': Instant3DObjsDataset,
10 | 'mipnerf360': MipNeRF360Dataset,
11 | 'mvgen': MVGenDataset,
12 | }
--------------------------------------------------------------------------------
/dataLoader/instant3d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from glob import glob
4 | import imageio
5 | import tqdm
6 | from multiprocessing import Pool
7 | import copy
8 | import cv2
9 | import random
10 | from PIL import Image
11 | import torch
12 | import json
13 | from dataLoader.utils import build_rays
14 | from scipy.spatial.transform import Rotation as R
15 | from dataLoader.utils import intrinsic_to_fov, KMean
16 |
17 | class Instant3DObjsDataset(torch.utils.data.Dataset):
18 | def __init__(self, cfg):
19 | super(Instant3DObjsDataset, self).__init__()
20 | self.data_root = cfg.data_root
21 |
22 | self.img_size = np.array(cfg.img_size)
23 |
24 | scenes_name = np.array([f for f in sorted(os.listdir(self.data_root)) if f.endswith('png')])
25 | self.scenes_name = scenes_name
26 | print(len(self.scenes_name))
27 |
28 | self.build_camera()
29 | self.bg_color = 1.0
30 |
31 | def build_camera(self):
32 | scene_info = {'c2ws':[],'w2cs':[],'ixts':[]}
33 | json_info = json.load(open(os.path.join(self.data_root, f'opencv_cameras.json')))
34 |
35 | for i in range(4):
36 | frame = json_info['frames'][i]
37 | w2c = np.array(frame['w2c'])
38 | c2w = np.linalg.inv(w2c)
39 | c2w[:3,3] /= 1.7
40 | w2c = np.linalg.inv(c2w)
41 | scene_info['c2ws'].append(c2w)
42 | scene_info['w2cs'].append(w2c)
43 |
44 | ixt = np.eye(3)
45 | ixt[[0,1],[0,1]] = np.array([frame['fx'],frame['fy']])
46 | ixt[[0,1],[2,2]] = np.array([frame['cx'],frame['cy']])
47 | scene_info['ixts'].append(ixt)
48 |
49 | scene_info['c2ws'] = np.stack(scene_info['c2ws']).astype(np.float32)
50 | scene_info['w2cs'] = np.stack(scene_info['w2cs']).astype(np.float32)
51 | scene_info['ixts'] = np.stack(scene_info['ixts']).astype(np.float32)
52 |
53 | self.scene_info = scene_info
54 |
55 | def __getitem__(self, index):
56 |
57 |
58 | scenes_name = self.scenes_name[index]
59 | # src_view_id = list(range(4))
60 | # tar_views = src_view_id + list(range(4))
61 |
62 | #np.random.rand(3)
63 | tar_img = self.read_image(scenes_name)
64 | tar_c2ws = self.scene_info['c2ws']
65 | tar_w2cs = self.scene_info['w2cs']
66 | tar_ixts = self.scene_info['ixts']
67 |
68 | # align cameras using first view
69 | # no inver operation
70 | r = np.linalg.norm(tar_c2ws[0,:3,3])
71 | ref_c2w = np.eye(4, dtype=np.float32).reshape(1,4,4)
72 | ref_w2c = np.eye(4, dtype=np.float32).reshape(1,4,4)
73 | ref_c2w[:,2,3], ref_w2c[:,2,3] = -r, r
74 | transform_mats = ref_c2w @ tar_w2cs[:1]
75 | tar_w2cs = tar_w2cs.copy() @ tar_c2ws[:1] @ ref_w2c
76 | tar_c2ws = transform_mats @ tar_c2ws.copy()
77 |
78 | fov_x, fov_y = intrinsic_to_fov(tar_ixts[0],w=512,h=512)
79 |
80 | ret = {'fovx':fov_x,
81 | 'fovy':fov_y,
82 | }
83 | H, W = self.img_size
84 |
85 | ret.update({'tar_c2w': tar_c2ws,
86 | 'tar_w2c': tar_w2cs,
87 | 'tar_ixt': tar_ixts,
88 | 'tar_rgb': tar_img.transpose(1,0,2,3).reshape(H,4*W,3),
89 | 'transform_mats': transform_mats
90 | })
91 | near_far = np.array([r-1.0, r+1.0]).astype(np.float32)
92 | ret.update({'near_far': np.array(near_far).astype(np.float32)})
93 | ret.update({'meta': {'scene':scenes_name,f'tar_h': int(H), f'tar_w': int(W)}})
94 |
95 | rays = build_rays(tar_c2ws, tar_ixts.copy(), H, W, 1.0)
96 | ret.update({f'tar_rays': rays})
97 | rays_down = build_rays(tar_c2ws, tar_ixts.copy(), H, W, 1.0/16)
98 | ret.update({f'tar_rays_down': rays_down})
99 | return ret
100 |
101 |
102 | def read_image(self, scenes_name):
103 |
104 | img = imageio.imread(f'{self.data_root}/{scenes_name}')
105 | img = img.astype(np.float32) / 255.
106 | if img.shape[-1] == 4:
107 | img = (img[..., :3] * img[..., -1:] + self.bg_color*(1 - img[..., -1:])).astype(np.float32)
108 |
109 | # split images
110 | row_chunks = np.array_split(img, 2)
111 | imgs = np.stack([np.array_split(chunk, 2, axis=1) for chunk in row_chunks]).reshape(4,512,512,-1)
112 | return imgs
113 |
114 |
115 | def __len__(self):
116 | return len(self.scenes_name)
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: lara
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | dependencies:
7 | - pip=24.0
8 | - python=3.9.19
9 | - zlib=1.2.13
10 | - pytorch=2.1.2
11 | - torchaudio=2.1.2
12 | - torchvision=0.16.2
13 | - pip:
14 | - open3d==0.18.0
15 | - diffusers==0.27.2
16 | - einops==0.7.0
17 | - fire==0.6.0
18 | - h5py==3.10.0
19 | - imageio==2.34.0
20 | - kornia==0.7.2
21 | - pytorch-lightning==2.1.3
22 | - lpips==0.1.4
23 | - matplotlib==3.8.4
24 | - numpy==1.26.3
25 | - omegaconf==2.3.0
26 | - opencv-python==4.9.0.80
27 | - pillow==10.2.0
28 | - pytorch-msssim==1.0.0
29 | - pyyaml==6.0.1
30 | - rembg==2.0.56
31 | - scikit-image==0.21.0
32 | - scikit-learn==1.4.0
33 | - scipy==1.13.0
34 | - tensorboardx==2.6.2.2
35 | - timm==0.9.12
36 | - torchmetrics==1.4.0
37 | - torchtyping==0.1.4
38 | - tqdm==4.66.2
39 | - transformers==4.25.1
40 | - trimesh==4.3.2
41 | - wandb==0.16.2
42 | - xformers==0.0.23.post1
43 | - tensorboard==2.16.2
44 | - open_clip_torch==2.24.0
45 | - streamlit==1.35.0
46 | - invisible-watermark==0.2.0
47 | - git+https://github.com/openai/CLIP.git
48 | - third_party/diff-surfel-rasterization
--------------------------------------------------------------------------------
/eval_all.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | gpu_id = 0
4 | name = 'release'
5 | ckpt_path = f'ckpts/epoch=29.ckpt'
6 |
7 | for n_views in [4]:
8 | cmd = f'CUDA_VISIBLE_DEVICES={gpu_id} python evaluation.py configs/infer.yaml n_views={n_views} infer.eval_novel_view_only=True ' \
9 | f'infer.ckpt_path={ckpt_path} infer.metric_path=outputs/metrics/{name}_GSO_{n_views}_views.json ' \
10 | f'infer.dataset.dataset_name=GSO infer.dataset.data_root=dataset/google_scanned_objects infer.eval_depth=[0.005,0.01,0.02] ' \
11 | f'infer.video_frames=0 infer.save_mesh=False ' \
12 | f'infer.save_folder=outputs/image_vis/{name}_GSO_{n_views}_views infer.dataset.n_group={n_views} '
13 | os.system(cmd)
14 |
15 | cmd = f'CUDA_VISIBLE_DEVICES={gpu_id} python evaluation.py configs/infer.yaml n_views={n_views} infer.eval_novel_view_only=True ' \
16 | f'infer.ckpt_path={ckpt_path} infer.metric_path=outputs/metrics/{name}_gobjeverse_{n_views}_views.json ' \
17 | f'infer.dataset.dataset_name=gobjeverse infer.dataset.data_root=dataset/gobjaverse/gobjaverse.h5 ' \
18 | f'infer.video_frames=0 infer.save_mesh=False ' \
19 | f'infer.save_folder=outputs/image_vis/{name}_gobjaverse_{n_views}_views infer.dataset.n_group={n_views} '
20 | os.system(cmd)
21 |
22 | cmd = f'CUDA_VISIBLE_DEVICES={gpu_id} python evaluation.py configs/infer.yaml n_views={n_views} infer.eval_novel_view_only=True ' \
23 | f'infer.ckpt_path={ckpt_path} infer.metric_path=outputs/metrics/{name}_co3d_teddybear_{n_views}_views.json ' \
24 | f'infer.dataset.dataset_name=gobjeverse infer.dataset.data_root=dataset/Co3D/co3d_teddybear.h5 ' \
25 | f'infer.video_frames=0 infer.save_mesh=False ' \
26 | f'infer.save_folder=outputs/image_vis/{name}_co3d_teddybear infer.dataset.n_group={n_views} '
27 | os.system(cmd)
28 |
29 | cmd = f'CUDA_VISIBLE_DEVICES={gpu_id} python evaluation.py configs/infer.yaml n_views={n_views} infer.eval_novel_view_only=True ' \
30 | f'infer.ckpt_path={ckpt_path} infer.metric_path=outputs/metrics/{name}_co3d_hydrant_{n_views}_views.json ' \
31 | f'infer.dataset.dataset_name=gobjeverse infer.dataset.data_root=dataset/Co3D/co3d_hydrant.h5 ' \
32 | f'infer.video_frames=0 infer.save_mesh=False ' \
33 | f'infer.save_folder=outputs/image_vis/{name}_co3d_hydrant infer.dataset.n_group={n_views} '
34 | os.system(cmd)
35 |
36 |
37 |
--------------------------------------------------------------------------------
/lightning/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pytorch_msssim import MS_SSIM
4 | from torch.nn import functional as F
5 |
6 | from torch.cuda.amp import autocast
7 |
8 | class Losses(nn.Module):
9 | def __init__(self):
10 | super(Losses, self).__init__()
11 |
12 | self.color_crit = nn.MSELoss(reduction='mean')
13 | self.mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
14 |
15 | self.ssim = MS_SSIM(data_range=1.0, size_average=True, channel=3)
16 |
17 | def forward(self, batch, output, iter):
18 |
19 | scalar_stats = {}
20 | loss = 0
21 |
22 | B,V,H,W = batch['tar_rgb'].shape[:-1]
23 |
24 | tar_rgb = batch['tar_rgb'].permute(0,2,1,3,4).reshape(B,H,V*W,3)
25 |
26 |
27 | if 'image' in output:
28 |
29 | for prex in ['','_fine']:
30 |
31 |
32 | if prex=='_fine' and f'acc_map{prex}' not in output:
33 | continue
34 |
35 | color_loss_all = (output[f'image{prex}']-tar_rgb)**2
36 | loss += color_loss_all.mean()
37 |
38 | psnr = -10. * torch.log(color_loss_all.detach().mean()) / \
39 | torch.log(torch.Tensor([10.]).to(color_loss_all.device))
40 | scalar_stats.update({f'mse{prex}': color_loss_all.mean().detach()})
41 | scalar_stats.update({f'psnr{prex}': psnr})
42 |
43 |
44 | with autocast(enabled=False):
45 | ssim_val = self.ssim(output[f'image{prex}'].permute(0,3,1,2), tar_rgb.permute(0,3,1,2))
46 | scalar_stats.update({f'ssim{prex}': ssim_val.detach()})
47 | loss += 0.5 * (1-ssim_val)
48 |
49 | if f'rend_dist{prex}' in output and iter>1000 and prex!='_fine':
50 | distortion = output[f"rend_dist{prex}"].mean()
51 | scalar_stats.update({f'distortion{prex}': distortion.detach()})
52 | loss += distortion*1000
53 |
54 | rend_normal = output[f'rend_normal{prex}']
55 | depth_normal = output[f'depth_normal{prex}']
56 | acc_map = output[f'acc_map{prex}'].detach()
57 |
58 | normal_error = ((1 - (rend_normal * depth_normal).sum(dim=-1))*acc_map).mean()
59 | scalar_stats.update({f'normal{prex}': normal_error.detach()})
60 | loss += normal_error*0.2
61 |
62 | return loss, scalar_stats
63 |
64 |
--------------------------------------------------------------------------------
/lightning/system.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | from lightning.loss import Losses
5 | import pytorch_lightning as L
6 |
7 | import torch.nn as nn
8 | from lightning.vis import vis_images
9 | from pytorch_lightning.loggers import TensorBoardLogger
10 | from lightning.utils import CosineWarmupScheduler
11 |
12 | from lightning.network import Network
13 |
14 | class system(L.LightningModule):
15 | def __init__(self, cfg):
16 | super().__init__()
17 |
18 | self.cfg = cfg
19 | self.loss = Losses()
20 | self.net = Network(cfg)
21 |
22 | self.validation_step_outputs = []
23 |
24 | def training_step(self, batch, batch_idx):
25 |
26 | output = self.net(batch, with_fine=self.global_step>self.cfg.train.start_fine)
27 | loss, scalar_stats = self.loss(batch, output, self.global_step)
28 | for key, value in scalar_stats.items():
29 | prog_bar = True if key in ['psnr','mask','depth'] else False
30 | self.log(f'train/{key}', value, prog_bar=prog_bar)
31 | self.log('lr',self.trainer.optimizers[0].param_groups[0]['lr'])
32 |
33 | if 0 == self.trainer.global_step % 3000 and (self.trainer.local_rank == 0):
34 | self.vis_results(output, batch, prex='train')
35 |
36 | return loss
37 |
38 | def validation_step(self, batch, batch_idx):
39 | self.net.eval()
40 | output = self.net(batch, with_fine=self.global_step>self.cfg.train.start_fine)
41 | loss, scalar_stats = self.loss(batch, output, self.global_step)
42 | if batch_idx == 0 and (self.trainer.local_rank == 0):
43 | self.vis_results(output, batch, prex='val')
44 | self.validation_step_outputs.append(scalar_stats)
45 | return loss
46 |
47 | def on_validation_epoch_end(self):
48 | keys = self.validation_step_outputs[0]
49 | for key in keys:
50 | prog_bar = True if key in ['psnr','mask','depth'] else False
51 | metric_mean = torch.stack([x[key] for x in self.validation_step_outputs]).mean()
52 | self.log(f'val/{key}', metric_mean, prog_bar=prog_bar, sync_dist=True)
53 |
54 | self.validation_step_outputs.clear() # free memory
55 | torch.cuda.empty_cache()
56 |
57 | def vis_results(self, output, batch, prex):
58 | output_vis = vis_images(output, batch)
59 | for key, value in output_vis.items():
60 | if isinstance(self.logger, TensorBoardLogger):
61 | B,h,w = value.shape[:3]
62 | value = value.reshape(1,B*h,w,3).transpose(0,3,1,2)
63 | self.logger.experiment.add_images(f'{prex}/{key}', value, self.global_step)
64 | else:
65 | imgs = [np.concatenate([img for img in value],axis=0)]
66 | self.logger.log_image(f'{prex}/{key}', imgs, step=self.global_step)
67 | self.net.train()
68 |
69 | def num_steps(self) -> int:
70 | """Get number of steps"""
71 | # Accessing _data_source is flaky and might break
72 | dataset = self.trainer.fit_loop._data_source.dataloader()
73 | dataset_size = len(dataset)
74 | num_devices = max(1, self.trainer.num_devices)
75 | num_steps = dataset_size * self.trainer.max_epochs * self.cfg.train.limit_train_batches // (self.trainer.accumulate_grad_batches * num_devices)
76 | return int(num_steps)
77 |
78 | def configure_optimizers(self):
79 | decay_params, no_decay_params = [], []
80 |
81 | # add all bias and LayerNorm params to no_decay_params
82 | for name, module in self.named_modules():
83 | if isinstance(module, nn.LayerNorm):
84 | no_decay_params.extend([p for p in module.parameters()])
85 | elif hasattr(module, 'bias') and module.bias is not None:
86 | no_decay_params.append(module.bias)
87 |
88 | # add remaining parameters to decay_params
89 | _no_decay_ids = set(map(id, no_decay_params))
90 | decay_params = [p for p in self.parameters() if id(p) not in _no_decay_ids]
91 |
92 | # filter out parameters with no grad
93 | decay_params = list(filter(lambda p: p.requires_grad, decay_params))
94 | no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params))
95 |
96 | # Optimizer
97 | opt_groups = [
98 | {'params': decay_params, 'weight_decay': self.cfg.train.weight_decay},
99 | {'params': no_decay_params, 'weight_decay': 0.0},
100 | ]
101 | optimizer = torch.optim.AdamW(
102 | opt_groups,
103 | lr=self.cfg.train.lr,
104 | betas=(self.cfg.train.beta1, self.cfg.train.beta2),
105 | )
106 |
107 | total_global_batches = self.num_steps()
108 | scheduler = CosineWarmupScheduler(
109 | optimizer=optimizer,
110 | warmup_iters=self.cfg.train.warmup_iters,
111 | max_iters=total_global_batches,
112 | )
113 |
114 | return {"optimizer": optimizer,
115 | "lr_scheduler": {
116 | 'scheduler': scheduler,
117 | 'interval': 'step' # or 'epoch' for epoch-level updates
118 | }}
--------------------------------------------------------------------------------
/lightning/utils.py:
--------------------------------------------------------------------------------
1 | import torch, os, json, math
2 | import numpy as np
3 | from torch.optim.lr_scheduler import LRScheduler
4 |
5 | def getProjectionMatrix(znear, zfar, fovX, fovY):
6 |
7 | tanHalfFovY = torch.tan((fovY / 2))
8 | tanHalfFovX = torch.tan((fovX / 2))
9 |
10 | P = torch.zeros(4, 4)
11 |
12 | z_sign = 1.0
13 |
14 | P[0, 0] = 1 / tanHalfFovX
15 | P[1, 1] = 1 / tanHalfFovY
16 | P[3, 2] = z_sign
17 | P[2, 2] = z_sign * zfar / (zfar - znear)
18 | P[2, 3] = -(zfar * znear) / (zfar - znear)
19 | return P
20 |
21 |
22 | class MiniCam:
23 | def __init__(self, c2w, width, height, fovy, fovx, znear, zfar, device):
24 | # c2w (pose) should be in NeRF convention.
25 |
26 | self.image_width = width
27 | self.image_height = height
28 | self.FoVy = fovy
29 | self.FoVx = fovx
30 | self.znear = znear
31 | self.zfar = zfar
32 |
33 | w2c = torch.inverse(c2w)
34 |
35 | # rectify...
36 | # w2c[1:3, :3] *= -1
37 | # w2c[:3, 3] *= -1
38 |
39 | self.world_view_transform = w2c.transpose(0, 1).to(device)
40 | self.projection_matrix = (
41 | getProjectionMatrix(
42 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
43 | )
44 | .transpose(0, 1)
45 | .to(device)
46 | )
47 | self.full_proj_transform = (self.world_view_transform @ self.projection_matrix).float()
48 | self.camera_center = -c2w[:3, 3].to(device)
49 |
50 |
51 | def rotation_matrix_to_quaternion(R):
52 | tr = R[0, 0] + R[1, 1] + R[2, 2]
53 | if tr > 0:
54 | S = torch.sqrt(tr + 1.0) * 2.0
55 | qw = 0.25 * S
56 | qx = (R[2, 1] - R[1, 2]) / S
57 | qy = (R[0, 2] - R[2, 0]) / S
58 | qz = (R[1, 0] - R[0, 1]) / S
59 | elif (R[0, 0] > R[1, 1]) and (R[0, 0] > R[2, 2]):
60 | S = torch.sqrt(1.0 + R[0, 0] - R[1, 1] - R[2, 2]) * 2.0
61 | qw = (R[2, 1] - R[1, 2]) / S
62 | qx = 0.25 * S
63 | qy = (R[0, 1] + R[1, 0]) / S
64 | qz = (R[0, 2] + R[2, 0]) / S
65 | elif R[1, 1] > R[2, 2]:
66 | S = torch.sqrt(1.0 + R[1, 1] - R[0, 0] - R[2, 2]) * 2.0
67 | qw = (R[0, 2] - R[2, 0]) / S
68 | qx = (R[0, 1] + R[1, 0]) / S
69 | qy = 0.25 * S
70 | qz = (R[1, 2] + R[2, 1]) / S
71 | else:
72 | S = torch.sqrt(1.0 + R[2, 2] - R[0, 0] - R[1, 1]) * 2.0
73 | qw = (R[1, 0] - R[0, 1]) / S
74 | qx = (R[0, 2] + R[2, 0]) / S
75 | qy = (R[1, 2] + R[2, 1]) / S
76 | qz = 0.25 * S
77 | return torch.stack([qw, qx, qy, qz], dim=1)
78 |
79 | def rotate_quaternions(q, R):
80 | # Convert quaternions to rotation matrices
81 | q = torch.cat([q[:, :1], -q[:, 1:]], dim=1)
82 | q = torch.cat([q[:, :3], q[:, 3:] * -1], dim=1)
83 | rotated_R = torch.matmul(torch.matmul(q, R), q.inverse())
84 |
85 | # Convert the rotated rotation matrices back to quaternions
86 | return rotation_matrix_to_quaternion(rotated_R)
87 |
88 | # this function is borrowed from OpenLRM
89 | class CosineWarmupScheduler(LRScheduler):
90 | def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1):
91 | self.warmup_iters = warmup_iters
92 | self.max_iters = max_iters
93 | self.initial_lr = initial_lr
94 | super().__init__(optimizer, last_iter)
95 |
96 | def get_lr(self):
97 |
98 | if self._step_count <= self.warmup_iters:
99 | return [
100 | self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters
101 | for base_lr in self.base_lrs]
102 | else:
103 | cos_iter = self._step_count - self.warmup_iters
104 | cos_max_iter = self.max_iters - self.warmup_iters
105 | cos_theta = cos_iter / cos_max_iter * math.pi
106 | cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs]
107 | return cos_lr
--------------------------------------------------------------------------------
/lightning/vis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tools.img_utils import visualize_depth_numpy
4 |
5 |
6 |
7 | def vis_appearance_depth(output, batch):
8 | outputs = {}
9 | B, V, H, W = batch['tar_rgb'].shape[:-1]
10 |
11 | pred_rgb = output[f'image'].detach().cpu().numpy()
12 | pred_depth = output[f'depth'].detach().cpu().numpy()
13 | gt_rgb = batch[f'tar_rgb'].permute(0,2,1,3,4).reshape(B, H, V*W, 3).detach().cpu().numpy()
14 |
15 | near_far = batch['near_far'][0].tolist()
16 | pred_depth_colorlized = np.stack([visualize_depth_numpy(_depth, near_far) for _depth in pred_depth]).astype('float32')/255
17 | outputs.update({f"gt_rgb":gt_rgb, f"pred_rgb":pred_rgb, f"pred_depth":pred_depth_colorlized})
18 |
19 |
20 | if 'rend_normal' in output:
21 | rend_normal = torch.nn.functional.normalize(output[f'rend_normal'].detach(),dim=-1)
22 | rend_normal = rend_normal.cpu().numpy()
23 | outputs.update({f"rend_normal":(rend_normal+1)/2})
24 |
25 | depth_normal = output[f'depth_normal'].detach().cpu().numpy()
26 | outputs.update({f"depth_normal":(depth_normal+1)/2})
27 |
28 | if 'tar_nrm' in batch:
29 | normal_gt = batch['tar_nrm'].cpu().numpy()
30 | outputs.update({f"normal_gt":(normal_gt+1)/2})
31 |
32 |
33 | if 'img_tri' in output:
34 | img_tri = output['img_tri'].detach().cpu().permute(0,2,3,1).numpy()
35 | outputs.update({f"img_tri": img_tri})
36 | if 'feats_tri' in output:
37 | feats_tri = output['feats_tri'].detach().cpu().permute(0,2,3,1).numpy()
38 | outputs.update({f"feats_tri": feats_tri})
39 |
40 | if 'image_fine' in output:
41 | rgb_fine = output[f'image_fine'].detach().cpu().numpy()
42 | outputs.update({f"rgb_fine":rgb_fine})
43 |
44 | pred_depth_fine = output[f'depth_fine'].detach().cpu().numpy()
45 | pred_depth_fine_colorlized = np.stack([visualize_depth_numpy(_depth, near_far) for _depth in pred_depth_fine]).astype('float32')/255
46 | outputs.update({f"pred_depth_fine":pred_depth_fine_colorlized})
47 |
48 | if 'rend_normal_fine' in output:
49 | rend_normal_fine = torch.nn.functional.normalize(output[f'rend_normal_fine'].detach(),dim=-1)
50 | rend_normal_fine = rend_normal_fine.cpu().numpy()
51 | outputs.update({f"rend_normal_fine":(rend_normal_fine+1)/2})
52 |
53 | if 'depth_normal_fine' in output:
54 | depth_normal_fine = output[f'depth_normal_fine'].detach().cpu().numpy()
55 | outputs.update({f"depth_normal_fine":(depth_normal_fine+1)/2})
56 |
57 | return outputs
58 |
59 | def vis_depth(output, batch):
60 |
61 | outputs = {}
62 | B, S, _, H, W = batch['src_inps'].shape
63 | h, w = batch['src_deps'].shape[-2:]
64 |
65 | near_far = batch['near_far'][0].tolist()
66 | gt_src_depth = batch['src_deps'].reshape(B,-1, h, w).cpu().permute(0,2,1,3).numpy().reshape(B,h,-1)
67 | mask = gt_src_depth > 0
68 | pred_src_depth = output['pred_src_depth'].reshape(B,-1, h, w).detach().cpu().permute(0,2,1,3).numpy().reshape(B,h,-1)
69 | pred_src_depth[~mask] = 0.0
70 | depth_err = np.abs(gt_src_depth-pred_src_depth)*2
71 | gt_src_depth_colorlized = np.stack([visualize_depth_numpy(_depth, near_far) for _depth in gt_src_depth]).astype('float32')/255
72 | pred_src_depth_colorlized = np.stack([visualize_depth_numpy(_depth, near_far) for _depth in pred_src_depth]).astype('float32')/255
73 | depth_err_colorlized = np.stack([visualize_depth_numpy(_err, near_far) for _err in depth_err]).astype('float32')/255
74 | rgb_source = batch['src_inps'].reshape(B,S, 3, H, W).detach().cpu().permute(0,3,1,4,2).numpy().reshape(B,H,-1,3)
75 |
76 | outputs.update({f"rgb_source": rgb_source, "gt_src_depth": gt_src_depth_colorlized,
77 | "pred_src_depth":pred_src_depth_colorlized, "depth_err":depth_err_colorlized})
78 |
79 | return outputs
80 |
81 | def vis_images(output, batch):
82 | if 'image' in output:
83 | return vis_appearance_depth(output, batch)
84 | else:
85 | return vis_depth(output, batch)
86 |
--------------------------------------------------------------------------------
/third_party/image_generator/.github/workflows/black.yml:
--------------------------------------------------------------------------------
1 | name: Run black
2 | on: [pull_request]
3 |
4 | jobs:
5 | lint:
6 | runs-on: ubuntu-latest
7 | steps:
8 | - uses: actions/checkout@v3
9 | - name: Install venv
10 | run: |
11 | sudo apt-get -y install python3.10-venv
12 | - uses: psf/black@stable
13 | with:
14 | options: "--check --verbose -l88"
15 | src: "./sgm ./scripts ./main.py"
16 |
--------------------------------------------------------------------------------
/third_party/image_generator/.github/workflows/test-build.yaml:
--------------------------------------------------------------------------------
1 | name: Build package
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 |
8 | jobs:
9 | build:
10 | name: Build
11 | runs-on: ubuntu-latest
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | python-version: ["3.8", "3.10"]
16 | requirements-file: ["pt2", "pt13"]
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install -r requirements/${{ matrix.requirements-file }}.txt
27 | pip install .
--------------------------------------------------------------------------------
/third_party/image_generator/.github/workflows/test-inference.yml:
--------------------------------------------------------------------------------
1 | name: Test inference
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - main
8 |
9 | jobs:
10 | test:
11 | name: "Test inference"
12 | # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment
13 | if: github.repository == 'stability-ai/generative-models'
14 | runs-on: [self-hosted, slurm, g40]
15 | steps:
16 | - uses: actions/checkout@v3
17 | - name: "Symlink checkpoints"
18 | run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints
19 | - name: "Setup python"
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: "3.10"
23 | - name: "Install Hatch"
24 | run: pip install hatch
25 | - name: "Run inference tests"
26 | run: hatch run ci:test-inference --junit-xml test-results.xml
27 | - name: Surface failing tests
28 | if: always()
29 | uses: pmeier/pytest-results-action@main
30 | with:
31 | path: test-results.xml
32 | summary: true
33 | display-options: fEX
34 | fail-on-empty: true
35 |
--------------------------------------------------------------------------------
/third_party/image_generator/.gitignore:
--------------------------------------------------------------------------------
1 | # extensions
2 | *.egg-info
3 | *.py[cod]
4 |
5 | # envs
6 | .pt13
7 | .pt2
8 |
9 | # directories
10 | /checkpoints
11 | /dist
12 | /outputs
13 | /build
14 | /src
--------------------------------------------------------------------------------
/third_party/image_generator/CODEOWNERS:
--------------------------------------------------------------------------------
1 | .github @Stability-AI/infrastructure
--------------------------------------------------------------------------------
/third_party/image_generator/LICENSE-CODE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Stability AI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/third_party/image_generator/assets/000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/assets/000.jpg
--------------------------------------------------------------------------------
/third_party/image_generator/assets/sv3d.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/assets/sv3d.gif
--------------------------------------------------------------------------------
/third_party/image_generator/assets/tile.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/assets/tile.gif
--------------------------------------------------------------------------------
/third_party/image_generator/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: sgm.models.autoencoder.AutoencodingEngine
4 | params:
5 | input_key: jpg
6 | monitor: val/rec_loss
7 |
8 | loss_config:
9 | target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
10 | params:
11 | perceptual_weight: 0.25
12 | disc_start: 20001
13 | disc_weight: 0.5
14 | learn_logvar: True
15 |
16 | regularization_weights:
17 | kl_loss: 1.0
18 |
19 | regularizer_config:
20 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
21 |
22 | encoder_config:
23 | target: sgm.modules.diffusionmodules.model.Encoder
24 | params:
25 | attn_type: none
26 | double_z: True
27 | z_channels: 4
28 | resolution: 256
29 | in_channels: 3
30 | out_ch: 3
31 | ch: 128
32 | ch_mult: [1, 2, 4]
33 | num_res_blocks: 4
34 | attn_resolutions: []
35 | dropout: 0.0
36 |
37 | decoder_config:
38 | target: sgm.modules.diffusionmodules.model.Decoder
39 | params: ${model.params.encoder_config.params}
40 |
41 | data:
42 | target: sgm.data.dataset.StableDataModuleFromConfig
43 | params:
44 | train:
45 | datapipeline:
46 | urls:
47 | - DATA-PATH
48 | pipeline_config:
49 | shardshuffle: 10000
50 | sample_shuffle: 10000
51 |
52 | decoders:
53 | - pil
54 |
55 | postprocessors:
56 | - target: sdata.mappers.TorchVisionImageTransforms
57 | params:
58 | key: jpg
59 | transforms:
60 | - target: torchvision.transforms.Resize
61 | params:
62 | size: 256
63 | interpolation: 3
64 | - target: torchvision.transforms.ToTensor
65 | - target: sdata.mappers.Rescaler
66 | - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
67 | params:
68 | h_key: height
69 | w_key: width
70 |
71 | loader:
72 | batch_size: 8
73 | num_workers: 4
74 |
75 |
76 | lightning:
77 | strategy:
78 | target: pytorch_lightning.strategies.DDPStrategy
79 | params:
80 | find_unused_parameters: True
81 |
82 | modelcheckpoint:
83 | params:
84 | every_n_train_steps: 5000
85 |
86 | callbacks:
87 | metrics_over_trainsteps_checkpoint:
88 | params:
89 | every_n_train_steps: 50000
90 |
91 | image_logger:
92 | target: main.ImageLogger
93 | params:
94 | enable_autocast: False
95 | batch_frequency: 1000
96 | max_images: 8
97 | increase_log_steps: True
98 |
99 | trainer:
100 | devices: 0,
101 | limit_val_batches: 50
102 | benchmark: True
103 | accumulate_grad_batches: 1
104 | val_check_interval: 10000
--------------------------------------------------------------------------------
/third_party/image_generator/configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: sgm.models.autoencoder.AutoencodingEngine
4 | params:
5 | input_key: jpg
6 | monitor: val/loss/rec
7 | disc_start_iter: 0
8 |
9 | encoder_config:
10 | target: sgm.modules.diffusionmodules.model.Encoder
11 | params:
12 | attn_type: vanilla-xformers
13 | double_z: true
14 | z_channels: 8
15 | resolution: 256
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | num_res_blocks: 2
21 | attn_resolutions: []
22 | dropout: 0.0
23 |
24 | decoder_config:
25 | target: sgm.modules.diffusionmodules.model.Decoder
26 | params: ${model.params.encoder_config.params}
27 |
28 | regularizer_config:
29 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
30 |
31 | loss_config:
32 | target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
33 | params:
34 | perceptual_weight: 0.25
35 | disc_start: 20001
36 | disc_weight: 0.5
37 | learn_logvar: True
38 |
39 | regularization_weights:
40 | kl_loss: 1.0
41 |
42 | data:
43 | target: sgm.data.dataset.StableDataModuleFromConfig
44 | params:
45 | train:
46 | datapipeline:
47 | urls:
48 | - DATA-PATH
49 | pipeline_config:
50 | shardshuffle: 10000
51 | sample_shuffle: 10000
52 |
53 | decoders:
54 | - pil
55 |
56 | postprocessors:
57 | - target: sdata.mappers.TorchVisionImageTransforms
58 | params:
59 | key: jpg
60 | transforms:
61 | - target: torchvision.transforms.Resize
62 | params:
63 | size: 256
64 | interpolation: 3
65 | - target: torchvision.transforms.ToTensor
66 | - target: sdata.mappers.Rescaler
67 | - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
68 | params:
69 | h_key: height
70 | w_key: width
71 |
72 | loader:
73 | batch_size: 8
74 | num_workers: 4
75 |
76 |
77 | lightning:
78 | strategy:
79 | target: pytorch_lightning.strategies.DDPStrategy
80 | params:
81 | find_unused_parameters: True
82 |
83 | modelcheckpoint:
84 | params:
85 | every_n_train_steps: 5000
86 |
87 | callbacks:
88 | metrics_over_trainsteps_checkpoint:
89 | params:
90 | every_n_train_steps: 50000
91 |
92 | image_logger:
93 | target: main.ImageLogger
94 | params:
95 | enable_autocast: False
96 | batch_frequency: 1000
97 | max_images: 8
98 | increase_log_steps: True
99 |
100 | trainer:
101 | devices: 0,
102 | limit_val_batches: 50
103 | benchmark: True
104 | accumulate_grad_batches: 1
105 | val_check_interval: 10000
106 |
--------------------------------------------------------------------------------
/third_party/image_generator/configs/example_training/toy/cifar10_cond.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: sgm.models.diffusion.DiffusionEngine
4 | params:
5 | denoiser_config:
6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
7 | params:
8 | scaling_config:
9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10 | params:
11 | sigma_data: 1.0
12 |
13 | network_config:
14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15 | params:
16 | in_channels: 3
17 | out_channels: 3
18 | model_channels: 32
19 | attention_resolutions: []
20 | num_res_blocks: 4
21 | channel_mult: [1, 2, 2]
22 | num_head_channels: 32
23 | num_classes: sequential
24 | adm_in_channels: 128
25 |
26 | conditioner_config:
27 | target: sgm.modules.GeneralConditioner
28 | params:
29 | emb_models:
30 | - is_trainable: True
31 | input_key: cls
32 | ucg_rate: 0.2
33 | target: sgm.modules.encoders.modules.ClassEmbedder
34 | params:
35 | embed_dim: 128
36 | n_classes: 10
37 |
38 | first_stage_config:
39 | target: sgm.models.autoencoder.IdentityFirstStage
40 |
41 | loss_fn_config:
42 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43 | params:
44 | loss_weighting_config:
45 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
46 | params:
47 | sigma_data: 1.0
48 | sigma_sampler_config:
49 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
50 |
51 | sampler_config:
52 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
53 | params:
54 | num_steps: 50
55 |
56 | discretization_config:
57 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
58 |
59 | guider_config:
60 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG
61 | params:
62 | scale: 3.0
63 |
64 | data:
65 | target: sgm.data.cifar10.CIFAR10Loader
66 | params:
67 | batch_size: 512
68 | num_workers: 1
69 |
70 | lightning:
71 | modelcheckpoint:
72 | params:
73 | every_n_train_steps: 5000
74 |
75 | callbacks:
76 | metrics_over_trainsteps_checkpoint:
77 | params:
78 | every_n_train_steps: 25000
79 |
80 | image_logger:
81 | target: main.ImageLogger
82 | params:
83 | disabled: False
84 | batch_frequency: 1000
85 | max_images: 64
86 | increase_log_steps: True
87 | log_first_step: False
88 | log_images_kwargs:
89 | use_ema_scope: False
90 | N: 64
91 | n_rows: 8
92 |
93 | trainer:
94 | devices: 0,
95 | benchmark: True
96 | num_sanity_val_steps: 0
97 | accumulate_grad_batches: 1
98 | max_epochs: 20
--------------------------------------------------------------------------------
/third_party/image_generator/configs/example_training/toy/mnist.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: sgm.models.diffusion.DiffusionEngine
4 | params:
5 | denoiser_config:
6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
7 | params:
8 | scaling_config:
9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10 | params:
11 | sigma_data: 1.0
12 |
13 | network_config:
14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15 | params:
16 | in_channels: 1
17 | out_channels: 1
18 | model_channels: 32
19 | attention_resolutions: []
20 | num_res_blocks: 4
21 | channel_mult: [1, 2, 2]
22 | num_head_channels: 32
23 |
24 | first_stage_config:
25 | target: sgm.models.autoencoder.IdentityFirstStage
26 |
27 | loss_fn_config:
28 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
29 | params:
30 | loss_weighting_config:
31 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
32 | params:
33 | sigma_data: 1.0
34 | sigma_sampler_config:
35 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
36 |
37 | sampler_config:
38 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
39 | params:
40 | num_steps: 50
41 |
42 | discretization_config:
43 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
44 |
45 | data:
46 | target: sgm.data.mnist.MNISTLoader
47 | params:
48 | batch_size: 512
49 | num_workers: 1
50 |
51 | lightning:
52 | modelcheckpoint:
53 | params:
54 | every_n_train_steps: 5000
55 |
56 | callbacks:
57 | metrics_over_trainsteps_checkpoint:
58 | params:
59 | every_n_train_steps: 25000
60 |
61 | image_logger:
62 | target: main.ImageLogger
63 | params:
64 | disabled: False
65 | batch_frequency: 1000
66 | max_images: 64
67 | increase_log_steps: False
68 | log_first_step: False
69 | log_images_kwargs:
70 | use_ema_scope: False
71 | N: 64
72 | n_rows: 8
73 |
74 | trainer:
75 | devices: 0,
76 | benchmark: True
77 | num_sanity_val_steps: 0
78 | accumulate_grad_batches: 1
79 | max_epochs: 10
--------------------------------------------------------------------------------
/third_party/image_generator/configs/example_training/toy/mnist_cond.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: sgm.models.diffusion.DiffusionEngine
4 | params:
5 | denoiser_config:
6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
7 | params:
8 | scaling_config:
9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10 | params:
11 | sigma_data: 1.0
12 |
13 | network_config:
14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15 | params:
16 | in_channels: 1
17 | out_channels: 1
18 | model_channels: 32
19 | attention_resolutions: []
20 | num_res_blocks: 4
21 | channel_mult: [1, 2, 2]
22 | num_head_channels: 32
23 | num_classes: sequential
24 | adm_in_channels: 128
25 |
26 | conditioner_config:
27 | target: sgm.modules.GeneralConditioner
28 | params:
29 | emb_models:
30 | - is_trainable: True
31 | input_key: cls
32 | ucg_rate: 0.2
33 | target: sgm.modules.encoders.modules.ClassEmbedder
34 | params:
35 | embed_dim: 128
36 | n_classes: 10
37 |
38 | first_stage_config:
39 | target: sgm.models.autoencoder.IdentityFirstStage
40 |
41 | loss_fn_config:
42 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43 | params:
44 | loss_weighting_config:
45 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
46 | params:
47 | sigma_data: 1.0
48 | sigma_sampler_config:
49 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
50 |
51 | sampler_config:
52 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
53 | params:
54 | num_steps: 50
55 |
56 | discretization_config:
57 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
58 |
59 | guider_config:
60 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG
61 | params:
62 | scale: 3.0
63 |
64 | data:
65 | target: sgm.data.mnist.MNISTLoader
66 | params:
67 | batch_size: 512
68 | num_workers: 1
69 |
70 | lightning:
71 | modelcheckpoint:
72 | params:
73 | every_n_train_steps: 5000
74 |
75 | callbacks:
76 | metrics_over_trainsteps_checkpoint:
77 | params:
78 | every_n_train_steps: 25000
79 |
80 | image_logger:
81 | target: main.ImageLogger
82 | params:
83 | disabled: False
84 | batch_frequency: 1000
85 | max_images: 16
86 | increase_log_steps: True
87 | log_first_step: False
88 | log_images_kwargs:
89 | use_ema_scope: False
90 | N: 16
91 | n_rows: 4
92 |
93 | trainer:
94 | devices: 0,
95 | benchmark: True
96 | num_sanity_val_steps: 0
97 | accumulate_grad_batches: 1
98 | max_epochs: 20
--------------------------------------------------------------------------------
/third_party/image_generator/configs/example_training/toy/mnist_cond_discrete_eps.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: sgm.models.diffusion.DiffusionEngine
4 | params:
5 | denoiser_config:
6 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
7 | params:
8 | num_idx: 1000
9 |
10 | scaling_config:
11 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
12 | discretization_config:
13 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
14 |
15 | network_config:
16 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | in_channels: 1
19 | out_channels: 1
20 | model_channels: 32
21 | attention_resolutions: []
22 | num_res_blocks: 4
23 | channel_mult: [1, 2, 2]
24 | num_head_channels: 32
25 | num_classes: sequential
26 | adm_in_channels: 128
27 |
28 | conditioner_config:
29 | target: sgm.modules.GeneralConditioner
30 | params:
31 | emb_models:
32 | - is_trainable: True
33 | input_key: cls
34 | ucg_rate: 0.2
35 | target: sgm.modules.encoders.modules.ClassEmbedder
36 | params:
37 | embed_dim: 128
38 | n_classes: 10
39 |
40 | first_stage_config:
41 | target: sgm.models.autoencoder.IdentityFirstStage
42 |
43 | loss_fn_config:
44 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
45 | params:
46 | loss_weighting_config:
47 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
48 | sigma_sampler_config:
49 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
50 | params:
51 | num_idx: 1000
52 |
53 | discretization_config:
54 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
55 |
56 | sampler_config:
57 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
58 | params:
59 | num_steps: 50
60 |
61 | discretization_config:
62 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
63 |
64 | guider_config:
65 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG
66 | params:
67 | scale: 5.0
68 |
69 | data:
70 | target: sgm.data.mnist.MNISTLoader
71 | params:
72 | batch_size: 512
73 | num_workers: 1
74 |
75 | lightning:
76 | modelcheckpoint:
77 | params:
78 | every_n_train_steps: 5000
79 |
80 | callbacks:
81 | metrics_over_trainsteps_checkpoint:
82 | params:
83 | every_n_train_steps: 25000
84 |
85 | image_logger:
86 | target: main.ImageLogger
87 | params:
88 | disabled: False
89 | batch_frequency: 1000
90 | max_images: 16
91 | increase_log_steps: True
92 | log_first_step: False
93 | log_images_kwargs:
94 | use_ema_scope: False
95 | N: 16
96 | n_rows: 4
97 |
98 | trainer:
99 | devices: 0,
100 | benchmark: True
101 | num_sanity_val_steps: 0
102 | accumulate_grad_batches: 1
103 | max_epochs: 20
--------------------------------------------------------------------------------
/third_party/image_generator/configs/example_training/toy/mnist_cond_l1_loss.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: sgm.models.diffusion.DiffusionEngine
4 | params:
5 | denoiser_config:
6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
7 | params:
8 | scaling_config:
9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10 | params:
11 | sigma_data: 1.0
12 |
13 | network_config:
14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15 | params:
16 | in_channels: 1
17 | out_channels: 1
18 | model_channels: 32
19 | attention_resolutions: []
20 | num_res_blocks: 4
21 | channel_mult: [1, 2, 2]
22 | num_head_channels: 32
23 | num_classes: sequential
24 | adm_in_channels: 128
25 |
26 | conditioner_config:
27 | target: sgm.modules.GeneralConditioner
28 | params:
29 | emb_models:
30 | - is_trainable: True
31 | input_key: cls
32 | ucg_rate: 0.2
33 | target: sgm.modules.encoders.modules.ClassEmbedder
34 | params:
35 | embed_dim: 128
36 | n_classes: 10
37 |
38 | first_stage_config:
39 | target: sgm.models.autoencoder.IdentityFirstStage
40 |
41 | loss_fn_config:
42 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43 | params:
44 | loss_type: l1
45 | loss_weighting_config:
46 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
47 | params:
48 | sigma_data: 1.0
49 | sigma_sampler_config:
50 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
51 |
52 | sampler_config:
53 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
54 | params:
55 | num_steps: 50
56 |
57 | discretization_config:
58 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
59 |
60 | guider_config:
61 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG
62 | params:
63 | scale: 3.0
64 |
65 | data:
66 | target: sgm.data.mnist.MNISTLoader
67 | params:
68 | batch_size: 512
69 | num_workers: 1
70 |
71 | lightning:
72 | modelcheckpoint:
73 | params:
74 | every_n_train_steps: 5000
75 |
76 | callbacks:
77 | metrics_over_trainsteps_checkpoint:
78 | params:
79 | every_n_train_steps: 25000
80 |
81 | image_logger:
82 | target: main.ImageLogger
83 | params:
84 | disabled: False
85 | batch_frequency: 1000
86 | max_images: 64
87 | increase_log_steps: True
88 | log_first_step: False
89 | log_images_kwargs:
90 | use_ema_scope: False
91 | N: 64
92 | n_rows: 8
93 |
94 | trainer:
95 | devices: 0,
96 | benchmark: True
97 | num_sanity_val_steps: 0
98 | accumulate_grad_batches: 1
99 | max_epochs: 20
--------------------------------------------------------------------------------
/third_party/image_generator/configs/example_training/toy/mnist_cond_with_ema.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | target: sgm.models.diffusion.DiffusionEngine
4 | params:
5 | use_ema: True
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
9 | params:
10 | scaling_config:
11 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
12 | params:
13 | sigma_data: 1.0
14 |
15 | network_config:
16 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | in_channels: 1
19 | out_channels: 1
20 | model_channels: 32
21 | attention_resolutions: []
22 | num_res_blocks: 4
23 | channel_mult: [1, 2, 2]
24 | num_head_channels: 32
25 | num_classes: sequential
26 | adm_in_channels: 128
27 |
28 | conditioner_config:
29 | target: sgm.modules.GeneralConditioner
30 | params:
31 | emb_models:
32 | - is_trainable: True
33 | input_key: cls
34 | ucg_rate: 0.2
35 | target: sgm.modules.encoders.modules.ClassEmbedder
36 | params:
37 | embed_dim: 128
38 | n_classes: 10
39 |
40 | first_stage_config:
41 | target: sgm.models.autoencoder.IdentityFirstStage
42 |
43 | loss_fn_config:
44 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
45 | params:
46 | loss_weighting_config:
47 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
48 | params:
49 | sigma_data: 1.0
50 | sigma_sampler_config:
51 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
52 |
53 | sampler_config:
54 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
55 | params:
56 | num_steps: 50
57 |
58 | discretization_config:
59 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
60 |
61 | guider_config:
62 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG
63 | params:
64 | scale: 3.0
65 |
66 | data:
67 | target: sgm.data.mnist.MNISTLoader
68 | params:
69 | batch_size: 512
70 | num_workers: 1
71 |
72 | lightning:
73 | modelcheckpoint:
74 | params:
75 | every_n_train_steps: 5000
76 |
77 | callbacks:
78 | metrics_over_trainsteps_checkpoint:
79 | params:
80 | every_n_train_steps: 25000
81 |
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | disabled: False
86 | batch_frequency: 1000
87 | max_images: 64
88 | increase_log_steps: True
89 | log_first_step: False
90 | log_images_kwargs:
91 | use_ema_scope: False
92 | N: 64
93 | n_rows: 8
94 |
95 | trainer:
96 | devices: 0,
97 | benchmark: True
98 | num_sanity_val_steps: 0
99 | accumulate_grad_batches: 1
100 | max_epochs: 20
--------------------------------------------------------------------------------
/third_party/image_generator/configs/inference/sd_2_1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.18215
5 | disable_first_stage_autocast: True
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9 | params:
10 | num_idx: 1000
11 |
12 | scaling_config:
13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14 | discretization_config:
15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16 |
17 | network_config:
18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | use_checkpoint: True
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 320
24 | attention_resolutions: [4, 2, 1]
25 | num_res_blocks: 2
26 | channel_mult: [1, 2, 4, 4]
27 | num_head_channels: 64
28 | use_linear_in_transformer: True
29 | transformer_depth: 1
30 | context_dim: 1024
31 |
32 | conditioner_config:
33 | target: sgm.modules.GeneralConditioner
34 | params:
35 | emb_models:
36 | - is_trainable: False
37 | input_key: txt
38 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
39 | params:
40 | freeze: true
41 | layer: penultimate
42 |
43 | first_stage_config:
44 | target: sgm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: val/rec_loss
48 | ddconfig:
49 | double_z: true
50 | z_channels: 4
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult: [1, 2, 4, 4]
56 | num_res_blocks: 2
57 | attn_resolutions: []
58 | dropout: 0.0
59 | lossconfig:
60 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/third_party/image_generator/configs/inference/sd_2_1_768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.18215
5 | disable_first_stage_autocast: True
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9 | params:
10 | num_idx: 1000
11 |
12 | scaling_config:
13 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
14 | discretization_config:
15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16 |
17 | network_config:
18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | use_checkpoint: True
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 320
24 | attention_resolutions: [4, 2, 1]
25 | num_res_blocks: 2
26 | channel_mult: [1, 2, 4, 4]
27 | num_head_channels: 64
28 | use_linear_in_transformer: True
29 | transformer_depth: 1
30 | context_dim: 1024
31 |
32 | conditioner_config:
33 | target: sgm.modules.GeneralConditioner
34 | params:
35 | emb_models:
36 | - is_trainable: False
37 | input_key: txt
38 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
39 | params:
40 | freeze: true
41 | layer: penultimate
42 |
43 | first_stage_config:
44 | target: sgm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: val/rec_loss
48 | ddconfig:
49 | double_z: true
50 | z_channels: 4
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult: [1, 2, 4, 4]
56 | num_res_blocks: 2
57 | attn_resolutions: []
58 | dropout: 0.0
59 | lossconfig:
60 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/third_party/image_generator/configs/inference/sd_xl_base.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.13025
5 | disable_first_stage_autocast: True
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9 | params:
10 | num_idx: 1000
11 |
12 | scaling_config:
13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14 | discretization_config:
15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16 |
17 | network_config:
18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | adm_in_channels: 2816
21 | num_classes: sequential
22 | use_checkpoint: True
23 | in_channels: 4
24 | out_channels: 4
25 | model_channels: 320
26 | attention_resolutions: [4, 2]
27 | num_res_blocks: 2
28 | channel_mult: [1, 2, 4]
29 | num_head_channels: 64
30 | use_linear_in_transformer: True
31 | transformer_depth: [1, 2, 10]
32 | context_dim: 2048
33 | spatial_transformer_attn_type: softmax-xformers
34 |
35 | conditioner_config:
36 | target: sgm.modules.GeneralConditioner
37 | params:
38 | emb_models:
39 | - is_trainable: False
40 | input_key: txt
41 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
42 | params:
43 | layer: hidden
44 | layer_idx: 11
45 |
46 | - is_trainable: False
47 | input_key: txt
48 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
49 | params:
50 | arch: ViT-bigG-14
51 | version: laion2b_s39b_b160k
52 | freeze: True
53 | layer: penultimate
54 | always_return_pooled: True
55 | legacy: False
56 |
57 | - is_trainable: False
58 | input_key: original_size_as_tuple
59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60 | params:
61 | outdim: 256
62 |
63 | - is_trainable: False
64 | input_key: crop_coords_top_left
65 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
66 | params:
67 | outdim: 256
68 |
69 | - is_trainable: False
70 | input_key: target_size_as_tuple
71 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
72 | params:
73 | outdim: 256
74 |
75 | first_stage_config:
76 | target: sgm.models.autoencoder.AutoencoderKL
77 | params:
78 | embed_dim: 4
79 | monitor: val/rec_loss
80 | ddconfig:
81 | attn_type: vanilla-xformers
82 | double_z: true
83 | z_channels: 4
84 | resolution: 256
85 | in_channels: 3
86 | out_ch: 3
87 | ch: 128
88 | ch_mult: [1, 2, 4, 4]
89 | num_res_blocks: 2
90 | attn_resolutions: []
91 | dropout: 0.0
92 | lossconfig:
93 | target: torch.nn.Identity
94 |
--------------------------------------------------------------------------------
/third_party/image_generator/configs/inference/sd_xl_refiner.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.13025
5 | disable_first_stage_autocast: True
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9 | params:
10 | num_idx: 1000
11 |
12 | scaling_config:
13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14 | discretization_config:
15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16 |
17 | network_config:
18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | adm_in_channels: 2560
21 | num_classes: sequential
22 | use_checkpoint: True
23 | in_channels: 4
24 | out_channels: 4
25 | model_channels: 384
26 | attention_resolutions: [4, 2]
27 | num_res_blocks: 2
28 | channel_mult: [1, 2, 4, 4]
29 | num_head_channels: 64
30 | use_linear_in_transformer: True
31 | transformer_depth: 4
32 | context_dim: [1280, 1280, 1280, 1280]
33 | spatial_transformer_attn_type: softmax-xformers
34 |
35 | conditioner_config:
36 | target: sgm.modules.GeneralConditioner
37 | params:
38 | emb_models:
39 | - is_trainable: False
40 | input_key: txt
41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
42 | params:
43 | arch: ViT-bigG-14
44 | version: laion2b_s39b_b160k
45 | legacy: False
46 | freeze: True
47 | layer: penultimate
48 | always_return_pooled: True
49 |
50 | - is_trainable: False
51 | input_key: original_size_as_tuple
52 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53 | params:
54 | outdim: 256
55 |
56 | - is_trainable: False
57 | input_key: crop_coords_top_left
58 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59 | params:
60 | outdim: 256
61 |
62 | - is_trainable: False
63 | input_key: aesthetic_score
64 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
65 | params:
66 | outdim: 256
67 |
68 | first_stage_config:
69 | target: sgm.models.autoencoder.AutoencoderKL
70 | params:
71 | embed_dim: 4
72 | monitor: val/rec_loss
73 | ddconfig:
74 | attn_type: vanilla-xformers
75 | double_z: true
76 | z_channels: 4
77 | resolution: 256
78 | in_channels: 3
79 | out_ch: 3
80 | ch: 128
81 | ch_mult: [1, 2, 4, 4]
82 | num_res_blocks: 2
83 | attn_resolutions: []
84 | dropout: 0.0
85 | lossconfig:
86 | target: torch.nn.Identity
87 |
--------------------------------------------------------------------------------
/third_party/image_generator/configs/inference/sv3d_p.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.18215
5 | disable_first_stage_autocast: True
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
9 | params:
10 | scaling_config:
11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12 |
13 | network_config:
14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet
15 | params:
16 | adm_in_channels: 1280
17 | num_classes: sequential
18 | use_checkpoint: True
19 | in_channels: 8
20 | out_channels: 4
21 | model_channels: 320
22 | attention_resolutions: [4, 2, 1]
23 | num_res_blocks: 2
24 | channel_mult: [1, 2, 4, 4]
25 | num_head_channels: 64
26 | use_linear_in_transformer: True
27 | transformer_depth: 1
28 | context_dim: 1024
29 | spatial_transformer_attn_type: softmax-xformers
30 | extra_ff_mix_layer: True
31 | use_spatial_context: True
32 | merge_strategy: learned_with_images
33 | video_kernel_size: [3, 1, 1]
34 |
35 | conditioner_config:
36 | target: sgm.modules.GeneralConditioner
37 | params:
38 | emb_models:
39 | - input_key: cond_frames_without_noise
40 | is_trainable: False
41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42 | params:
43 | n_cond_frames: 1
44 | n_copies: 1
45 | open_clip_embedding_config:
46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47 | params:
48 | freeze: True
49 |
50 | - input_key: cond_frames
51 | is_trainable: False
52 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
53 | params:
54 | disable_encoder_autocast: True
55 | n_cond_frames: 1
56 | n_copies: 1
57 | is_ae: True
58 | encoder_config:
59 | target: sgm.models.autoencoder.AutoencoderKLModeOnly
60 | params:
61 | embed_dim: 4
62 | monitor: val/rec_loss
63 | ddconfig:
64 | attn_type: vanilla-xformers
65 | double_z: True
66 | z_channels: 4
67 | resolution: 256
68 | in_channels: 3
69 | out_ch: 3
70 | ch: 128
71 | ch_mult: [1, 2, 4, 4]
72 | num_res_blocks: 2
73 | attn_resolutions: []
74 | dropout: 0.0
75 | lossconfig:
76 | target: torch.nn.Identity
77 |
78 | - input_key: cond_aug
79 | is_trainable: False
80 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
81 | params:
82 | outdim: 256
83 |
84 | - input_key: polars_rad
85 | is_trainable: False
86 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
87 | params:
88 | outdim: 512
89 |
90 | - input_key: azimuths_rad
91 | is_trainable: False
92 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
93 | params:
94 | outdim: 512
95 |
96 | first_stage_config:
97 | target: sgm.models.autoencoder.AutoencodingEngine
98 | params:
99 | loss_config:
100 | target: torch.nn.Identity
101 | regularizer_config:
102 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
103 | encoder_config:
104 | target: torch.nn.Identity
105 | decoder_config:
106 | target: sgm.modules.diffusionmodules.model.Decoder
107 | params:
108 | attn_type: vanilla-xformers
109 | double_z: True
110 | z_channels: 4
111 | resolution: 256
112 | in_channels: 3
113 | out_ch: 3
114 | ch: 128
115 | ch_mult: [ 1, 2, 4, 4 ]
116 | num_res_blocks: 2
117 | attn_resolutions: [ ]
118 | dropout: 0.0
--------------------------------------------------------------------------------
/third_party/image_generator/configs/inference/sv3d_u.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.18215
5 | disable_first_stage_autocast: True
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
9 | params:
10 | scaling_config:
11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12 |
13 | network_config:
14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet
15 | params:
16 | adm_in_channels: 256
17 | num_classes: sequential
18 | use_checkpoint: True
19 | in_channels: 8
20 | out_channels: 4
21 | model_channels: 320
22 | attention_resolutions: [4, 2, 1]
23 | num_res_blocks: 2
24 | channel_mult: [1, 2, 4, 4]
25 | num_head_channels: 64
26 | use_linear_in_transformer: True
27 | transformer_depth: 1
28 | context_dim: 1024
29 | spatial_transformer_attn_type: softmax-xformers
30 | extra_ff_mix_layer: True
31 | use_spatial_context: True
32 | merge_strategy: learned_with_images
33 | video_kernel_size: [3, 1, 1]
34 |
35 | conditioner_config:
36 | target: sgm.modules.GeneralConditioner
37 | params:
38 | emb_models:
39 | - input_key: cond_frames_without_noise
40 | is_trainable: False
41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42 | params:
43 | n_cond_frames: 1
44 | n_copies: 1
45 | open_clip_embedding_config:
46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47 | params:
48 | freeze: True
49 |
50 | - input_key: cond_frames
51 | is_trainable: False
52 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
53 | params:
54 | disable_encoder_autocast: True
55 | n_cond_frames: 1
56 | n_copies: 1
57 | is_ae: True
58 | encoder_config:
59 | target: sgm.models.autoencoder.AutoencoderKLModeOnly
60 | params:
61 | embed_dim: 4
62 | monitor: val/rec_loss
63 | ddconfig:
64 | attn_type: vanilla-xformers
65 | double_z: True
66 | z_channels: 4
67 | resolution: 256
68 | in_channels: 3
69 | out_ch: 3
70 | ch: 128
71 | ch_mult: [1, 2, 4, 4]
72 | num_res_blocks: 2
73 | attn_resolutions: []
74 | dropout: 0.0
75 | lossconfig:
76 | target: torch.nn.Identity
77 |
78 | - input_key: cond_aug
79 | is_trainable: False
80 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
81 | params:
82 | outdim: 256
83 |
84 | first_stage_config:
85 | target: sgm.models.autoencoder.AutoencodingEngine
86 | params:
87 | loss_config:
88 | target: torch.nn.Identity
89 | regularizer_config:
90 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
91 | encoder_config:
92 | target: torch.nn.Identity
93 | decoder_config:
94 | target: sgm.modules.diffusionmodules.model.Decoder
95 | params:
96 | attn_type: vanilla-xformers
97 | double_z: True
98 | z_channels: 4
99 | resolution: 256
100 | in_channels: 3
101 | out_ch: 3
102 | ch: 128
103 | ch_mult: [ 1, 2, 4, 4 ]
104 | num_res_blocks: 2
105 | attn_resolutions: [ ]
106 | dropout: 0.0
--------------------------------------------------------------------------------
/third_party/image_generator/configs/inference/svd.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.18215
5 | disable_first_stage_autocast: True
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
9 | params:
10 | scaling_config:
11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12 |
13 | network_config:
14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet
15 | params:
16 | adm_in_channels: 768
17 | num_classes: sequential
18 | use_checkpoint: True
19 | in_channels: 8
20 | out_channels: 4
21 | model_channels: 320
22 | attention_resolutions: [4, 2, 1]
23 | num_res_blocks: 2
24 | channel_mult: [1, 2, 4, 4]
25 | num_head_channels: 64
26 | use_linear_in_transformer: True
27 | transformer_depth: 1
28 | context_dim: 1024
29 | spatial_transformer_attn_type: softmax-xformers
30 | extra_ff_mix_layer: True
31 | use_spatial_context: True
32 | merge_strategy: learned_with_images
33 | video_kernel_size: [3, 1, 1]
34 |
35 | conditioner_config:
36 | target: sgm.modules.GeneralConditioner
37 | params:
38 | emb_models:
39 | - is_trainable: False
40 | input_key: cond_frames_without_noise
41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42 | params:
43 | n_cond_frames: 1
44 | n_copies: 1
45 | open_clip_embedding_config:
46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47 | params:
48 | freeze: True
49 |
50 | - input_key: fps_id
51 | is_trainable: False
52 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53 | params:
54 | outdim: 256
55 |
56 | - input_key: motion_bucket_id
57 | is_trainable: False
58 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59 | params:
60 | outdim: 256
61 |
62 | - input_key: cond_frames
63 | is_trainable: False
64 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
65 | params:
66 | disable_encoder_autocast: True
67 | n_cond_frames: 1
68 | n_copies: 1
69 | is_ae: True
70 | encoder_config:
71 | target: sgm.models.autoencoder.AutoencoderKLModeOnly
72 | params:
73 | embed_dim: 4
74 | monitor: val/rec_loss
75 | ddconfig:
76 | attn_type: vanilla-xformers
77 | double_z: True
78 | z_channels: 4
79 | resolution: 256
80 | in_channels: 3
81 | out_ch: 3
82 | ch: 128
83 | ch_mult: [1, 2, 4, 4]
84 | num_res_blocks: 2
85 | attn_resolutions: []
86 | dropout: 0.0
87 | lossconfig:
88 | target: torch.nn.Identity
89 |
90 | - input_key: cond_aug
91 | is_trainable: False
92 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
93 | params:
94 | outdim: 256
95 |
96 | first_stage_config:
97 | target: sgm.models.autoencoder.AutoencodingEngine
98 | params:
99 | loss_config:
100 | target: torch.nn.Identity
101 | regularizer_config:
102 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
103 | encoder_config:
104 | target: sgm.modules.diffusionmodules.model.Encoder
105 | params:
106 | attn_type: vanilla
107 | double_z: True
108 | z_channels: 4
109 | resolution: 256
110 | in_channels: 3
111 | out_ch: 3
112 | ch: 128
113 | ch_mult: [1, 2, 4, 4]
114 | num_res_blocks: 2
115 | attn_resolutions: []
116 | dropout: 0.0
117 | decoder_config:
118 | target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
119 | params:
120 | attn_type: vanilla
121 | double_z: True
122 | z_channels: 4
123 | resolution: 256
124 | in_channels: 3
125 | out_ch: 3
126 | ch: 128
127 | ch_mult: [1, 2, 4, 4]
128 | num_res_blocks: 2
129 | attn_resolutions: []
130 | dropout: 0.0
131 | video_kernel_size: [3, 1, 1]
--------------------------------------------------------------------------------
/third_party/image_generator/configs/inference/svd_image_decoder.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.18215
5 | disable_first_stage_autocast: True
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
9 | params:
10 | scaling_config:
11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12 |
13 | network_config:
14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet
15 | params:
16 | adm_in_channels: 768
17 | num_classes: sequential
18 | use_checkpoint: True
19 | in_channels: 8
20 | out_channels: 4
21 | model_channels: 320
22 | attention_resolutions: [4, 2, 1]
23 | num_res_blocks: 2
24 | channel_mult: [1, 2, 4, 4]
25 | num_head_channels: 64
26 | use_linear_in_transformer: True
27 | transformer_depth: 1
28 | context_dim: 1024
29 | spatial_transformer_attn_type: softmax-xformers
30 | extra_ff_mix_layer: True
31 | use_spatial_context: True
32 | merge_strategy: learned_with_images
33 | video_kernel_size: [3, 1, 1]
34 |
35 | conditioner_config:
36 | target: sgm.modules.GeneralConditioner
37 | params:
38 | emb_models:
39 | - is_trainable: False
40 | input_key: cond_frames_without_noise
41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42 | params:
43 | n_cond_frames: 1
44 | n_copies: 1
45 | open_clip_embedding_config:
46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47 | params:
48 | freeze: True
49 |
50 | - input_key: fps_id
51 | is_trainable: False
52 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53 | params:
54 | outdim: 256
55 |
56 | - input_key: motion_bucket_id
57 | is_trainable: False
58 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59 | params:
60 | outdim: 256
61 |
62 | - input_key: cond_frames
63 | is_trainable: False
64 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
65 | params:
66 | disable_encoder_autocast: True
67 | n_cond_frames: 1
68 | n_copies: 1
69 | is_ae: True
70 | encoder_config:
71 | target: sgm.models.autoencoder.AutoencoderKLModeOnly
72 | params:
73 | embed_dim: 4
74 | monitor: val/rec_loss
75 | ddconfig:
76 | attn_type: vanilla-xformers
77 | double_z: True
78 | z_channels: 4
79 | resolution: 256
80 | in_channels: 3
81 | out_ch: 3
82 | ch: 128
83 | ch_mult: [1, 2, 4, 4]
84 | num_res_blocks: 2
85 | attn_resolutions: []
86 | dropout: 0.0
87 | lossconfig:
88 | target: torch.nn.Identity
89 |
90 | - input_key: cond_aug
91 | is_trainable: False
92 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
93 | params:
94 | outdim: 256
95 |
96 | first_stage_config:
97 | target: sgm.models.autoencoder.AutoencoderKL
98 | params:
99 | embed_dim: 4
100 | monitor: val/rec_loss
101 | ddconfig:
102 | attn_type: vanilla-xformers
103 | double_z: True
104 | z_channels: 4
105 | resolution: 256
106 | in_channels: 3
107 | out_ch: 3
108 | ch: 128
109 | ch_mult: [1, 2, 4, 4]
110 | num_res_blocks: 2
111 | attn_resolutions: []
112 | dropout: 0.0
113 | lossconfig:
114 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/third_party/image_generator/configs/sd_xl_base.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.13025
5 | disable_first_stage_autocast: True
6 |
7 | denoiser_config:
8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9 | params:
10 | num_idx: 1000
11 |
12 | scaling_config:
13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14 | discretization_config:
15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16 |
17 | network_config:
18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | adm_in_channels: 2816
21 | num_classes: sequential
22 | use_checkpoint: True
23 | in_channels: 4
24 | out_channels: 4
25 | model_channels: 320
26 | attention_resolutions: [4, 2]
27 | num_res_blocks: 2
28 | channel_mult: [1, 2, 4]
29 | num_head_channels: 64
30 | use_linear_in_transformer: True
31 | transformer_depth: [1, 2, 10]
32 | context_dim: 2048
33 | spatial_transformer_attn_type: softmax-xformers
34 |
35 | conditioner_config:
36 | target: sgm.modules.GeneralConditioner
37 | params:
38 | emb_models:
39 | - is_trainable: False
40 | input_key: txt
41 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
42 | params:
43 | layer: hidden
44 | layer_idx: 11
45 |
46 | - is_trainable: False
47 | input_key: txt
48 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
49 | params:
50 | arch: ViT-bigG-14
51 | version: laion2b_s39b_b160k
52 | freeze: True
53 | layer: penultimate
54 | always_return_pooled: True
55 | legacy: False
56 |
57 | - is_trainable: False
58 | input_key: original_size_as_tuple
59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60 | params:
61 | outdim: 256
62 |
63 | - is_trainable: False
64 | input_key: crop_coords_top_left
65 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
66 | params:
67 | outdim: 256
68 |
69 | - is_trainable: False
70 | input_key: target_size_as_tuple
71 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
72 | params:
73 | outdim: 256
74 |
75 | first_stage_config:
76 | target: sgm.models.autoencoder.AutoencoderKL
77 | params:
78 | embed_dim: 4
79 | monitor: val/rec_loss
80 | ddconfig:
81 | attn_type: vanilla-xformers
82 | double_z: true
83 | z_channels: 4
84 | resolution: 256
85 | in_channels: 3
86 | out_ch: 3
87 | ch: 128
88 | ch_mult: [1, 2, 4, 4]
89 | num_res_blocks: 2
90 | attn_resolutions: []
91 | dropout: 0.0
92 | lossconfig:
93 | target: torch.nn.Identity
94 |
--------------------------------------------------------------------------------
/third_party/image_generator/data/DejaVuSans.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/data/DejaVuSans.ttf
--------------------------------------------------------------------------------
/third_party/image_generator/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "sgm"
7 | dynamic = ["version"]
8 | description = "Stability Generative Models"
9 | readme = "README.md"
10 | license-files = { paths = ["LICENSE-CODE"] }
11 | requires-python = ">=3.8"
12 |
13 | [project.urls]
14 | Homepage = "https://github.com/Stability-AI/generative-models"
15 |
16 | [tool.hatch.version]
17 | path = "sgm/__init__.py"
18 |
19 | [tool.hatch.build]
20 | # This needs to be explicitly set so the configuration files
21 | # grafted into the `sgm` directory get included in the wheel's
22 | # RECORD file.
23 | include = [
24 | "sgm",
25 | ]
26 | # The force-include configurations below make Hatch copy
27 | # the configs/ directory (containing the various YAML files required
28 | # to generatively model) into the source distribution and the wheel.
29 |
30 | [tool.hatch.build.targets.sdist.force-include]
31 | "./configs" = "sgm/configs"
32 |
33 | [tool.hatch.build.targets.wheel.force-include]
34 | "./configs" = "sgm/configs"
35 |
36 | [tool.hatch.envs.ci]
37 | skip-install = false
38 |
39 | dependencies = [
40 | "pytest"
41 | ]
42 |
43 | [tool.hatch.envs.ci.scripts]
44 | test-inference = [
45 | "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
46 | "pip install -r requirements/pt2.txt",
47 | "pytest -v tests/inference/test_inference.py {args}",
48 | ]
49 |
--------------------------------------------------------------------------------
/third_party/image_generator/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | markers =
3 | inference: mark as inference test (deselect with '-m "not inference"')
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/demo/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/demo/discretization.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from third_party.image_generator.sgm.modules.diffusionmodules.discretizer import Discretization
4 |
5 |
6 | class Img2ImgDiscretizationWrapper:
7 | """
8 | wraps a discretizer, and prunes the sigmas
9 | params:
10 | strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
11 | """
12 |
13 | def __init__(self, discretization: Discretization, strength: float = 1.0):
14 | self.discretization = discretization
15 | self.strength = strength
16 | assert 0.0 <= self.strength <= 1.0
17 |
18 | def __call__(self, *args, **kwargs):
19 | # sigmas start large first, and decrease then
20 | sigmas = self.discretization(*args, **kwargs)
21 | print(f"sigmas after discretization, before pruning img2img: ", sigmas)
22 | sigmas = torch.flip(sigmas, (0,))
23 | sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
24 | print("prune index:", max(int(self.strength * len(sigmas)), 1))
25 | sigmas = torch.flip(sigmas, (0,))
26 | print(f"sigmas after pruning: ", sigmas)
27 | return sigmas
28 |
29 |
30 | class Txt2NoisyDiscretizationWrapper:
31 | """
32 | wraps a discretizer, and prunes the sigmas
33 | params:
34 | strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
35 | """
36 |
37 | def __init__(
38 | self, discretization: Discretization, strength: float = 0.0, original_steps=None
39 | ):
40 | self.discretization = discretization
41 | self.strength = strength
42 | self.original_steps = original_steps
43 | assert 0.0 <= self.strength <= 1.0
44 |
45 | def __call__(self, *args, **kwargs):
46 | # sigmas start large first, and decrease then
47 | sigmas = self.discretization(*args, **kwargs)
48 | print(f"sigmas after discretization, before pruning img2img: ", sigmas)
49 | sigmas = torch.flip(sigmas, (0,))
50 | if self.original_steps is None:
51 | steps = len(sigmas)
52 | else:
53 | steps = self.original_steps + 1
54 | prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
55 | sigmas = sigmas[prune_index:]
56 | print("prune index:", prune_index)
57 | sigmas = torch.flip(sigmas, (0,))
58 | print(f"sigmas after pruning: ", sigmas)
59 | return sigmas
60 |
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/demo/sv3d_helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 |
6 |
7 | def generate_dynamic_cycle_xy_values(
8 | length=21,
9 | init_elev=0,
10 | num_components=84,
11 | frequency_range=(1, 5),
12 | amplitude_range=(0.5, 10),
13 | step_range=(0, 2),
14 | ):
15 | # Y values generation
16 | y_sequence = np.ones(length) * init_elev
17 | for _ in range(num_components):
18 | # Choose a frequency that will complete whole cycles in the sequence
19 | frequency = np.random.randint(*frequency_range) * (2 * np.pi / length)
20 | amplitude = np.random.uniform(*amplitude_range)
21 | phase_shift = np.random.choice([0, np.pi]) # np.random.uniform(0, 2 * np.pi)
22 | angles = (
23 | np.linspace(0, frequency * length, length, endpoint=False) + phase_shift
24 | )
25 | y_sequence += np.sin(angles) * amplitude
26 | # X values generation
27 | # Generate length - 1 steps since the last step is back to start
28 | steps = np.random.uniform(*step_range, length - 1)
29 | total_step_sum = np.sum(steps)
30 | # Calculate the scale factor to scale total steps to just under 360
31 | scale_factor = (
32 | 360 - ((360 / length) * np.random.uniform(*step_range))
33 | ) / total_step_sum
34 | # Apply the scale factor and generate the sequence of X values
35 | x_values = np.cumsum(steps * scale_factor)
36 | # Ensure the sequence starts at 0 and add the final step to complete the loop
37 | x_values = np.insert(x_values, 0, 0)
38 | return x_values, y_sequence
39 |
40 |
41 | def smooth_data(data, window_size):
42 | # Extend data at both ends by wrapping around to create a continuous loop
43 | pad_size = window_size
44 | padded_data = np.concatenate((data[-pad_size:], data, data[:pad_size]))
45 |
46 | # Apply smoothing
47 | kernel = np.ones(window_size) / window_size
48 | smoothed_data = np.convolve(padded_data, kernel, mode="same")
49 |
50 | # Extract the smoothed data corresponding to the original sequence
51 | # Adjust the indices to account for the larger padding
52 | start_index = pad_size
53 | end_index = -pad_size if pad_size != 0 else None
54 | smoothed_original_data = smoothed_data[start_index:end_index]
55 | return smoothed_original_data
56 |
57 |
58 | # Function to generate and process the data
59 | def gen_dynamic_loop(length=21, elev_deg=0):
60 | while True:
61 | # Generate the combined X and Y values using the new function
62 | azim_values, elev_values = generate_dynamic_cycle_xy_values(
63 | length=84, init_elev=elev_deg
64 | )
65 | # Smooth the Y values directly
66 | smoothed_elev_values = smooth_data(elev_values, 5)
67 | max_magnitude = np.max(np.abs(smoothed_elev_values))
68 | if max_magnitude < 90:
69 | break
70 | subsample = 84 // length
71 | azim_rad = np.deg2rad(azim_values[::subsample])
72 | elev_rad = np.deg2rad(smoothed_elev_values[::subsample])
73 | # Make cond frame the last one
74 | return np.roll(azim_rad, -1), np.roll(elev_rad, -1)
75 |
76 |
77 | def plot_3D(azim, polar, save_path, dynamic=True):
78 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
79 | elev = np.deg2rad(90) - polar
80 | fig = plt.figure(figsize=(5, 5))
81 | ax = fig.add_subplot(projection="3d")
82 | cm = plt.get_cmap("Greys")
83 | col_line = [cm(i) for i in np.linspace(0.3, 1, len(azim) + 1)]
84 | cm = plt.get_cmap("cool")
85 | col = [cm(float(i) / (len(azim))) for i in np.arange(len(azim))]
86 | xs = np.cos(elev) * np.cos(azim)
87 | ys = np.cos(elev) * np.sin(azim)
88 | zs = np.sin(elev)
89 | ax.scatter(xs[0], ys[0], zs[0], s=100, color=col[0])
90 | xs_d, ys_d, zs_d = (xs[1:] - xs[:-1]), (ys[1:] - ys[:-1]), (zs[1:] - zs[:-1])
91 | for i in range(len(xs) - 1):
92 | if dynamic:
93 | ax.quiver(
94 | xs[i], ys[i], zs[i], xs_d[i], ys_d[i], zs_d[i], lw=2, color=col_line[i]
95 | )
96 | else:
97 | ax.plot(xs[i : i + 2], ys[i : i + 2], zs[i : i + 2], lw=2, c=col_line[i])
98 | ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1])
99 | ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k")
100 | ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k")
101 | ax.view_init(elev=30, azim=-20, roll=0)
102 | plt.savefig(save_path, bbox_inches="tight")
103 | plt.clf()
104 | plt.close()
105 |
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/sampling/configs/sv3d_p.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.18215
5 | disable_first_stage_autocast: True
6 | ckpt_path: checkpoints/sv3d_p.safetensors
7 |
8 | denoiser_config:
9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
10 | params:
11 | scaling_config:
12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13 |
14 | network_config:
15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet
16 | params:
17 | adm_in_channels: 1280
18 | num_classes: sequential
19 | use_checkpoint: True
20 | in_channels: 8
21 | out_channels: 4
22 | model_channels: 320
23 | attention_resolutions: [4, 2, 1]
24 | num_res_blocks: 2
25 | channel_mult: [1, 2, 4, 4]
26 | num_head_channels: 64
27 | use_linear_in_transformer: True
28 | transformer_depth: 1
29 | context_dim: 1024
30 | spatial_transformer_attn_type: softmax-xformers
31 | extra_ff_mix_layer: True
32 | use_spatial_context: True
33 | merge_strategy: learned_with_images
34 | video_kernel_size: [3, 1, 1]
35 |
36 | conditioner_config:
37 | target: sgm.modules.GeneralConditioner
38 | params:
39 | emb_models:
40 | - input_key: cond_frames_without_noise
41 | is_trainable: False
42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43 | params:
44 | n_cond_frames: 1
45 | n_copies: 1
46 | open_clip_embedding_config:
47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48 | params:
49 | freeze: True
50 |
51 | - input_key: cond_frames
52 | is_trainable: False
53 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
54 | params:
55 | disable_encoder_autocast: True
56 | n_cond_frames: 1
57 | n_copies: 1
58 | is_ae: True
59 | encoder_config:
60 | target: sgm.models.autoencoder.AutoencoderKLModeOnly
61 | params:
62 | embed_dim: 4
63 | monitor: val/rec_loss
64 | ddconfig:
65 | attn_type: vanilla-xformers
66 | double_z: True
67 | z_channels: 4
68 | resolution: 256
69 | in_channels: 3
70 | out_ch: 3
71 | ch: 128
72 | ch_mult: [1, 2, 4, 4]
73 | num_res_blocks: 2
74 | attn_resolutions: []
75 | dropout: 0.0
76 | lossconfig:
77 | target: torch.nn.Identity
78 |
79 | - input_key: cond_aug
80 | is_trainable: False
81 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
82 | params:
83 | outdim: 256
84 |
85 | - input_key: polars_rad
86 | is_trainable: False
87 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
88 | params:
89 | outdim: 512
90 |
91 | - input_key: azimuths_rad
92 | is_trainable: False
93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94 | params:
95 | outdim: 512
96 |
97 | first_stage_config:
98 | target: sgm.models.autoencoder.AutoencodingEngine
99 | params:
100 | loss_config:
101 | target: torch.nn.Identity
102 | regularizer_config:
103 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
104 | encoder_config:
105 | target: torch.nn.Identity
106 | decoder_config:
107 | target: sgm.modules.diffusionmodules.model.Decoder
108 | params:
109 | attn_type: vanilla-xformers
110 | double_z: True
111 | z_channels: 4
112 | resolution: 256
113 | in_channels: 3
114 | out_ch: 3
115 | ch: 128
116 | ch_mult: [ 1, 2, 4, 4 ]
117 | num_res_blocks: 2
118 | attn_resolutions: [ ]
119 | dropout: 0.0
120 |
121 | sampler_config:
122 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
123 | params:
124 | discretization_config:
125 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
126 | params:
127 | sigma_max: 700.0
128 |
129 | guider_config:
130 | target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider
131 | params:
132 | max_scale: 2.5
133 |
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/sampling/configs/sv3d_u.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.18215
5 | disable_first_stage_autocast: True
6 | ckpt_path: checkpoints/sv3d_u.safetensors
7 |
8 | denoiser_config:
9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
10 | params:
11 | scaling_config:
12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13 |
14 | network_config:
15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet
16 | params:
17 | adm_in_channels: 256
18 | num_classes: sequential
19 | use_checkpoint: True
20 | in_channels: 8
21 | out_channels: 4
22 | model_channels: 320
23 | attention_resolutions: [4, 2, 1]
24 | num_res_blocks: 2
25 | channel_mult: [1, 2, 4, 4]
26 | num_head_channels: 64
27 | use_linear_in_transformer: True
28 | transformer_depth: 1
29 | context_dim: 1024
30 | spatial_transformer_attn_type: softmax-xformers
31 | extra_ff_mix_layer: True
32 | use_spatial_context: True
33 | merge_strategy: learned_with_images
34 | video_kernel_size: [3, 1, 1]
35 |
36 | conditioner_config:
37 | target: sgm.modules.GeneralConditioner
38 | params:
39 | emb_models:
40 | - is_trainable: False
41 | input_key: cond_frames_without_noise
42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43 | params:
44 | n_cond_frames: 1
45 | n_copies: 1
46 | open_clip_embedding_config:
47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48 | params:
49 | freeze: True
50 |
51 | - input_key: cond_frames
52 | is_trainable: False
53 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
54 | params:
55 | disable_encoder_autocast: True
56 | n_cond_frames: 1
57 | n_copies: 1
58 | is_ae: True
59 | encoder_config:
60 | target: sgm.models.autoencoder.AutoencoderKLModeOnly
61 | params:
62 | embed_dim: 4
63 | monitor: val/rec_loss
64 | ddconfig:
65 | attn_type: vanilla-xformers
66 | double_z: True
67 | z_channels: 4
68 | resolution: 256
69 | in_channels: 3
70 | out_ch: 3
71 | ch: 128
72 | ch_mult: [1, 2, 4, 4]
73 | num_res_blocks: 2
74 | attn_resolutions: []
75 | dropout: 0.0
76 | lossconfig:
77 | target: torch.nn.Identity
78 |
79 | - input_key: cond_aug
80 | is_trainable: False
81 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
82 | params:
83 | outdim: 256
84 |
85 | first_stage_config:
86 | target: sgm.models.autoencoder.AutoencodingEngine
87 | params:
88 | loss_config:
89 | target: torch.nn.Identity
90 | regularizer_config:
91 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
92 | encoder_config:
93 | target: torch.nn.Identity
94 | decoder_config:
95 | target: sgm.modules.diffusionmodules.model.Decoder
96 | params:
97 | attn_type: vanilla-xformers
98 | double_z: True
99 | z_channels: 4
100 | resolution: 256
101 | in_channels: 3
102 | out_ch: 3
103 | ch: 128
104 | ch_mult: [ 1, 2, 4, 4 ]
105 | num_res_blocks: 2
106 | attn_resolutions: [ ]
107 | dropout: 0.0
108 |
109 | sampler_config:
110 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
111 | params:
112 | discretization_config:
113 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
114 | params:
115 | sigma_max: 700.0
116 |
117 | guider_config:
118 | target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider
119 | params:
120 | max_scale: 2.5
121 |
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/sampling/configs/svd_image_decoder.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.18215
5 | disable_first_stage_autocast: True
6 | ckpt_path: checkpoints/svd_image_decoder.safetensors
7 |
8 | denoiser_config:
9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
10 | params:
11 | scaling_config:
12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13 |
14 | network_config:
15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet
16 | params:
17 | adm_in_channels: 768
18 | num_classes: sequential
19 | use_checkpoint: True
20 | in_channels: 8
21 | out_channels: 4
22 | model_channels: 320
23 | attention_resolutions: [4, 2, 1]
24 | num_res_blocks: 2
25 | channel_mult: [1, 2, 4, 4]
26 | num_head_channels: 64
27 | use_linear_in_transformer: True
28 | transformer_depth: 1
29 | context_dim: 1024
30 | spatial_transformer_attn_type: softmax-xformers
31 | extra_ff_mix_layer: True
32 | use_spatial_context: True
33 | merge_strategy: learned_with_images
34 | video_kernel_size: [3, 1, 1]
35 |
36 | conditioner_config:
37 | target: sgm.modules.GeneralConditioner
38 | params:
39 | emb_models:
40 | - is_trainable: False
41 | input_key: cond_frames_without_noise
42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43 | params:
44 | n_cond_frames: 1
45 | n_copies: 1
46 | open_clip_embedding_config:
47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48 | params:
49 | freeze: True
50 |
51 | - input_key: fps_id
52 | is_trainable: False
53 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
54 | params:
55 | outdim: 256
56 |
57 | - input_key: motion_bucket_id
58 | is_trainable: False
59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60 | params:
61 | outdim: 256
62 |
63 | - input_key: cond_frames
64 | is_trainable: False
65 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
66 | params:
67 | disable_encoder_autocast: True
68 | n_cond_frames: 1
69 | n_copies: 1
70 | is_ae: True
71 | encoder_config:
72 | target: sgm.models.autoencoder.AutoencoderKLModeOnly
73 | params:
74 | embed_dim: 4
75 | monitor: val/rec_loss
76 | ddconfig:
77 | attn_type: vanilla-xformers
78 | double_z: True
79 | z_channels: 4
80 | resolution: 256
81 | in_channels: 3
82 | out_ch: 3
83 | ch: 128
84 | ch_mult: [1, 2, 4, 4]
85 | num_res_blocks: 2
86 | attn_resolutions: []
87 | dropout: 0.0
88 | lossconfig:
89 | target: torch.nn.Identity
90 |
91 | - input_key: cond_aug
92 | is_trainable: False
93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94 | params:
95 | outdim: 256
96 |
97 | first_stage_config:
98 | target: sgm.models.autoencoder.AutoencoderKL
99 | params:
100 | embed_dim: 4
101 | monitor: val/rec_loss
102 | ddconfig:
103 | attn_type: vanilla-xformers
104 | double_z: True
105 | z_channels: 4
106 | resolution: 256
107 | in_channels: 3
108 | out_ch: 3
109 | ch: 128
110 | ch_mult: [1, 2, 4, 4]
111 | num_res_blocks: 2
112 | attn_resolutions: []
113 | dropout: 0.0
114 | lossconfig:
115 | target: torch.nn.Identity
116 |
117 | sampler_config:
118 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
119 | params:
120 | discretization_config:
121 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
122 | params:
123 | sigma_max: 700.0
124 |
125 | guider_config:
126 | target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
127 | params:
128 | max_scale: 2.5
129 | min_scale: 1.0
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/sampling/configs/svd_xt_image_decoder.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: sgm.models.diffusion.DiffusionEngine
3 | params:
4 | scale_factor: 0.18215
5 | disable_first_stage_autocast: True
6 | ckpt_path: checkpoints/svd_xt_image_decoder.safetensors
7 |
8 | denoiser_config:
9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser
10 | params:
11 | scaling_config:
12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13 |
14 | network_config:
15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet
16 | params:
17 | adm_in_channels: 768
18 | num_classes: sequential
19 | use_checkpoint: True
20 | in_channels: 8
21 | out_channels: 4
22 | model_channels: 320
23 | attention_resolutions: [4, 2, 1]
24 | num_res_blocks: 2
25 | channel_mult: [1, 2, 4, 4]
26 | num_head_channels: 64
27 | use_linear_in_transformer: True
28 | transformer_depth: 1
29 | context_dim: 1024
30 | spatial_transformer_attn_type: softmax-xformers
31 | extra_ff_mix_layer: True
32 | use_spatial_context: True
33 | merge_strategy: learned_with_images
34 | video_kernel_size: [3, 1, 1]
35 |
36 | conditioner_config:
37 | target: sgm.modules.GeneralConditioner
38 | params:
39 | emb_models:
40 | - is_trainable: False
41 | input_key: cond_frames_without_noise
42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43 | params:
44 | n_cond_frames: 1
45 | n_copies: 1
46 | open_clip_embedding_config:
47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48 | params:
49 | freeze: True
50 |
51 | - input_key: fps_id
52 | is_trainable: False
53 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
54 | params:
55 | outdim: 256
56 |
57 | - input_key: motion_bucket_id
58 | is_trainable: False
59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60 | params:
61 | outdim: 256
62 |
63 | - input_key: cond_frames
64 | is_trainable: False
65 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
66 | params:
67 | disable_encoder_autocast: True
68 | n_cond_frames: 1
69 | n_copies: 1
70 | is_ae: True
71 | encoder_config:
72 | target: sgm.models.autoencoder.AutoencoderKLModeOnly
73 | params:
74 | embed_dim: 4
75 | monitor: val/rec_loss
76 | ddconfig:
77 | attn_type: vanilla-xformers
78 | double_z: True
79 | z_channels: 4
80 | resolution: 256
81 | in_channels: 3
82 | out_ch: 3
83 | ch: 128
84 | ch_mult: [1, 2, 4, 4]
85 | num_res_blocks: 2
86 | attn_resolutions: []
87 | dropout: 0.0
88 | lossconfig:
89 | target: torch.nn.Identity
90 |
91 | - input_key: cond_aug
92 | is_trainable: False
93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94 | params:
95 | outdim: 256
96 |
97 | first_stage_config:
98 | target: sgm.models.autoencoder.AutoencoderKL
99 | params:
100 | embed_dim: 4
101 | monitor: val/rec_loss
102 | ddconfig:
103 | attn_type: vanilla-xformers
104 | double_z: True
105 | z_channels: 4
106 | resolution: 256
107 | in_channels: 3
108 | out_ch: 3
109 | ch: 128
110 | ch_mult: [1, 2, 4, 4]
111 | num_res_blocks: 2
112 | attn_resolutions: []
113 | dropout: 0.0
114 | lossconfig:
115 | target: torch.nn.Identity
116 |
117 | sampler_config:
118 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
119 | params:
120 | discretization_config:
121 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
122 | params:
123 | sigma_max: 700.0
124 |
125 | guider_config:
126 | target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
127 | params:
128 | max_scale: 3.0
129 | min_scale: 1.5
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/util/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/util/detection/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/util/detection/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/util/detection/nsfw_and_watermark_dectection.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import clip
4 | import numpy as np
5 | import torch
6 | import torchvision.transforms as T
7 | from PIL import Image
8 |
9 | RESOURCES_ROOT = "dataLoader/image_generator/scripts/util/detection/"
10 |
11 |
12 | def predict_proba(X, weights, biases):
13 | logits = X @ weights.T + biases
14 | proba = np.where(
15 | logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))
16 | )
17 | return proba.T
18 |
19 |
20 | def load_model_weights(path: str):
21 | model_weights = np.load(path)
22 | return model_weights["weights"], model_weights["biases"]
23 |
24 |
25 | def clip_process_images(images: torch.Tensor) -> torch.Tensor:
26 | min_size = min(images.shape[-2:])
27 | return T.Compose(
28 | [
29 | T.CenterCrop(min_size), # TODO: this might affect the watermark, check this
30 | T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
31 | T.Normalize(
32 | (0.48145466, 0.4578275, 0.40821073),
33 | (0.26862954, 0.26130258, 0.27577711),
34 | ),
35 | ]
36 | )(images)
37 |
38 |
39 | class DeepFloydDataFiltering(object):
40 | def __init__(
41 | self, verbose: bool = False, device: torch.device = torch.device("cpu")
42 | ):
43 | super().__init__()
44 | self.verbose = verbose
45 | self._device = None
46 | self.clip_model, _ = clip.load("ViT-L/14", device=device)
47 | self.clip_model.eval()
48 |
49 | self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
50 | os.path.join(RESOURCES_ROOT, "w_head_v1.npz")
51 | )
52 | self.cpu_p_weights, self.cpu_p_biases = load_model_weights(
53 | os.path.join(RESOURCES_ROOT, "p_head_v1.npz")
54 | )
55 | self.w_threshold, self.p_threshold = 0.5, 0.5
56 |
57 | @torch.inference_mode()
58 | def __call__(self, images: torch.Tensor) -> torch.Tensor:
59 | imgs = clip_process_images(images)
60 | if self._device is None:
61 | self._device = next(p for p in self.clip_model.parameters()).device
62 | image_features = self.clip_model.encode_image(imgs.to(self._device))
63 | image_features = image_features.detach().cpu().numpy().astype(np.float16)
64 | p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
65 | w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
66 | print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None
67 | query = p_pred > self.p_threshold
68 | if query.sum() > 0:
69 | print(f"Hit for p_threshold: {p_pred}") if self.verbose else None
70 | images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
71 | query = w_pred > self.w_threshold
72 | if query.sum() > 0:
73 | print(f"Hit for w_threshold: {w_pred}") if self.verbose else None
74 | images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
75 | return images
76 |
77 |
78 | def load_img(path: str) -> torch.Tensor:
79 | image = Image.open(path)
80 | if not image.mode == "RGB":
81 | image = image.convert("RGB")
82 | image_transforms = T.Compose(
83 | [
84 | T.ToTensor(),
85 | ]
86 | )
87 | return image_transforms(image)[None, ...]
88 |
89 |
90 | def test(root):
91 | from einops import rearrange
92 |
93 | filter = DeepFloydDataFiltering(verbose=True)
94 | for p in os.listdir((root)):
95 | print(f"running on {p}...")
96 | img = load_img(os.path.join(root, p))
97 | filtered_img = filter(img)
98 | filtered_img = rearrange(
99 | 255.0 * (filtered_img.numpy())[0], "c h w -> h w c"
100 | ).astype(np.uint8)
101 | Image.fromarray(filtered_img).save(
102 | os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg")
103 | )
104 |
105 |
106 | if __name__ == "__main__":
107 | import fire
108 |
109 | fire.Fire(test)
110 | print("done.")
111 |
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/util/detection/p_head_v1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/util/detection/p_head_v1.npz
--------------------------------------------------------------------------------
/third_party/image_generator/scripts/util/detection/w_head_v1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/util/detection/w_head_v1.npz
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import AutoencodingEngine, DiffusionEngine
2 | from .util import get_configs_path, instantiate_from_config
3 |
4 | __version__ = "0.1.0"
5 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import StableDataModuleFromConfig
2 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/data/cifar10.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torchvision
3 | from torch.utils.data import DataLoader, Dataset
4 | from torchvision import transforms
5 |
6 |
7 | class CIFAR10DataDictWrapper(Dataset):
8 | def __init__(self, dset):
9 | super().__init__()
10 | self.dset = dset
11 |
12 | def __getitem__(self, i):
13 | x, y = self.dset[i]
14 | return {"jpg": x, "cls": y}
15 |
16 | def __len__(self):
17 | return len(self.dset)
18 |
19 |
20 | class CIFAR10Loader(pl.LightningDataModule):
21 | def __init__(self, batch_size, num_workers=0, shuffle=True):
22 | super().__init__()
23 |
24 | transform = transforms.Compose(
25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26 | )
27 |
28 | self.batch_size = batch_size
29 | self.num_workers = num_workers
30 | self.shuffle = shuffle
31 | self.train_dataset = CIFAR10DataDictWrapper(
32 | torchvision.datasets.CIFAR10(
33 | root=".data/", train=True, download=True, transform=transform
34 | )
35 | )
36 | self.test_dataset = CIFAR10DataDictWrapper(
37 | torchvision.datasets.CIFAR10(
38 | root=".data/", train=False, download=True, transform=transform
39 | )
40 | )
41 |
42 | def prepare_data(self):
43 | pass
44 |
45 | def train_dataloader(self):
46 | return DataLoader(
47 | self.train_dataset,
48 | batch_size=self.batch_size,
49 | shuffle=self.shuffle,
50 | num_workers=self.num_workers,
51 | )
52 |
53 | def test_dataloader(self):
54 | return DataLoader(
55 | self.test_dataset,
56 | batch_size=self.batch_size,
57 | shuffle=self.shuffle,
58 | num_workers=self.num_workers,
59 | )
60 |
61 | def val_dataloader(self):
62 | return DataLoader(
63 | self.test_dataset,
64 | batch_size=self.batch_size,
65 | shuffle=self.shuffle,
66 | num_workers=self.num_workers,
67 | )
68 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/data/dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torchdata.datapipes.iter
4 | import webdataset as wds
5 | from omegaconf import DictConfig
6 | from pytorch_lightning import LightningDataModule
7 |
8 | try:
9 | from sdata import create_dataset, create_dummy_dataset, create_loader
10 | except ImportError as e:
11 | print("#" * 100)
12 | print("Datasets not yet available")
13 | print("to enable, we need to add stable-datasets as a submodule")
14 | print("please use ``git submodule update --init --recursive``")
15 | print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16 | print("#" * 100)
17 | exit(1)
18 |
19 |
20 | class StableDataModuleFromConfig(LightningDataModule):
21 | def __init__(
22 | self,
23 | train: DictConfig,
24 | validation: Optional[DictConfig] = None,
25 | test: Optional[DictConfig] = None,
26 | skip_val_loader: bool = False,
27 | dummy: bool = False,
28 | ):
29 | super().__init__()
30 | self.train_config = train
31 | assert (
32 | "datapipeline" in self.train_config and "loader" in self.train_config
33 | ), "train config requires the fields `datapipeline` and `loader`"
34 |
35 | self.val_config = validation
36 | if not skip_val_loader:
37 | if self.val_config is not None:
38 | assert (
39 | "datapipeline" in self.val_config and "loader" in self.val_config
40 | ), "validation config requires the fields `datapipeline` and `loader`"
41 | else:
42 | print(
43 | "Warning: No Validation datapipeline defined, using that one from training"
44 | )
45 | self.val_config = train
46 |
47 | self.test_config = test
48 | if self.test_config is not None:
49 | assert (
50 | "datapipeline" in self.test_config and "loader" in self.test_config
51 | ), "test config requires the fields `datapipeline` and `loader`"
52 |
53 | self.dummy = dummy
54 | if self.dummy:
55 | print("#" * 100)
56 | print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57 | print("#" * 100)
58 |
59 | def setup(self, stage: str) -> None:
60 | print("Preparing datasets")
61 | if self.dummy:
62 | data_fn = create_dummy_dataset
63 | else:
64 | data_fn = create_dataset
65 |
66 | self.train_datapipeline = data_fn(**self.train_config.datapipeline)
67 | if self.val_config:
68 | self.val_datapipeline = data_fn(**self.val_config.datapipeline)
69 | if self.test_config:
70 | self.test_datapipeline = data_fn(**self.test_config.datapipeline)
71 |
72 | def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
73 | loader = create_loader(self.train_datapipeline, **self.train_config.loader)
74 | return loader
75 |
76 | def val_dataloader(self) -> wds.DataPipeline:
77 | return create_loader(self.val_datapipeline, **self.val_config.loader)
78 |
79 | def test_dataloader(self) -> wds.DataPipeline:
80 | return create_loader(self.test_datapipeline, **self.test_config.loader)
81 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/data/mnist.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torchvision
3 | from torch.utils.data import DataLoader, Dataset
4 | from torchvision import transforms
5 |
6 |
7 | class MNISTDataDictWrapper(Dataset):
8 | def __init__(self, dset):
9 | super().__init__()
10 | self.dset = dset
11 |
12 | def __getitem__(self, i):
13 | x, y = self.dset[i]
14 | return {"jpg": x, "cls": y}
15 |
16 | def __len__(self):
17 | return len(self.dset)
18 |
19 |
20 | class MNISTLoader(pl.LightningDataModule):
21 | def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
22 | super().__init__()
23 |
24 | transform = transforms.Compose(
25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26 | )
27 |
28 | self.batch_size = batch_size
29 | self.num_workers = num_workers
30 | self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
31 | self.shuffle = shuffle
32 | self.train_dataset = MNISTDataDictWrapper(
33 | torchvision.datasets.MNIST(
34 | root=".data/", train=True, download=True, transform=transform
35 | )
36 | )
37 | self.test_dataset = MNISTDataDictWrapper(
38 | torchvision.datasets.MNIST(
39 | root=".data/", train=False, download=True, transform=transform
40 | )
41 | )
42 |
43 | def prepare_data(self):
44 | pass
45 |
46 | def train_dataloader(self):
47 | return DataLoader(
48 | self.train_dataset,
49 | batch_size=self.batch_size,
50 | shuffle=self.shuffle,
51 | num_workers=self.num_workers,
52 | prefetch_factor=self.prefetch_factor,
53 | )
54 |
55 | def test_dataloader(self):
56 | return DataLoader(
57 | self.test_dataset,
58 | batch_size=self.batch_size,
59 | shuffle=self.shuffle,
60 | num_workers=self.num_workers,
61 | prefetch_factor=self.prefetch_factor,
62 | )
63 |
64 | def val_dataloader(self):
65 | return DataLoader(
66 | self.test_dataset,
67 | batch_size=self.batch_size,
68 | shuffle=self.shuffle,
69 | num_workers=self.num_workers,
70 | prefetch_factor=self.prefetch_factor,
71 | )
72 |
73 |
74 | if __name__ == "__main__":
75 | dset = MNISTDataDictWrapper(
76 | torchvision.datasets.MNIST(
77 | root=".data/",
78 | train=False,
79 | download=True,
80 | transform=transforms.Compose(
81 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
82 | ),
83 | )
84 | )
85 | ex = dset[0]
86 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 |
9 | def __init__(
10 | self,
11 | warm_up_steps,
12 | lr_min,
13 | lr_max,
14 | lr_start,
15 | max_decay_steps,
16 | verbosity_interval=0,
17 | ):
18 | self.lr_warm_up_steps = warm_up_steps
19 | self.lr_start = lr_start
20 | self.lr_min = lr_min
21 | self.lr_max = lr_max
22 | self.lr_max_decay_steps = max_decay_steps
23 | self.last_lr = 0.0
24 | self.verbosity_interval = verbosity_interval
25 |
26 | def schedule(self, n, **kwargs):
27 | if self.verbosity_interval > 0:
28 | if n % self.verbosity_interval == 0:
29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30 | if n < self.lr_warm_up_steps:
31 | lr = (
32 | self.lr_max - self.lr_start
33 | ) / self.lr_warm_up_steps * n + self.lr_start
34 | self.last_lr = lr
35 | return lr
36 | else:
37 | t = (n - self.lr_warm_up_steps) / (
38 | self.lr_max_decay_steps - self.lr_warm_up_steps
39 | )
40 | t = min(t, 1.0)
41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42 | 1 + np.cos(t * np.pi)
43 | )
44 | self.last_lr = lr
45 | return lr
46 |
47 | def __call__(self, n, **kwargs):
48 | return self.schedule(n, **kwargs)
49 |
50 |
51 | class LambdaWarmUpCosineScheduler2:
52 | """
53 | supports repeated iterations, configurable via lists
54 | note: use with a base_lr of 1.0.
55 | """
56 |
57 | def __init__(
58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59 | ):
60 | assert (
61 | len(warm_up_steps)
62 | == len(f_min)
63 | == len(f_max)
64 | == len(f_start)
65 | == len(cycle_lengths)
66 | )
67 | self.lr_warm_up_steps = warm_up_steps
68 | self.f_start = f_start
69 | self.f_min = f_min
70 | self.f_max = f_max
71 | self.cycle_lengths = cycle_lengths
72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73 | self.last_f = 0.0
74 | self.verbosity_interval = verbosity_interval
75 |
76 | def find_in_interval(self, n):
77 | interval = 0
78 | for cl in self.cum_cycles[1:]:
79 | if n <= cl:
80 | return interval
81 | interval += 1
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0:
88 | print(
89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90 | f"current cycle {cycle}"
91 | )
92 | if n < self.lr_warm_up_steps[cycle]:
93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94 | cycle
95 | ] * n + self.f_start[cycle]
96 | self.last_f = f
97 | return f
98 | else:
99 | t = (n - self.lr_warm_up_steps[cycle]) / (
100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101 | )
102 | t = min(t, 1.0)
103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104 | 1 + np.cos(t * np.pi)
105 | )
106 | self.last_f = f
107 | return f
108 |
109 | def __call__(self, n, **kwargs):
110 | return self.schedule(n, **kwargs)
111 |
112 |
113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114 | def schedule(self, n, **kwargs):
115 | cycle = self.find_in_interval(n)
116 | n = n - self.cum_cycles[cycle]
117 | if self.verbosity_interval > 0:
118 | if n % self.verbosity_interval == 0:
119 | print(
120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121 | f"current cycle {cycle}"
122 | )
123 |
124 | if n < self.lr_warm_up_steps[cycle]:
125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126 | cycle
127 | ] * n + self.f_start[cycle]
128 | self.last_f = f
129 | return f
130 | else:
131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132 | self.cycle_lengths[cycle] - n
133 | ) / (self.cycle_lengths[cycle])
134 | self.last_f = f
135 | return f
136 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .autoencoder import AutoencodingEngine
2 | from .diffusion import DiffusionEngine
3 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoders.modules import GeneralConditioner
2 |
3 | UNCONDITIONAL_CONFIG = {
4 | "target": "sgm.modules.GeneralConditioner",
5 | "params": {"emb_models": []},
6 | }
7 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/autoencoding/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/losses/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "GeneralLPIPSWithDiscriminator",
3 | "LatentLPIPS",
4 | ]
5 |
6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator
7 | from .lpips import LatentLPIPS
8 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/losses/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from ....util import default, instantiate_from_config
5 | from ..lpips.loss.lpips import LPIPS
6 |
7 |
8 | class LatentLPIPS(nn.Module):
9 | def __init__(
10 | self,
11 | decoder_config,
12 | perceptual_weight=1.0,
13 | latent_weight=1.0,
14 | scale_input_to_tgt_size=False,
15 | scale_tgt_to_input_size=False,
16 | perceptual_weight_on_inputs=0.0,
17 | ):
18 | super().__init__()
19 | self.scale_input_to_tgt_size = scale_input_to_tgt_size
20 | self.scale_tgt_to_input_size = scale_tgt_to_input_size
21 | self.init_decoder(decoder_config)
22 | self.perceptual_loss = LPIPS().eval()
23 | self.perceptual_weight = perceptual_weight
24 | self.latent_weight = latent_weight
25 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
26 |
27 | def init_decoder(self, config):
28 | self.decoder = instantiate_from_config(config)
29 | if hasattr(self.decoder, "encoder"):
30 | del self.decoder.encoder
31 |
32 | def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
33 | log = dict()
34 | loss = (latent_inputs - latent_predictions) ** 2
35 | log[f"{split}/latent_l2_loss"] = loss.mean().detach()
36 | image_reconstructions = None
37 | if self.perceptual_weight > 0.0:
38 | image_reconstructions = self.decoder.decode(latent_predictions)
39 | image_targets = self.decoder.decode(latent_inputs)
40 | perceptual_loss = self.perceptual_loss(
41 | image_targets.contiguous(), image_reconstructions.contiguous()
42 | )
43 | loss = (
44 | self.latent_weight * loss.mean()
45 | + self.perceptual_weight * perceptual_loss.mean()
46 | )
47 | log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
48 |
49 | if self.perceptual_weight_on_inputs > 0.0:
50 | image_reconstructions = default(
51 | image_reconstructions, self.decoder.decode(latent_predictions)
52 | )
53 | if self.scale_input_to_tgt_size:
54 | image_inputs = torch.nn.functional.interpolate(
55 | image_inputs,
56 | image_reconstructions.shape[2:],
57 | mode="bicubic",
58 | antialias=True,
59 | )
60 | elif self.scale_tgt_to_input_size:
61 | image_reconstructions = torch.nn.functional.interpolate(
62 | image_reconstructions,
63 | image_inputs.shape[2:],
64 | mode="bicubic",
65 | antialias=True,
66 | )
67 |
68 | perceptual_loss2 = self.perceptual_loss(
69 | image_inputs.contiguous(), image_reconstructions.contiguous()
70 | )
71 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
72 | log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
73 | return loss, log
74 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/lpips/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/autoencoding/lpips/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/lpips/loss/.gitignore:
--------------------------------------------------------------------------------
1 | vgg.pth
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/lpips/loss/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/lpips/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/autoencoding/lpips/loss/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/lpips/model/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 | --------------------------- LICENSE FOR pix2pix --------------------------------
27 | BSD License
28 |
29 | For pix2pix software
30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without
34 | modification, are permitted provided that the following conditions are met:
35 |
36 | * Redistributions of source code must retain the above copyright notice, this
37 | list of conditions and the following disclaimer.
38 |
39 | * Redistributions in binary form must reproduce the above copyright notice,
40 | this list of conditions and the following disclaimer in the documentation
41 | and/or other materials provided with the distribution.
42 |
43 | ----------------------------- LICENSE FOR DCGAN --------------------------------
44 | BSD License
45 |
46 | For dcgan.torch software
47 |
48 | Copyright (c) 2015, Facebook, Inc. All rights reserved.
49 |
50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51 |
52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53 |
54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55 |
56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57 |
58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/lpips/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/autoencoding/lpips/model/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/lpips/model/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch.nn as nn
4 |
5 | from ..util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find("Conv") != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find("BatchNorm") != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 |
22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23 | """Construct a PatchGAN discriminator
24 | Parameters:
25 | input_nc (int) -- the number of channels in input images
26 | ndf (int) -- the number of filters in the last conv layer
27 | n_layers (int) -- the number of conv layers in the discriminator
28 | norm_layer -- normalization layer
29 | """
30 | super(NLayerDiscriminator, self).__init__()
31 | if not use_actnorm:
32 | norm_layer = nn.BatchNorm2d
33 | else:
34 | norm_layer = ActNorm
35 | if (
36 | type(norm_layer) == functools.partial
37 | ): # no need to use bias as BatchNorm2d has affine parameters
38 | use_bias = norm_layer.func != nn.BatchNorm2d
39 | else:
40 | use_bias = norm_layer != nn.BatchNorm2d
41 |
42 | kw = 4
43 | padw = 1
44 | sequence = [
45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46 | nn.LeakyReLU(0.2, True),
47 | ]
48 | nf_mult = 1
49 | nf_mult_prev = 1
50 | for n in range(1, n_layers): # gradually increase the number of filters
51 | nf_mult_prev = nf_mult
52 | nf_mult = min(2**n, 8)
53 | sequence += [
54 | nn.Conv2d(
55 | ndf * nf_mult_prev,
56 | ndf * nf_mult,
57 | kernel_size=kw,
58 | stride=2,
59 | padding=padw,
60 | bias=use_bias,
61 | ),
62 | norm_layer(ndf * nf_mult),
63 | nn.LeakyReLU(0.2, True),
64 | ]
65 |
66 | nf_mult_prev = nf_mult
67 | nf_mult = min(2**n_layers, 8)
68 | sequence += [
69 | nn.Conv2d(
70 | ndf * nf_mult_prev,
71 | ndf * nf_mult,
72 | kernel_size=kw,
73 | stride=1,
74 | padding=padw,
75 | bias=use_bias,
76 | ),
77 | norm_layer(ndf * nf_mult),
78 | nn.LeakyReLU(0.2, True),
79 | ]
80 |
81 | sequence += [
82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83 | ] # output 1 channel prediction map
84 | self.main = nn.Sequential(*sequence)
85 |
86 | def forward(self, input):
87 | """Standard forward."""
88 | return self.main(input)
89 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/lpips/util.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 |
4 | import requests
5 | import torch
6 | import torch.nn as nn
7 | from tqdm import tqdm
8 |
9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10 |
11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12 |
13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14 |
15 |
16 | def download(url, local_path, chunk_size=1024):
17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18 | with requests.get(url, stream=True) as r:
19 | total_size = int(r.headers.get("content-length", 0))
20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21 | with open(local_path, "wb") as f:
22 | for data in r.iter_content(chunk_size=chunk_size):
23 | if data:
24 | f.write(data)
25 | pbar.update(chunk_size)
26 |
27 |
28 | def md5_hash(path):
29 | with open(path, "rb") as f:
30 | content = f.read()
31 | return hashlib.md5(content).hexdigest()
32 |
33 |
34 | def get_ckpt_path(name, root, check=False):
35 | assert name in URL_MAP
36 | path = os.path.join(root, CKPT_MAP[name])
37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39 | download(URL_MAP[name], path)
40 | md5 = md5_hash(path)
41 | assert md5 == MD5_MAP[name], md5
42 | return path
43 |
44 |
45 | class ActNorm(nn.Module):
46 | def __init__(
47 | self, num_features, logdet=False, affine=True, allow_reverse_init=False
48 | ):
49 | assert affine
50 | super().__init__()
51 | self.logdet = logdet
52 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
53 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
54 | self.allow_reverse_init = allow_reverse_init
55 |
56 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
57 |
58 | def initialize(self, input):
59 | with torch.no_grad():
60 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
61 | mean = (
62 | flatten.mean(1)
63 | .unsqueeze(1)
64 | .unsqueeze(2)
65 | .unsqueeze(3)
66 | .permute(1, 0, 2, 3)
67 | )
68 | std = (
69 | flatten.std(1)
70 | .unsqueeze(1)
71 | .unsqueeze(2)
72 | .unsqueeze(3)
73 | .permute(1, 0, 2, 3)
74 | )
75 |
76 | self.loc.data.copy_(-mean)
77 | self.scale.data.copy_(1 / (std + 1e-6))
78 |
79 | def forward(self, input, reverse=False):
80 | if reverse:
81 | return self.reverse(input)
82 | if len(input.shape) == 2:
83 | input = input[:, :, None, None]
84 | squeeze = True
85 | else:
86 | squeeze = False
87 |
88 | _, _, height, width = input.shape
89 |
90 | if self.training and self.initialized.item() == 0:
91 | self.initialize(input)
92 | self.initialized.fill_(1)
93 |
94 | h = self.scale * (input + self.loc)
95 |
96 | if squeeze:
97 | h = h.squeeze(-1).squeeze(-1)
98 |
99 | if self.logdet:
100 | log_abs = torch.log(torch.abs(self.scale))
101 | logdet = height * width * torch.sum(log_abs)
102 | logdet = logdet * torch.ones(input.shape[0]).to(input)
103 | return h, logdet
104 |
105 | return h
106 |
107 | def reverse(self, output):
108 | if self.training and self.initialized.item() == 0:
109 | if not self.allow_reverse_init:
110 | raise RuntimeError(
111 | "Initializing ActNorm in reverse direction is "
112 | "disabled by default. Use allow_reverse_init=True to enable."
113 | )
114 | else:
115 | self.initialize(output)
116 | self.initialized.fill_(1)
117 |
118 | if len(output.shape) == 2:
119 | output = output[:, :, None, None]
120 | squeeze = True
121 | else:
122 | squeeze = False
123 |
124 | h = output / self.scale - self.loc
125 |
126 | if squeeze:
127 | h = h.squeeze(-1).squeeze(-1)
128 | return h
129 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/lpips/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def hinge_d_loss(logits_real, logits_fake):
6 | loss_real = torch.mean(F.relu(1.0 - logits_real))
7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8 | d_loss = 0.5 * (loss_real + loss_fake)
9 | return d_loss
10 |
11 |
12 | def vanilla_d_loss(logits_real, logits_fake):
13 | d_loss = 0.5 * (
14 | torch.mean(torch.nn.functional.softplus(-logits_real))
15 | + torch.mean(torch.nn.functional.softplus(logits_fake))
16 | )
17 | return d_loss
18 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/regularizers/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from ....modules.distributions.distributions import \
9 | DiagonalGaussianDistribution
10 | from .base import AbstractRegularizer
11 |
12 |
13 | class DiagonalGaussianRegularizer(AbstractRegularizer):
14 | def __init__(self, sample: bool = True):
15 | super().__init__()
16 | self.sample = sample
17 |
18 | def get_trainable_parameters(self) -> Any:
19 | yield from ()
20 |
21 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
22 | log = dict()
23 | posterior = DiagonalGaussianDistribution(z)
24 | if self.sample:
25 | z = posterior.sample()
26 | else:
27 | z = posterior.mode()
28 | kl_loss = posterior.kl()
29 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
30 | log["kl_loss"] = kl_loss
31 | return z, log
32 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/autoencoding/regularizers/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any, Tuple
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 |
9 | class AbstractRegularizer(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
14 | raise NotImplementedError()
15 |
16 | @abstractmethod
17 | def get_trainable_parameters(self) -> Any:
18 | raise NotImplementedError()
19 |
20 |
21 | class IdentityRegularizer(AbstractRegularizer):
22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
23 | return z, dict()
24 |
25 | def get_trainable_parameters(self) -> Any:
26 | yield from ()
27 |
28 |
29 | def measure_perplexity(
30 | predicted_indices: torch.Tensor, num_centroids: int
31 | ) -> Tuple[torch.Tensor, torch.Tensor]:
32 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
33 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
34 | encodings = (
35 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
36 | )
37 | avg_probs = encodings.mean(0)
38 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
39 | cluster_use = torch.sum(avg_probs > 0)
40 | return perplexity, cluster_use
41 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/denoiser.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ...util import append_dims, instantiate_from_config
7 | from .denoiser_scaling import DenoiserScaling
8 | from .discretizer import Discretization
9 |
10 |
11 | class Denoiser(nn.Module):
12 | def __init__(self, scaling_config: Dict):
13 | super().__init__()
14 |
15 | self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
16 |
17 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
18 | return sigma
19 |
20 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
21 | return c_noise
22 |
23 | def forward(
24 | self,
25 | network: nn.Module,
26 | input: torch.Tensor,
27 | sigma: torch.Tensor,
28 | cond: Dict,
29 | **additional_model_inputs,
30 | ) -> torch.Tensor:
31 | sigma = self.possibly_quantize_sigma(sigma)
32 | sigma_shape = sigma.shape
33 | sigma = append_dims(sigma, input.ndim)
34 | c_skip, c_out, c_in, c_noise = self.scaling(sigma)
35 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
36 | return (
37 | network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
38 | + input * c_skip
39 | )
40 |
41 |
42 | class DiscreteDenoiser(Denoiser):
43 | def __init__(
44 | self,
45 | scaling_config: Dict,
46 | num_idx: int,
47 | discretization_config: Dict,
48 | do_append_zero: bool = False,
49 | quantize_c_noise: bool = True,
50 | flip: bool = True,
51 | ):
52 | super().__init__(scaling_config)
53 | self.discretization: Discretization = instantiate_from_config(
54 | discretization_config
55 | )
56 | sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
57 | self.register_buffer("sigmas", sigmas)
58 | self.quantize_c_noise = quantize_c_noise
59 | self.num_idx = num_idx
60 |
61 | def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
62 | dists = sigma - self.sigmas[:, None]
63 | return dists.abs().argmin(dim=0).view(sigma.shape)
64 |
65 | def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
66 | return self.sigmas[idx]
67 |
68 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
69 | return self.idx_to_sigma(self.sigma_to_idx(sigma))
70 |
71 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
72 | if self.quantize_c_noise:
73 | return self.sigma_to_idx(c_noise)
74 | else:
75 | return c_noise
76 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/denoiser_scaling.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Tuple
3 |
4 | import torch
5 |
6 |
7 | class DenoiserScaling(ABC):
8 | @abstractmethod
9 | def __call__(
10 | self, sigma: torch.Tensor
11 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
12 | pass
13 |
14 |
15 | class EDMScaling:
16 | def __init__(self, sigma_data: float = 0.5):
17 | self.sigma_data = sigma_data
18 |
19 | def __call__(
20 | self, sigma: torch.Tensor
21 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
22 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
23 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
24 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
25 | c_noise = 0.25 * sigma.log()
26 | return c_skip, c_out, c_in, c_noise
27 |
28 |
29 | class EpsScaling:
30 | def __call__(
31 | self, sigma: torch.Tensor
32 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
33 | c_skip = torch.ones_like(sigma, device=sigma.device)
34 | c_out = -sigma
35 | c_in = 1 / (sigma**2 + 1.0) ** 0.5
36 | c_noise = sigma.clone()
37 | return c_skip, c_out, c_in, c_noise
38 |
39 |
40 | class VScaling:
41 | def __call__(
42 | self, sigma: torch.Tensor
43 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
44 | c_skip = 1.0 / (sigma**2 + 1.0)
45 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5
46 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
47 | c_noise = sigma.clone()
48 | return c_skip, c_out, c_in, c_noise
49 |
50 |
51 | class VScalingWithEDMcNoise(DenoiserScaling):
52 | def __call__(
53 | self, sigma: torch.Tensor
54 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
55 | c_skip = 1.0 / (sigma**2 + 1.0)
56 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5
57 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
58 | c_noise = 0.25 * sigma.log()
59 | return c_skip, c_out, c_in, c_noise
60 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/denoiser_weighting.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class UnitWeighting:
5 | def __call__(self, sigma):
6 | return torch.ones_like(sigma, device=sigma.device)
7 |
8 |
9 | class EDMWeighting:
10 | def __init__(self, sigma_data=0.5):
11 | self.sigma_data = sigma_data
12 |
13 | def __call__(self, sigma):
14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
15 |
16 |
17 | class VWeighting(EDMWeighting):
18 | def __init__(self):
19 | super().__init__(sigma_data=1.0)
20 |
21 |
22 | class EpsWeighting:
23 | def __call__(self, sigma):
24 | return sigma**-2.0
25 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/discretizer.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from functools import partial
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from ...modules.diffusionmodules.util import make_beta_schedule
8 | from ...util import append_zero
9 |
10 |
11 | def generate_roughly_equally_spaced_steps(
12 | num_substeps: int, max_step: int
13 | ) -> np.ndarray:
14 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
15 |
16 |
17 | class Discretization:
18 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
19 | sigmas = self.get_sigmas(n, device=device)
20 | sigmas = append_zero(sigmas) if do_append_zero else sigmas
21 | return sigmas if not flip else torch.flip(sigmas, (0,))
22 |
23 | @abstractmethod
24 | def get_sigmas(self, n, device):
25 | pass
26 |
27 |
28 | class EDMDiscretization(Discretization):
29 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
30 | self.sigma_min = sigma_min
31 | self.sigma_max = sigma_max
32 | self.rho = rho
33 |
34 | def get_sigmas(self, n, device="cpu"):
35 | ramp = torch.linspace(0, 1, n, device=device)
36 | min_inv_rho = self.sigma_min ** (1 / self.rho)
37 | max_inv_rho = self.sigma_max ** (1 / self.rho)
38 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
39 | return sigmas
40 |
41 |
42 | class LegacyDDPMDiscretization(Discretization):
43 | def __init__(
44 | self,
45 | linear_start=0.00085,
46 | linear_end=0.0120,
47 | num_timesteps=1000,
48 | ):
49 | super().__init__()
50 | self.num_timesteps = num_timesteps
51 | betas = make_beta_schedule(
52 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
53 | )
54 | alphas = 1.0 - betas
55 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
56 | self.to_torch = partial(torch.tensor, dtype=torch.float32)
57 |
58 | def get_sigmas(self, n, device="cpu"):
59 | if n < self.num_timesteps:
60 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
61 | alphas_cumprod = self.alphas_cumprod[timesteps]
62 | elif n == self.num_timesteps:
63 | alphas_cumprod = self.alphas_cumprod
64 | else:
65 | raise ValueError
66 |
67 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
68 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
69 | return torch.flip(sigmas, (0,))
70 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/guiders.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from abc import ABC, abstractmethod
3 | from typing import Dict, List, Literal, Optional, Tuple, Union
4 |
5 | import torch
6 | from einops import rearrange, repeat
7 |
8 | from ...util import append_dims, default
9 |
10 | logpy = logging.getLogger(__name__)
11 |
12 |
13 | class Guider(ABC):
14 | @abstractmethod
15 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
16 | pass
17 |
18 | def prepare_inputs(
19 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict
20 | ) -> Tuple[torch.Tensor, float, Dict]:
21 | pass
22 |
23 |
24 | class VanillaCFG(Guider):
25 | def __init__(self, scale: float):
26 | self.scale = scale
27 |
28 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
29 | x_u, x_c = x.chunk(2)
30 | x_pred = x_u + self.scale * (x_c - x_u)
31 | return x_pred
32 |
33 | def prepare_inputs(self, x, s, c, uc):
34 | c_out = dict()
35 |
36 | for k in c:
37 | if k in ["vector", "crossattn", "concat"]:
38 | c_out[k] = torch.cat((uc[k], c[k]), 0)
39 | else:
40 | assert c[k] == uc[k]
41 | c_out[k] = c[k]
42 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out
43 |
44 |
45 | class IdentityGuider(Guider):
46 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
47 | return x
48 |
49 | def prepare_inputs(
50 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict
51 | ) -> Tuple[torch.Tensor, float, Dict]:
52 | c_out = dict()
53 |
54 | for k in c:
55 | c_out[k] = c[k]
56 |
57 | return x, s, c_out
58 |
59 |
60 | class LinearPredictionGuider(Guider):
61 | def __init__(
62 | self,
63 | max_scale: float,
64 | num_frames: int,
65 | min_scale: float = 1.0,
66 | additional_cond_keys: Optional[Union[List[str], str]] = None,
67 | ):
68 | self.min_scale = min_scale
69 | self.max_scale = max_scale
70 | self.num_frames = num_frames
71 | self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
72 |
73 | additional_cond_keys = default(additional_cond_keys, [])
74 | if isinstance(additional_cond_keys, str):
75 | additional_cond_keys = [additional_cond_keys]
76 | self.additional_cond_keys = additional_cond_keys
77 |
78 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
79 | x_u, x_c = x.chunk(2)
80 |
81 | x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
82 | x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
83 | scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
84 | scale = append_dims(scale, x_u.ndim).to(x_u.device)
85 |
86 | return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
87 |
88 | def prepare_inputs(
89 | self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
90 | ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
91 | c_out = dict()
92 |
93 | for k in c:
94 | if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
95 | c_out[k] = torch.cat((uc[k], c[k]), 0)
96 | else:
97 | assert c[k] == uc[k]
98 | c_out[k] = c[k]
99 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out
100 |
101 |
102 | class TrianglePredictionGuider(LinearPredictionGuider):
103 | def __init__(
104 | self,
105 | max_scale: float,
106 | num_frames: int,
107 | min_scale: float = 1.0,
108 | period: Optional[float] = 1.0,
109 | period_fusing: Literal["mean", "multiply", "max"] = "max",
110 | additional_cond_keys: Optional[Union[List[str], str]] = None,
111 | ):
112 | super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
113 | values = torch.linspace(0, 1, num_frames)
114 | # Constructs a triangle wave
115 | if isinstance(period, float):
116 | period = [period]
117 |
118 | scales = []
119 | for p in period:
120 | scales.append(self.triangle_wave(values, p))
121 |
122 | if period_fusing == "mean":
123 | scale = sum(scales) / len(period)
124 | elif period_fusing == "multiply":
125 | scale = torch.prod(torch.stack(scales), dim=0)
126 | elif period_fusing == "max":
127 | scale = torch.max(torch.stack(scales), dim=0).values
128 | self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
129 |
130 | def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
131 | return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
132 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/loss.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS
7 | from ...modules.encoders.modules import GeneralConditioner
8 | from ...util import append_dims, instantiate_from_config
9 | from .denoiser import Denoiser
10 |
11 |
12 | class StandardDiffusionLoss(nn.Module):
13 | def __init__(
14 | self,
15 | sigma_sampler_config: dict,
16 | loss_weighting_config: dict,
17 | loss_type: str = "l2",
18 | offset_noise_level: float = 0.0,
19 | batch2model_keys: Optional[Union[str, List[str]]] = None,
20 | ):
21 | super().__init__()
22 |
23 | assert loss_type in ["l2", "l1", "lpips"]
24 |
25 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
26 | self.loss_weighting = instantiate_from_config(loss_weighting_config)
27 |
28 | self.loss_type = loss_type
29 | self.offset_noise_level = offset_noise_level
30 |
31 | if loss_type == "lpips":
32 | self.lpips = LPIPS().eval()
33 |
34 | if not batch2model_keys:
35 | batch2model_keys = []
36 |
37 | if isinstance(batch2model_keys, str):
38 | batch2model_keys = [batch2model_keys]
39 |
40 | self.batch2model_keys = set(batch2model_keys)
41 |
42 | def get_noised_input(
43 | self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
44 | ) -> torch.Tensor:
45 | noised_input = input + noise * sigmas_bc
46 | return noised_input
47 |
48 | def forward(
49 | self,
50 | network: nn.Module,
51 | denoiser: Denoiser,
52 | conditioner: GeneralConditioner,
53 | input: torch.Tensor,
54 | batch: Dict,
55 | ) -> torch.Tensor:
56 | cond = conditioner(batch)
57 | return self._forward(network, denoiser, cond, input, batch)
58 |
59 | def _forward(
60 | self,
61 | network: nn.Module,
62 | denoiser: Denoiser,
63 | cond: Dict,
64 | input: torch.Tensor,
65 | batch: Dict,
66 | ) -> Tuple[torch.Tensor, Dict]:
67 | additional_model_inputs = {
68 | key: batch[key] for key in self.batch2model_keys.intersection(batch)
69 | }
70 | sigmas = self.sigma_sampler(input.shape[0]).to(input)
71 |
72 | noise = torch.randn_like(input)
73 | if self.offset_noise_level > 0.0:
74 | offset_shape = (
75 | (input.shape[0], 1, input.shape[2])
76 | if self.n_frames is not None
77 | else (input.shape[0], input.shape[1])
78 | )
79 | noise = noise + self.offset_noise_level * append_dims(
80 | torch.randn(offset_shape, device=input.device),
81 | input.ndim,
82 | )
83 | sigmas_bc = append_dims(sigmas, input.ndim)
84 | noised_input = self.get_noised_input(sigmas_bc, noise, input)
85 |
86 | model_output = denoiser(
87 | network, noised_input, sigmas, cond, **additional_model_inputs
88 | )
89 | w = append_dims(self.loss_weighting(sigmas), input.ndim)
90 | return self.get_loss(model_output, input, w)
91 |
92 | def get_loss(self, model_output, target, w):
93 | if self.loss_type == "l2":
94 | return torch.mean(
95 | (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
96 | )
97 | elif self.loss_type == "l1":
98 | return torch.mean(
99 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
100 | )
101 | elif self.loss_type == "lpips":
102 | loss = self.lpips(model_output, target).reshape(-1)
103 | return loss
104 | else:
105 | raise NotImplementedError(f"Unknown loss type {self.loss_type}")
106 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/loss_weighting.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import torch
4 |
5 |
6 | class DiffusionLossWeighting(ABC):
7 | @abstractmethod
8 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
9 | pass
10 |
11 |
12 | class UnitWeighting(DiffusionLossWeighting):
13 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
14 | return torch.ones_like(sigma, device=sigma.device)
15 |
16 |
17 | class EDMWeighting(DiffusionLossWeighting):
18 | def __init__(self, sigma_data: float = 0.5):
19 | self.sigma_data = sigma_data
20 |
21 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
22 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
23 |
24 |
25 | class VWeighting(EDMWeighting):
26 | def __init__(self):
27 | super().__init__(sigma_data=1.0)
28 |
29 |
30 | class EpsWeighting(DiffusionLossWeighting):
31 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
32 | return sigma**-2.0
33 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/sampling_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scipy import integrate
3 |
4 | from ...util import append_dims
5 |
6 |
7 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
8 | if order - 1 > i:
9 | raise ValueError(f"Order {order} too high for step {i}")
10 |
11 | def fn(tau):
12 | prod = 1.0
13 | for k in range(order):
14 | if j == k:
15 | continue
16 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
17 | return prod
18 |
19 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
20 |
21 |
22 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
23 | if not eta:
24 | return sigma_to, 0.0
25 | sigma_up = torch.minimum(
26 | sigma_to,
27 | eta
28 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
29 | )
30 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
31 | return sigma_down, sigma_up
32 |
33 |
34 | def to_d(x, sigma, denoised):
35 | return (x - denoised) / append_dims(sigma, x.ndim)
36 |
37 |
38 | def to_neg_log_sigma(sigma):
39 | return sigma.log().neg()
40 |
41 |
42 | def to_sigma(neg_log_sigma):
43 | return neg_log_sigma.neg().exp()
44 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/sigma_sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ...util import default, instantiate_from_config
4 |
5 |
6 | class EDMSampling:
7 | def __init__(self, p_mean=-1.2, p_std=1.2):
8 | self.p_mean = p_mean
9 | self.p_std = p_std
10 |
11 | def __call__(self, n_samples, rand=None):
12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
13 | return log_sigma.exp()
14 |
15 |
16 | class DiscreteSampling:
17 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
18 | self.num_idx = num_idx
19 | self.sigmas = instantiate_from_config(discretization_config)(
20 | num_idx, do_append_zero=do_append_zero, flip=flip
21 | )
22 |
23 | def idx_to_sigma(self, idx):
24 | return self.sigmas[idx]
25 |
26 | def __call__(self, n_samples, rand=None):
27 | idx = default(
28 | rand,
29 | torch.randint(0, self.num_idx, (n_samples,)),
30 | )
31 | return self.idx_to_sigma(idx)
32 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/diffusionmodules/wrappers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from packaging import version
4 |
5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
6 |
7 |
8 | class IdentityWrapper(nn.Module):
9 | def __init__(self, diffusion_model, compile_model: bool = False):
10 | super().__init__()
11 | compile = (
12 | torch.compile
13 | if (version.parse(torch.__version__) >= version.parse("2.0.0"))
14 | and compile_model
15 | else lambda x: x
16 | )
17 | self.diffusion_model = compile(diffusion_model)
18 |
19 | def forward(self, *args, **kwargs):
20 | return self.diffusion_model(*args, **kwargs)
21 |
22 |
23 | class OpenAIWrapper(IdentityWrapper):
24 | def forward(
25 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
26 | ) -> torch.Tensor:
27 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
28 | return self.diffusion_model(
29 | x,
30 | timesteps=t,
31 | context=c.get("crossattn", None),
32 | y=c.get("vector", None),
33 | **kwargs,
34 | )
35 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(
34 | device=self.parameters.device
35 | )
36 |
37 | def sample(self):
38 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
39 | device=self.parameters.device
40 | )
41 | return x
42 |
43 | def kl(self, other=None):
44 | if self.deterministic:
45 | return torch.Tensor([0.0])
46 | else:
47 | if other is None:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50 | dim=[1, 2, 3],
51 | )
52 | else:
53 | return 0.5 * torch.sum(
54 | torch.pow(self.mean - other.mean, 2) / other.var
55 | + self.var / other.var
56 | - 1.0
57 | - self.logvar
58 | + other.logvar,
59 | dim=[1, 2, 3],
60 | )
61 |
62 | def nll(self, sample, dims=[1, 2, 3]):
63 | if self.deterministic:
64 | return torch.Tensor([0.0])
65 | logtwopi = np.log(2.0 * np.pi)
66 | return 0.5 * torch.sum(
67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68 | dim=dims,
69 | )
70 |
71 | def mode(self):
72 | return self.mean
73 |
74 |
75 | def normal_kl(mean1, logvar1, mean2, logvar2):
76 | """
77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78 | Compute the KL divergence between two gaussians.
79 | Shapes are automatically broadcasted, so batches can be compared to
80 | scalars, among other use cases.
81 | """
82 | tensor = None
83 | for obj in (mean1, logvar1, mean2, logvar2):
84 | if isinstance(obj, torch.Tensor):
85 | tensor = obj
86 | break
87 | assert tensor is not None, "at least one argument must be a Tensor"
88 |
89 | # Force variances to be Tensors. Broadcasting helps convert scalars to
90 | # Tensors, but it does not work for torch.exp().
91 | logvar1, logvar2 = [
92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93 | for x in (logvar1, logvar2)
94 | ]
95 |
96 | return 0.5 * (
97 | -1.0
98 | + logvar2
99 | - logvar1
100 | + torch.exp(logvar1 - logvar2)
101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102 | )
103 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError("Decay must be between 0 and 1")
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer(
14 | "num_updates",
15 | torch.tensor(0, dtype=torch.int)
16 | if use_num_upates
17 | else torch.tensor(-1, dtype=torch.int),
18 | )
19 |
20 | for name, p in model.named_parameters():
21 | if p.requires_grad:
22 | # remove as '.'-character is not allowed in buffers
23 | s_name = name.replace(".", "")
24 | self.m_name2s_name.update({name: s_name})
25 | self.register_buffer(s_name, p.clone().detach().data)
26 |
27 | self.collected_params = []
28 |
29 | def reset_num_updates(self):
30 | del self.num_updates
31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
32 |
33 | def forward(self, model):
34 | decay = self.decay
35 |
36 | if self.num_updates >= 0:
37 | self.num_updates += 1
38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
39 |
40 | one_minus_decay = 1.0 - decay
41 |
42 | with torch.no_grad():
43 | m_param = dict(model.named_parameters())
44 | shadow_params = dict(self.named_buffers())
45 |
46 | for key in m_param:
47 | if m_param[key].requires_grad:
48 | sname = self.m_name2s_name[key]
49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
50 | shadow_params[sname].sub_(
51 | one_minus_decay * (shadow_params[sname] - m_param[key])
52 | )
53 | else:
54 | assert not key in self.m_name2s_name
55 |
56 | def copy_to(self, model):
57 | m_param = dict(model.named_parameters())
58 | shadow_params = dict(self.named_buffers())
59 | for key in m_param:
60 | if m_param[key].requires_grad:
61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
62 | else:
63 | assert not key in self.m_name2s_name
64 |
65 | def store(self, parameters):
66 | """
67 | Save the current parameters for restoring later.
68 | Args:
69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70 | temporarily stored.
71 | """
72 | self.collected_params = [param.clone() for param in parameters]
73 |
74 | def restore(self, parameters):
75 | """
76 | Restore the parameters stored with the `store` method.
77 | Useful to validate the model with EMA parameters without affecting the
78 | original optimization process. Store the parameters before the
79 | `copy_to` method. After validation (or model saving), use this to
80 | restore the former parameters.
81 | Args:
82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
83 | updated with the stored parameters.
84 | """
85 | for c_param, param in zip(self.collected_params, parameters):
86 | param.data.copy_(c_param.data)
87 |
--------------------------------------------------------------------------------
/third_party/image_generator/sgm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/third_party/image_generator/tests/inference/test_inference.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | from PIL import Image
3 | import pytest
4 | from pytest import fixture
5 | import torch
6 | from typing import Tuple
7 |
8 | from sgm.inference.api import (
9 | model_specs,
10 | SamplingParams,
11 | SamplingPipeline,
12 | Sampler,
13 | ModelArchitecture,
14 | )
15 | import sgm.inference.helpers as helpers
16 |
17 |
18 | @pytest.mark.inference
19 | class TestInference:
20 | @fixture(scope="class", params=model_specs.keys())
21 | def pipeline(self, request) -> SamplingPipeline:
22 | pipeline = SamplingPipeline(request.param)
23 | yield pipeline
24 | del pipeline
25 | torch.cuda.empty_cache()
26 |
27 | @fixture(
28 | scope="class",
29 | params=[
30 | [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
31 | [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
32 | ],
33 | ids=["SDXL_V1", "SDXL_V0_9"],
34 | )
35 | def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]:
36 | base_pipeline = SamplingPipeline(request.param[0])
37 | refiner_pipeline = SamplingPipeline(request.param[1])
38 | yield base_pipeline, refiner_pipeline
39 | del base_pipeline
40 | del refiner_pipeline
41 | torch.cuda.empty_cache()
42 |
43 | def create_init_image(self, h, w):
44 | image_array = numpy.random.rand(h, w, 3) * 255
45 | image = Image.fromarray(image_array.astype("uint8")).convert("RGB")
46 | return helpers.get_input_image_tensor(image)
47 |
48 | @pytest.mark.parametrize("sampler_enum", Sampler)
49 | def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum):
50 | output = pipeline.text_to_image(
51 | params=SamplingParams(sampler=sampler_enum.value, steps=10),
52 | prompt="A professional photograph of an astronaut riding a pig",
53 | negative_prompt="",
54 | samples=1,
55 | )
56 |
57 | assert output is not None
58 |
59 | @pytest.mark.parametrize("sampler_enum", Sampler)
60 | def test_img2img(self, pipeline: SamplingPipeline, sampler_enum):
61 | output = pipeline.image_to_image(
62 | params=SamplingParams(sampler=sampler_enum.value, steps=10),
63 | image=self.create_init_image(pipeline.specs.height, pipeline.specs.width),
64 | prompt="A professional photograph of an astronaut riding a pig",
65 | negative_prompt="",
66 | samples=1,
67 | )
68 | assert output is not None
69 |
70 | @pytest.mark.parametrize("sampler_enum", Sampler)
71 | @pytest.mark.parametrize(
72 | "use_init_image", [True, False], ids=["img2img", "txt2img"]
73 | )
74 | def test_sdxl_with_refiner(
75 | self,
76 | sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
77 | sampler_enum,
78 | use_init_image,
79 | ):
80 | base_pipeline, refiner_pipeline = sdxl_pipelines
81 | if use_init_image:
82 | output = base_pipeline.image_to_image(
83 | params=SamplingParams(sampler=sampler_enum.value, steps=10),
84 | image=self.create_init_image(
85 | base_pipeline.specs.height, base_pipeline.specs.width
86 | ),
87 | prompt="A professional photograph of an astronaut riding a pig",
88 | negative_prompt="",
89 | samples=1,
90 | return_latents=True,
91 | )
92 | else:
93 | output = base_pipeline.text_to_image(
94 | params=SamplingParams(sampler=sampler_enum.value, steps=10),
95 | prompt="A professional photograph of an astronaut riding a pig",
96 | negative_prompt="",
97 | samples=1,
98 | return_latents=True,
99 | )
100 |
101 | assert isinstance(output, (tuple, list))
102 | samples, samples_z = output
103 | assert samples is not None
104 | assert samples_z is not None
105 | refiner_pipeline.refiner(
106 | params=SamplingParams(sampler=sampler_enum.value, steps=10),
107 | image=samples_z,
108 | prompt="A professional photograph of an astronaut riding a pig",
109 | negative_prompt="",
110 | samples=1,
111 | )
112 |
--------------------------------------------------------------------------------
/tools/camera.py:
--------------------------------------------------------------------------------
1 | import math, torch
2 | from dataLoader.utils import build_rays, fov_to_ixt
3 |
4 | def getProjectionMatrix(znear, zfar, fovX, fovY):
5 |
6 | tanHalfFovY = math.tan((fovY / 2))
7 | tanHalfFovX = math.tan((fovX / 2))
8 |
9 | P = torch.zeros(4, 4)
10 |
11 | z_sign = 1.0
12 |
13 | P[0, 0] = 1 / tanHalfFovX
14 | P[1, 1] = 1 / tanHalfFovY
15 | P[3, 2] = z_sign
16 | P[2, 2] = z_sign * zfar / (zfar - znear)
17 | P[2, 3] = -(zfar * znear) / (zfar - znear)
18 | return P
19 |
20 |
21 | class MiniCam:
22 | def __init__(self, c2w, width, height, fovy, fovx, znear, zfar):
23 | # c2w (pose) should be in NeRF convention.
24 |
25 | self.image_width = width
26 | self.image_height = height
27 | self.FoVy = fovy
28 | self.FoVx = fovx
29 | self.znear = znear
30 | self.zfar = zfar
31 |
32 |
33 | w2c = torch.inverse(c2w)
34 |
35 | # rectify...
36 | # w2c[1:3, :3] *= -1
37 | # w2c[:3, 3] *= -1
38 |
39 | self.view_world_transform = c2w
40 | self.world_view_transform = w2c.transpose(0, 1)
41 | self.projection_matrix = getProjectionMatrix(
42 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
43 | ).transpose(0, 1)
44 |
45 | self.full_proj_transform = (self.world_view_transform @ self.projection_matrix).to(torch.float32)
46 | self.camera_center = -c2w[:3, 3]
47 |
48 | def to_device(self, device):
49 | self.world_view_transform = self.world_view_transform.to(device)
50 | self.projection_matrix = self.projection_matrix.to(device)
51 | self.camera_center = self.camera_center.to(device)
52 | self.full_proj_transform = self.full_proj_transform.to(device)
53 |
54 | def get_rays(self):
55 | ixt = fov_to_ixt(torch.tensor((self.FoVx,self.FoVy)), torch.tensor((self.image_width,self.image_height)))
56 | rays = build_rays(self.view_world_transform.cpu().numpy()[None], ixt[None], self.image_height, self.image_width)
57 | return torch.from_numpy(rays)
--------------------------------------------------------------------------------
/tools/depth.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def abs_error(depth_pred, depth_gt, mask):
4 | depth_pred, depth_gt = depth_pred[mask], depth_gt[mask]
5 |
6 | err = depth_pred - depth_gt
7 | return np.abs(err) if type(depth_pred) is np.ndarray else err.abs()
8 |
9 | def acc_threshold(depth_pred, depth_gt, mask, threshold):
10 | """
11 | computes the percentage of pixels whose depth error is less than @threshold
12 | """
13 | errors = abs_error(depth_pred, depth_gt, mask)
14 | acc_mask = errors < threshold
15 | return acc_mask.astype('float') if type(depth_pred) is np.ndarray else acc_mask.float()
--------------------------------------------------------------------------------
/tools/download_dataset.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import hf_hub_download
2 | import os
3 | import shutil
4 | import argparse
5 | from concurrent.futures import ThreadPoolExecutor
6 |
7 | def download_folder(repo_id, folder, local_dir, files, repo_type="dataset"):# model, dataset, or space.
8 |
9 | def download_file(file):
10 | cache_file_path = hf_hub_download(
11 | repo_id=repo_id,
12 | filename=file,
13 | subfolder=folder,
14 | # repo_type=repo_type,
15 | cache_dir=f'{local_dir}/{folder}/_temp',
16 | )
17 |
18 | target_path = f'{local_dir}/{folder}/{file}'
19 | os.makedirs(os.path.dirname(target_path), exist_ok=True)
20 | os.system(f'mv {os.path.realpath(cache_file_path)} {target_path}')
21 |
22 | with ThreadPoolExecutor() as executor:
23 | futures = []
24 | for file in files:
25 | futures.append(executor.submit(download_file, file))
26 | for future in futures:
27 | future.result()
28 |
29 |
30 | # Example usage
31 | repo_id = "apchen/LaRa" # Replace with your repository ID
32 | folder_path = "dataset" # Replace with the path to the folder in the repository
33 | local_dir = "." # Replace with your local destination directory
34 |
35 | gso_list = ['GSO.zip']
36 | co3d_list = ['Co3D/co3d_hydrant.h5','Co3D/co3d_teddybear.h5']
37 | gobjaverse_list = [f'gobjaverse/gobjaverse_part_{i+1:02d}.h5' for i in range(32)] + ['gobjaverse/gobjaverse.h5']
38 |
39 | if __name__ == "__main__":
40 | parser = argparse.ArgumentParser(description="download files.")
41 | parser.add_argument("dtype", type=str, default="gso", help="one of [gso,co3d,objaverse,all]")
42 |
43 | args = parser.parse_args()
44 |
45 | if "gso" == args.dtype:
46 | # download_folder(repo_id, folder_path, local_dir,gso_list)
47 | os.system(f'unzip {local_dir}/{folder_path}/{gso_list[0]} -d {local_dir}/{folder_path}')
48 | os.system(f'rm {local_dir}/{folder_path}/{gso_list[0]}')
49 | elif "co3d" == args.dtype:
50 | download_folder(repo_id, folder_path, local_dir,co3d_list)
51 | elif "objaverse" == args.dtype:
52 | download_folder(repo_id, folder_path, local_dir, gobjaverse_list)
53 | elif "all" == args.dtype:
54 | download_folder(repo_id, folder_path, local_dir,gso_list+co3d_list+gobjaverse_list)
55 | os.system(f'unzip {local_dir}/{folder_path}/{gso_list[0]} -d {local_dir}/{folder_path}')
56 | os.system(f'rm {local_dir}/{folder_path}/{gso_list[0]}')
57 |
58 | # shutil.rmtree(f'{local_dir}/{folder_path}/_temp')
59 |
--------------------------------------------------------------------------------
/tools/download_objaverse.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 |
3 | import os, sys, json
4 | from multiprocessing import Pool
5 |
6 | def download_url(item):
7 | global save_dir
8 | oss_full_dir = os.path.join("https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/objaverse_tar", item+".tar")
9 | os.system("wget -P {} {}".format(os.path.join(save_dir, item.split("/")[0]), oss_full_dir))
10 |
11 | def get_all_folders(root):
12 | all_folders = []
13 | categrey = os.listdir(root)
14 | for item in categrey:
15 | if not os.path.isdir(f'{root}/{item}'):
16 | continue
17 | folders = os.listdir(f'{root}/{item}')
18 | all_folders += [f'{root}/{item}/{folder}' for folder in folders]
19 | return all_folders
20 |
21 | def folder_to_json(exist_files):
22 | files = []
23 | for item in exist_files:
24 | split = item.split('/')[-2:]
25 | files.append(f'{split[0]}/{split[1][:-4]}')
26 | return files
27 |
28 | def filterout_existing(json, exist_files):
29 | for item in exist_files:
30 | json.remove(item)
31 | return json
32 |
33 | if __name__=="__main__":
34 | # download_gobjaverse_280k index file
35 | # wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/gobjaverse_280k.json
36 | assert len(sys.argv) == 4, "eg: python download_objaverse.py ./data /path/to/json_file 10"
37 | save_dir = str(sys.argv[1])
38 | json_file = str(sys.argv[2])
39 | n_threads = int(sys.argv[3])
40 |
41 | data = json.load(open(json_file))[:100]
42 |
43 | exist_files = get_all_folders(save_dir)
44 | exist_files = folder_to_json(exist_files)
45 |
46 | print(len(data))
47 | data = filterout_existing(data, exist_files)
48 | print(len(data))
49 |
50 | p = Pool(n_threads)
51 | p.map(download_url, data)
52 |
--------------------------------------------------------------------------------
/tools/hdf5_split_merge.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import argparse
3 | import os
4 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
5 |
6 | def split_hdf5_file(input_file, output_prefix, num_splits):
7 | with h5py.File(input_file, 'r') as f:
8 | keys = sorted(list(f.keys()))
9 | chunk_size = len(keys) // num_splits
10 |
11 | def write_chunk(i, keys_chunk):
12 | output_file = f"{output_prefix}_part_{i+1}.h5"
13 | with h5py.File(output_file, 'w') as out_f:
14 | for key in keys_chunk:
15 | f.copy(key, out_f)
16 |
17 | with ThreadPoolExecutor() as executor:
18 | futures = []
19 | for i in range(16):
20 | keys_chunk = keys[i*chunk_size: (i+1)*chunk_size]
21 | futures.append(executor.submit(write_chunk, i, keys_chunk))
22 | for future in futures:
23 | future.result()
24 |
25 | print(f"Split into {num_splits} files with prefix '{output_prefix}'.")
26 |
27 | def merge_hdf5_files(output_file, input_files):
28 | with h5py.File(output_file, 'w') as out_f:
29 | def copy_data(input_file):
30 | with h5py.File(input_file, 'r') as in_f:
31 | for key in in_f.keys():
32 | in_f.copy(key, out_f)
33 |
34 | with ThreadPoolExecutor() as executor:
35 | futures = [executor.submit(copy_data, input_file) for input_file in input_files]
36 | for future in futures:
37 | future.result()
38 |
39 | print(f"Merged files into '{output_file}'.")
40 |
41 | def get_absolute_paths(directory, prefix):
42 | return [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith(prefix)]
43 |
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser(description="Split and merge HDF5 files.")
47 |
48 | subparsers = parser.add_subparsers(dest="command", required=True)
49 |
50 | split_parser = subparsers.add_parser("split", help="Split an HDF5 file into multiple files.")
51 | split_parser.add_argument("input_file", type=str, help="Input HDF5 file to split.")
52 | split_parser.add_argument("output_prefix", type=str, help="Output prefix for split files.")
53 | split_parser.add_argument("num_splits", type=int, help="Number of splits.")
54 |
55 | merge_parser = subparsers.add_parser("merge", help="Merge multiple HDF5 files into one file.")
56 | merge_parser.add_argument("output_file", type=str, help="Output HDF5 file to create.")
57 | merge_parser.add_argument("file_prefix", type=str, help="Input HDF5 files to merge.")
58 |
59 | args = parser.parse_args()
60 |
61 | if args.command == "split":
62 | split_hdf5_file(args.input_file, args.output_prefix, args.num_splits)
63 | elif args.command == "merge":
64 | input_files = get_absolute_paths(args.input_directory, args.file_prefix)
65 | merge_hdf5_files(args.output_file, input_files)
66 |
--------------------------------------------------------------------------------
/tools/meshRender.py:
--------------------------------------------------------------------------------
1 | import os,sys
2 | import mitsuba as mi
3 | from tqdm import tqdm
4 | mi.set_variant('cuda_ad_rgb', 'llvm_ad_rgb')
5 | sys.path.append(os.path.join(os.path.dirname(__file__), "lib"))
6 |
7 | import numpy as np
8 |
9 | def render_mesh(cams, mesh_path, spp = 512, white_bg=True):
10 |
11 | image_width = cams[0].image_width
12 | image_height = cams[0].image_height
13 |
14 | mesh_type = os.path.splitext(mesh_path)[1][1:]
15 | sdf_scene = mi.load_file("configs/render/scene.xml", resx=image_width, resy=image_height, mesh_path=mesh_path, mesh_type=mesh_type,
16 | integrator_file="configs/render/integrator_path.xml", update_scene=False, spp=spp, max_depth=8)
17 |
18 | imgs = []
19 | pbar = tqdm(total=len(cams), desc='Files', position=0)
20 | b2c = np.array([[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32)
21 | for cam in cams:
22 | c2w, fov = cam.view_world_transform.numpy(), cam.FoVx
23 | fov = np.degrees(fov)
24 | image_width = cam.image_width
25 | image_height = cam.image_height
26 |
27 | to_world = c2w @ b2c
28 | to_world_transform = mi.ScalarTransform4f(to_world.tolist())
29 |
30 | sensor = mi.load_dict({
31 | 'type': 'perspective',
32 | 'fov': fov, 'sampler': {'type': 'independent'},
33 | 'film': {'type': 'hdrfilm', 'width': image_width, 'height': image_height,
34 | 'pixel_filter': {'type': 'gaussian'}, 'pixel_format': 'rgba'},
35 | 'to_world': to_world_transform
36 | })
37 |
38 | img = mi.render(sdf_scene, sensor=sensor, spp=spp)
39 | img = mi.Bitmap(img).convert(mi.Bitmap.PixelFormat.RGBA, mi.Struct.Type.UInt8, srgb_gamma=True)
40 | # img.write(f'123.png')
41 | img = np.array(img, copy=False)
42 | if white_bg:
43 | img = img.astype(np.float32)/255
44 | img = img[...,:3]*img[...,3:] + (1.0-img[...,3:])*np.array([0.722,0.376,0.161])
45 | img = np.round(img*255).astype('uint8')
46 |
47 | imgs.append(img)
48 | pbar.update(1)
49 | pbar.set_description("Mesh extraction Done. Rendering *_mesh.mp4: ")
50 |
51 | return np.stack(imgs)
--------------------------------------------------------------------------------
/train_lightning.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | n_thread = 2
4 | os.environ["MKL_NUM_THREADS"] = f"{n_thread}"
5 | os.environ["NUMEXPR_NUM_THREADS"] = f"{n_thread}"
6 | os.environ["OMP_NUM_THREADS"] = f"4"
7 | os.environ["VECLIB_MAXIMUM_THREADS"] = f"{n_thread}"
8 | os.environ["OPENBLAS_NUM_THREADS"] = f"{n_thread}"
9 |
10 |
11 | import torch
12 | from dataLoader import dataset_dict
13 | from omegaconf import OmegaConf
14 |
15 | from lightning.system import system
16 | from torch.utils.data import DataLoader
17 | import pytorch_lightning as L
18 |
19 | from datetime import datetime
20 |
21 |
22 | from pytorch_lightning.loggers import TensorBoardLogger
23 | from pytorch_lightning.loggers import WandbLogger
24 | from pytorch_lightning.callbacks import ModelCheckpoint
25 | from pytorch_lightning.strategies import DDPStrategy
26 |
27 | def main(cfg):
28 |
29 | torch.set_float32_matmul_precision('medium')
30 | torch.autograd.set_detect_anomaly(True)
31 | print("Using PyTorch {} and Lightning {}".format(torch.__version__, L.__version__))
32 |
33 | # data loader
34 | train_dataset = dataset_dict[cfg.train_dataset.dataset_name]
35 | train_loader = DataLoader(train_dataset(cfg.train_dataset),
36 | batch_size= cfg.train.batch_size,
37 | num_workers= 8,
38 | shuffle=True,
39 | pin_memory=False)
40 | val_dataset = dataset_dict[cfg.test_dataset.dataset_name]
41 | val_loader = DataLoader(val_dataset(cfg.test_dataset),
42 | batch_size=cfg.test.batch_size,
43 | num_workers=2,
44 | shuffle=True,
45 | pin_memory=False)
46 |
47 | # build logger
48 | project_name = cfg.exp_name.split("/")[0]
49 | exp_name = cfg.exp_name.split("/")[1]
50 |
51 | if cfg.logger.name == "tensorboard":
52 | logger = TensorBoardLogger(save_dir=cfg.logger.dir, name=exp_name)
53 | elif cfg.logger.name == "wandb":
54 | os.environ["WANDB__SERVICE_WAIT"] = "300"
55 | logger = WandbLogger(name=exp_name,project=project_name, save_dir=cfg.logger.dir, entity="large-reconstruction-model")
56 |
57 | # Set up ModelCheckpoint callback
58 | checkpoint_callback = ModelCheckpoint(
59 | dirpath=cfg.logger.dir, # Path where checkpoints will be saved
60 | filename='{epoch}', # Filename for the checkpoints
61 | # save_top_k=1, # Set to -1 to save all checkpoints
62 | every_n_epochs=5, # Save a checkpoint every K epochs
63 | save_on_train_epoch_end=True, # Ensure it saves at the end of an epoch, not the beginning
64 | )
65 |
66 | my_system = system(cfg)
67 |
68 | trainer = L.Trainer(devices=cfg.gpu_id,
69 | num_nodes=1,
70 | max_epochs=cfg.train.n_epoch,
71 | accelerator='gpu',
72 | strategy=DDPStrategy(find_unused_parameters=True),
73 | accumulate_grad_batches=2,
74 | logger=logger,
75 | gradient_clip_val=0.5,
76 | precision="bf16-mixed",
77 | callbacks=[checkpoint_callback],
78 | check_val_every_n_epoch=cfg.train.check_val_every_n_epoch,
79 | limit_val_batches=cfg.train.limit_val_batches, # Run on only 10% of the validation data
80 | limit_train_batches=cfg.train.limit_train_batches,
81 | )
82 |
83 |
84 | t0 = datetime.now()
85 | trainer.fit(
86 | my_system,
87 | train_dataloaders=train_loader,
88 | val_dataloaders=val_loader,
89 | ckpt_path=cfg.model.ckpt_path
90 | )
91 |
92 | dt = datetime.now() - t0
93 | print('Training took {}'.format(dt))
94 |
95 |
96 | if __name__ == '__main__':
97 |
98 | base_conf = OmegaConf.load('configs/base.yaml')
99 |
100 | cli_conf = OmegaConf.from_cli()
101 | cfg = OmegaConf.merge(base_conf, cli_conf)
102 |
103 | main(cfg)
--------------------------------------------------------------------------------