├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets └── demo.gif ├── configs ├── base.yaml ├── infer.yaml └── render │ ├── cathedral.hdr │ ├── cathedral.xml │ ├── common.xml │ ├── integrator_path.xml │ ├── scene.xml │ └── sensors.xml ├── dataLoader ├── __init__.py ├── gobjverse.py ├── google_scanned_objects.py ├── instant3d.py ├── mipnerf.py ├── mvgen.py └── utils.py ├── environment.yml ├── eval_all.py ├── evaluation.py ├── lightning ├── loss.py ├── network.py ├── renderer.py ├── renderer_2dgs.py ├── system.py ├── utils.py └── vis.py ├── third_party └── image_generator │ ├── .github │ └── workflows │ │ ├── black.yml │ │ ├── test-build.yaml │ │ └── test-inference.yml │ ├── .gitignore │ ├── CODEOWNERS │ ├── LICENSE-CODE │ ├── README.md │ ├── assets │ ├── 000.jpg │ ├── sv3d.gif │ └── tile.gif │ ├── configs │ ├── example_training │ │ ├── autoencoder │ │ │ └── kl-f4 │ │ │ │ ├── imagenet-attnfree-logvar.yaml │ │ │ │ └── imagenet-kl_f8_8chn.yaml │ │ ├── imagenet-f8_cond.yaml │ │ ├── toy │ │ │ ├── cifar10_cond.yaml │ │ │ ├── mnist.yaml │ │ │ ├── mnist_cond.yaml │ │ │ ├── mnist_cond_discrete_eps.yaml │ │ │ ├── mnist_cond_l1_loss.yaml │ │ │ └── mnist_cond_with_ema.yaml │ │ ├── txt2img-clipl-legacy-ucg-training.yaml │ │ └── txt2img-clipl.yaml │ ├── inference │ │ ├── sd_2_1.yaml │ │ ├── sd_2_1_768.yaml │ │ ├── sd_xl_base.yaml │ │ ├── sd_xl_refiner.yaml │ │ ├── sv3d_p.yaml │ │ ├── sv3d_u.yaml │ │ ├── svd.yaml │ │ └── svd_image_decoder.yaml │ └── sd_xl_base.yaml │ ├── data │ └── DejaVuSans.ttf │ ├── generator.py │ ├── main.py │ ├── model_licenses │ ├── LICENCE-SD-Turbo │ ├── LICENSE-SDXL-Turbo │ ├── LICENSE-SDXL0.9 │ ├── LICENSE-SDXL1.0 │ ├── LICENSE-SV3D │ └── LICENSE-SVD │ ├── pyproject.toml │ ├── pytest.ini │ ├── scripts │ ├── __init__.py │ ├── demo │ │ ├── __init__.py │ │ ├── detect.py │ │ ├── discretization.py │ │ ├── gradio_app.py │ │ ├── sampling.py │ │ ├── streamlit_helpers.py │ │ ├── sv3d_helpers.py │ │ ├── turbo.py │ │ └── video_sampling.py │ ├── sampling │ │ ├── configs │ │ │ ├── sv3d_p.yaml │ │ │ ├── sv3d_u.yaml │ │ │ ├── svd.yaml │ │ │ ├── svd_image_decoder.yaml │ │ │ ├── svd_xt.yaml │ │ │ ├── svd_xt_1_1.yaml │ │ │ └── svd_xt_image_decoder.yaml │ │ └── simple_video_sample.py │ ├── tests │ │ └── attention.py │ └── util │ │ ├── __init__.py │ │ └── detection │ │ ├── __init__.py │ │ ├── nsfw_and_watermark_dectection.py │ │ ├── p_head_v1.npz │ │ └── w_head_v1.npz │ ├── sgm │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── cifar10.py │ │ ├── dataset.py │ │ └── mnist.py │ ├── inference │ │ ├── api.py │ │ └── helpers.py │ ├── lr_scheduler.py │ ├── models │ │ ├── __init__.py │ │ ├── autoencoder.py │ │ └── diffusion.py │ ├── modules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── autoencoding │ │ │ ├── __init__.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── discriminator_loss.py │ │ │ │ └── lpips.py │ │ │ ├── lpips │ │ │ │ ├── __init__.py │ │ │ │ ├── loss │ │ │ │ │ ├── .gitignore │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── lpips.py │ │ │ │ ├── model │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── model.py │ │ │ │ ├── util.py │ │ │ │ └── vqperceptual.py │ │ │ ├── regularizers │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ └── quantize.py │ │ │ └── temporal_ae.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── denoiser.py │ │ │ ├── denoiser_scaling.py │ │ │ ├── denoiser_weighting.py │ │ │ ├── discretizer.py │ │ │ ├── guiders.py │ │ │ ├── loss.py │ │ │ ├── loss_weighting.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ ├── sampling.py │ │ │ ├── sampling_utils.py │ │ │ ├── sigma_sampling.py │ │ │ ├── util.py │ │ │ ├── video_model.py │ │ │ └── wrappers.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ └── video_attention.py │ └── util.py │ └── tests │ └── inference │ └── test_inference.py ├── tools ├── camera.py ├── camera_utils.py ├── depth.py ├── download_dataset.py ├── download_objaverse.py ├── gen_video_path.py ├── hdf5_split_merge.py ├── img_utils.py ├── meshExtractor.py ├── meshRender.py ├── prepare_dataset_co3d.py ├── prepare_dataset_objaverse.py └── rsh.py └── train_lightning.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | .ipynb_checkpoints/ 4 | *.py[cod] 5 | *.so 6 | *.orig 7 | *.o 8 | *.json 9 | *.pth 10 | *.npy 11 | *.ipynb 12 | *.png 13 | logs/* 14 | outputs/* 15 | ckpts/* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/diff-surfel-rasterization"] 2 | path = third_party/diff-surfel-rasterization 3 | url = git@github.com:hbb1/diff-surfel-rasterization.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Florent Bartoccioni and Eloi Zablocki and Andrei Bursuc and Patrick Perez and Matthieu Cord and Karteek Alahari 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LaRa: Efficient Large-Baseline Radiance Fields 2 | 3 | [Project page](https://apchenstu.github.io/LaRa/) | [Paper](https://arxiv.org/abs/2407.04699) | [Data](https://huggingface.co/apchen/LaRa/tree/main/dataset) | [Checkpoint](https://huggingface.co/apchen/LaRa/tree/main/ckpts) |
4 | 5 | ![Teaser image](assets/demo.gif) 6 | 7 | ## ⭐ New Features 8 | - 2024/04/05: Important updates - 9 | Now our method supports half precision training, achieving over **100% faster** convergence and about **1.5dB** gains with less iterations! 10 | 11 | | Model | PSNR ↑ | SSIM ↑ | Abs err (Geo) ↓ | Epoch | Time(day) | ckpt | 12 | | ------ | ------ | ------ | ------ | ------ | ------ | ------ | 13 | | Paper | 27.65 | 0.951 | 0.0654 | 50 | 3.5 | ------ | 14 | | bf16 | 29.15 | 0.956 | 0.0574 | 30 | 1.5 | [Download](https://huggingface.co/apchen/LaRa/tree/main/ckpts/) | 15 | 16 | Please download the pre-trained checkpoint from the provided link and place it in the `ckpts` folder. 17 | 18 | # Installation 19 | 20 | ``` 21 | git clone https://github.com/autonomousvision/LaRa.git --recursive 22 | conda env create --file environment.yml 23 | conda activate lara 24 | ``` 25 | 26 | 27 | # Dataset 28 | We used the processed [gobjaverse dataset](https://aigc3d.github.io/gobjaverse/) for training. A download script `tools/download_dataset.py` is provided to automatically download the datasets. 29 | 30 | ``` 31 | python tools/download_dataset.py all 32 | ``` 33 | Note: The GObjaverse dataset requires approximately 1.4 TB of storage. You can also download a subset of the dataset. Please refer to the provided script for details. Please manually delete the `_temp` folder after completing the download. 34 | 35 | If you would like to process the data by yourself, we provide preprocess scripts for the gobjaverse and co3d datasets, please check `tools/prepare_dataset_*`. 36 | You can also download our preprocessed data and put them to `dataset` folder: 37 | * [gobjaverse](#gobjaverse) 38 | * [Google Scaned Object](#GSO) 39 | * [Co3D](#Co3D) 40 | * Instant3D - Please contact the authors of Instant3D if you wish to obtain the data for comparison. 41 | # Training 42 | ``` 43 | python train_lightning.py 44 | ``` 45 | **note:** You can configure the GPU id and other parameter with `configs/base.yaml`. 46 | 47 | # Evaluation 48 | Our method supports the reconstruction of radiance fields from **multi-view**, **text**, and **single view** inputs. We provide a pre-trained checkpoint at [ckpt](https://huggingface.co/apchen/LaRa/resolve/main/ckpts/epoch%3D29.ckpt). 49 | 50 | ## multi-view to 3D 51 | To reproduce the table results, you can simply use: 52 | ``` 53 | python eval_all.py 54 | ``` 55 | **note:** 56 | - Please double-check that the paths inside the script are correct for your specific case. 57 | - Please specify the video_frames and save_mesh [labels](https://github.com/autonomousvision/LaRa/blob/main/eval_all.py#L11) if you would like to output mesh or video during the evaluation 58 | 59 | ## text to 3D 60 | ``` 61 | python evaluation.py configs/infer.yaml 62 | infer.ckpt_path=ckpts/epoch=29.ckpt 63 | infer.save_folder=outputs/prompts/ 64 | infer.dataset.generator_type=xxx 65 | infer.dataset.prompts=["a car made out of sushi","a beautiful rainbow fish"] 66 | ``` 67 | **note:** This part is currently unavailable due to a permissions issue. I will look for an alternative text-to-multi-view generator later next week. 68 | 69 | 70 | ## single view to 3D 71 | ``` 72 | python evaluation.py configs/infer.yaml 73 | infer.ckpt_path=ckpts/epoch=29.ckpt 74 | infer.save_folder=outputs/single-view/ 75 | infer.dataset.generator_type="zero123plus-v1" 76 | infer.dataset.image_pathes=\["assets/examples/13_realfusion_cherry_1.png"\] 77 | ``` 78 | **note:** It supports the generator types `zero123plus-v1.1` and `zero123plus-v1`. 79 | 80 | 81 | 82 | ## Acknowledgements 83 | Our render is built upon [2DGS](https://github.com/hbb1/2d-gaussian-splatting). The data preprocessing code for the Co3D dataset is partially borrowed from [Splatter-Image](https://github.com/szymanowiczs/splatter-image/blob/main/data_preprocessing/preprocess_co3d.py). Additionally, the script for generating multi-view images from text and single view image is sourced from [GRM](https://github.com/justimyhxu/grm). We thank all the authors for their great repos. 84 | 85 | ## Citation 86 | If you find our code or paper helps, please consider citing: 87 | ```bibtex 88 | @inproceedings{LaRa, 89 | author = {Anpei Chen and Haofei Xu and Stefano Esposito and Siyu Tang and Andreas Geiger}, 90 | title = {LaRa: Efficient Large-Baseline Radiance Fields}, 91 | booktitle = {European Conference on Computer Vision (ECCV)}, 92 | year = {2024} 93 | } 94 | ``` 95 | 96 | 97 | -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/assets/demo.gif -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | gpu_id: [4,5,6,7] 2 | 3 | exp_name: LaRa/release-test 4 | n_views: 4 5 | 6 | model: 7 | 8 | encoder_backbone: 'vit_base_patch16_224.dino' # ['vit_small_patch16_224.dino','vit_base_patch16_224.dino'] 9 | 10 | n_groups: [16] # n_groups for local attention 11 | n_offset_groups: 32 # offset radius of 1/n_offset_groups of the scene size 12 | 13 | K: 2 # primitives per-voxel 14 | sh_degree: 1 # view dependent color 15 | 16 | num_layers: 12 17 | num_heads: 16 18 | 19 | view_embed_dim: 32 20 | embedding_dim: 256 21 | 22 | vol_feat_reso: 16 23 | vol_embedding_reso: 32 24 | 25 | vol_embedding_out_dim: 80 26 | 27 | ckpt_path: null # specify a ckpt path if you want to continue training 28 | 29 | train_dataset: 30 | dataset_name: gobjeverse 31 | data_root: dataset/gobjaverse/gobjaverse.h5 32 | 33 | split: train 34 | img_size: [512,512] # image resolution 35 | n_group: ${n_views} # image resolution 36 | n_scenes: 3000000 37 | load_normal: True 38 | 39 | 40 | 41 | test_dataset: 42 | dataset_name: gobjeverse 43 | data_root: dataset/gobjaverse/gobjaverse.h5 44 | 45 | split: test 46 | img_size: [512,512] 47 | n_group: ${n_views} 48 | n_scenes: 3000000 49 | load_normal: True 50 | 51 | train: 52 | batch_size: 3 53 | lr: 4e-4 54 | beta1: 0.9 55 | beta2: 0.95 56 | weight_decay: 0.05 57 | # betas: [0.9, 0.95] 58 | warmup_iters: 1000 59 | n_epoch: 30 60 | limit_train_batches: 0.2 61 | limit_val_batches: 0.02 62 | check_val_every_n_epoch: 1 63 | start_fine: 5000 64 | use_rand_views: False 65 | test: 66 | batch_size: 3 67 | 68 | logger: 69 | name: tensorboard 70 | dir: logs/${exp_name} 71 | -------------------------------------------------------------------------------- /configs/infer.yaml: -------------------------------------------------------------------------------- 1 | n_views: 4 2 | 3 | infer: 4 | dataset: 5 | # dataset_name: gobjeverse 6 | # data_root: dataset/gobjaverse_280k/gobjaverse_280k.hdf5 7 | # data_root: dataset/Co3D/co3d_teddybear.hdf5 8 | # data_root: dataset/Co3D/co3d_hydrant.hdf5 9 | 10 | # dataset_name: GSO 11 | # data_root: dataset/google_scanned_objects 12 | 13 | # dataset_name: instant3d 14 | # data_root: dataset/instant3D 15 | 16 | # text to 3D 17 | dataset_name: mvgen 18 | generator_type: instant3d 19 | prompts: ["a car made out of sushi"] 20 | image_pathes: [] 21 | 22 | ## single view to 3D 23 | # dataset_name: mvgen 24 | # generator_type: zero123plus-v1.1 # zero123plus-v1.1,zero123plus-v1.2,sv3d 25 | # prompts: [] 26 | # image_pathes: ['examples/19_dalle3_stump1.png'] 27 | 28 | # # unposed inputs 29 | # dataset_name: unposed 30 | # image_pathes: examples/unposed/*.png 31 | 32 | split: test 33 | img_size: [512,512] 34 | n_group: 4 35 | n_scenes: 30000 36 | num_workers: 0 37 | batch_size: 1 38 | 39 | load_normal: False 40 | 41 | ckpt_path: ckpts/lara.ckpt 42 | 43 | eval_novel_view_only: True 44 | eval_depth: [] 45 | metric_path: None 46 | 47 | save_folder: outputs/video_vis/mvgen 48 | video_frames: 120 49 | mesh_video_frames: 0 50 | 51 | save_mesh: True 52 | aabb: [-0.5,-0.5,-0.5,0.5,0.5,0.5] 53 | 54 | finetuning: 55 | with_ft: False 56 | steps: 500 57 | 58 | # lr 59 | position_lr: 0.000016 60 | feature_lr: 0.0025 61 | opacity_lr: 0.05 62 | scaling_lr: 0.005 63 | rotation_lr: 0.001 64 | 65 | 66 | -------------------------------------------------------------------------------- /configs/render/cathedral.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/configs/render/cathedral.hdr -------------------------------------------------------------------------------- /configs/render/cathedral.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /configs/render/common.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /configs/render/integrator_path.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /configs/render/scene.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /configs/render/sensors.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 24 | 25 | 26 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /dataLoader/__init__.py: -------------------------------------------------------------------------------- 1 | from dataLoader.gobjverse import gobjverse 2 | from dataLoader.google_scanned_objects import GoogleObjsDataset 3 | from dataLoader.instant3d import Instant3DObjsDataset 4 | from dataLoader.mipnerf import MipNeRF360Dataset 5 | from dataLoader.mvgen import MVGenDataset 6 | 7 | dataset_dict = {'gobjeverse': gobjverse, 8 | 'GSO': GoogleObjsDataset, 9 | 'instant3d': Instant3DObjsDataset, 10 | 'mipnerf360': MipNeRF360Dataset, 11 | 'mvgen': MVGenDataset, 12 | } -------------------------------------------------------------------------------- /dataLoader/instant3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from glob import glob 4 | import imageio 5 | import tqdm 6 | from multiprocessing import Pool 7 | import copy 8 | import cv2 9 | import random 10 | from PIL import Image 11 | import torch 12 | import json 13 | from dataLoader.utils import build_rays 14 | from scipy.spatial.transform import Rotation as R 15 | from dataLoader.utils import intrinsic_to_fov, KMean 16 | 17 | class Instant3DObjsDataset(torch.utils.data.Dataset): 18 | def __init__(self, cfg): 19 | super(Instant3DObjsDataset, self).__init__() 20 | self.data_root = cfg.data_root 21 | 22 | self.img_size = np.array(cfg.img_size) 23 | 24 | scenes_name = np.array([f for f in sorted(os.listdir(self.data_root)) if f.endswith('png')]) 25 | self.scenes_name = scenes_name 26 | print(len(self.scenes_name)) 27 | 28 | self.build_camera() 29 | self.bg_color = 1.0 30 | 31 | def build_camera(self): 32 | scene_info = {'c2ws':[],'w2cs':[],'ixts':[]} 33 | json_info = json.load(open(os.path.join(self.data_root, f'opencv_cameras.json'))) 34 | 35 | for i in range(4): 36 | frame = json_info['frames'][i] 37 | w2c = np.array(frame['w2c']) 38 | c2w = np.linalg.inv(w2c) 39 | c2w[:3,3] /= 1.7 40 | w2c = np.linalg.inv(c2w) 41 | scene_info['c2ws'].append(c2w) 42 | scene_info['w2cs'].append(w2c) 43 | 44 | ixt = np.eye(3) 45 | ixt[[0,1],[0,1]] = np.array([frame['fx'],frame['fy']]) 46 | ixt[[0,1],[2,2]] = np.array([frame['cx'],frame['cy']]) 47 | scene_info['ixts'].append(ixt) 48 | 49 | scene_info['c2ws'] = np.stack(scene_info['c2ws']).astype(np.float32) 50 | scene_info['w2cs'] = np.stack(scene_info['w2cs']).astype(np.float32) 51 | scene_info['ixts'] = np.stack(scene_info['ixts']).astype(np.float32) 52 | 53 | self.scene_info = scene_info 54 | 55 | def __getitem__(self, index): 56 | 57 | 58 | scenes_name = self.scenes_name[index] 59 | # src_view_id = list(range(4)) 60 | # tar_views = src_view_id + list(range(4)) 61 | 62 | #np.random.rand(3) 63 | tar_img = self.read_image(scenes_name) 64 | tar_c2ws = self.scene_info['c2ws'] 65 | tar_w2cs = self.scene_info['w2cs'] 66 | tar_ixts = self.scene_info['ixts'] 67 | 68 | # align cameras using first view 69 | # no inver operation 70 | r = np.linalg.norm(tar_c2ws[0,:3,3]) 71 | ref_c2w = np.eye(4, dtype=np.float32).reshape(1,4,4) 72 | ref_w2c = np.eye(4, dtype=np.float32).reshape(1,4,4) 73 | ref_c2w[:,2,3], ref_w2c[:,2,3] = -r, r 74 | transform_mats = ref_c2w @ tar_w2cs[:1] 75 | tar_w2cs = tar_w2cs.copy() @ tar_c2ws[:1] @ ref_w2c 76 | tar_c2ws = transform_mats @ tar_c2ws.copy() 77 | 78 | fov_x, fov_y = intrinsic_to_fov(tar_ixts[0],w=512,h=512) 79 | 80 | ret = {'fovx':fov_x, 81 | 'fovy':fov_y, 82 | } 83 | H, W = self.img_size 84 | 85 | ret.update({'tar_c2w': tar_c2ws, 86 | 'tar_w2c': tar_w2cs, 87 | 'tar_ixt': tar_ixts, 88 | 'tar_rgb': tar_img.transpose(1,0,2,3).reshape(H,4*W,3), 89 | 'transform_mats': transform_mats 90 | }) 91 | near_far = np.array([r-1.0, r+1.0]).astype(np.float32) 92 | ret.update({'near_far': np.array(near_far).astype(np.float32)}) 93 | ret.update({'meta': {'scene':scenes_name,f'tar_h': int(H), f'tar_w': int(W)}}) 94 | 95 | rays = build_rays(tar_c2ws, tar_ixts.copy(), H, W, 1.0) 96 | ret.update({f'tar_rays': rays}) 97 | rays_down = build_rays(tar_c2ws, tar_ixts.copy(), H, W, 1.0/16) 98 | ret.update({f'tar_rays_down': rays_down}) 99 | return ret 100 | 101 | 102 | def read_image(self, scenes_name): 103 | 104 | img = imageio.imread(f'{self.data_root}/{scenes_name}') 105 | img = img.astype(np.float32) / 255. 106 | if img.shape[-1] == 4: 107 | img = (img[..., :3] * img[..., -1:] + self.bg_color*(1 - img[..., -1:])).astype(np.float32) 108 | 109 | # split images 110 | row_chunks = np.array_split(img, 2) 111 | imgs = np.stack([np.array_split(chunk, 2, axis=1) for chunk in row_chunks]).reshape(4,512,512,-1) 112 | return imgs 113 | 114 | 115 | def __len__(self): 116 | return len(self.scenes_name) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: lara 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - pip=24.0 8 | - python=3.9.19 9 | - zlib=1.2.13 10 | - pytorch=2.1.2 11 | - torchaudio=2.1.2 12 | - torchvision=0.16.2 13 | - pip: 14 | - open3d==0.18.0 15 | - diffusers==0.27.2 16 | - einops==0.7.0 17 | - fire==0.6.0 18 | - h5py==3.10.0 19 | - imageio==2.34.0 20 | - kornia==0.7.2 21 | - pytorch-lightning==2.1.3 22 | - lpips==0.1.4 23 | - matplotlib==3.8.4 24 | - numpy==1.26.3 25 | - omegaconf==2.3.0 26 | - opencv-python==4.9.0.80 27 | - pillow==10.2.0 28 | - pytorch-msssim==1.0.0 29 | - pyyaml==6.0.1 30 | - rembg==2.0.56 31 | - scikit-image==0.21.0 32 | - scikit-learn==1.4.0 33 | - scipy==1.13.0 34 | - tensorboardx==2.6.2.2 35 | - timm==0.9.12 36 | - torchmetrics==1.4.0 37 | - torchtyping==0.1.4 38 | - tqdm==4.66.2 39 | - transformers==4.25.1 40 | - trimesh==4.3.2 41 | - wandb==0.16.2 42 | - xformers==0.0.23.post1 43 | - tensorboard==2.16.2 44 | - open_clip_torch==2.24.0 45 | - streamlit==1.35.0 46 | - invisible-watermark==0.2.0 47 | - git+https://github.com/openai/CLIP.git 48 | - third_party/diff-surfel-rasterization -------------------------------------------------------------------------------- /eval_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | gpu_id = 0 4 | name = 'release' 5 | ckpt_path = f'ckpts/epoch=29.ckpt' 6 | 7 | for n_views in [4]: 8 | cmd = f'CUDA_VISIBLE_DEVICES={gpu_id} python evaluation.py configs/infer.yaml n_views={n_views} infer.eval_novel_view_only=True ' \ 9 | f'infer.ckpt_path={ckpt_path} infer.metric_path=outputs/metrics/{name}_GSO_{n_views}_views.json ' \ 10 | f'infer.dataset.dataset_name=GSO infer.dataset.data_root=dataset/google_scanned_objects infer.eval_depth=[0.005,0.01,0.02] ' \ 11 | f'infer.video_frames=0 infer.save_mesh=False ' \ 12 | f'infer.save_folder=outputs/image_vis/{name}_GSO_{n_views}_views infer.dataset.n_group={n_views} ' 13 | os.system(cmd) 14 | 15 | cmd = f'CUDA_VISIBLE_DEVICES={gpu_id} python evaluation.py configs/infer.yaml n_views={n_views} infer.eval_novel_view_only=True ' \ 16 | f'infer.ckpt_path={ckpt_path} infer.metric_path=outputs/metrics/{name}_gobjeverse_{n_views}_views.json ' \ 17 | f'infer.dataset.dataset_name=gobjeverse infer.dataset.data_root=dataset/gobjaverse/gobjaverse.h5 ' \ 18 | f'infer.video_frames=0 infer.save_mesh=False ' \ 19 | f'infer.save_folder=outputs/image_vis/{name}_gobjaverse_{n_views}_views infer.dataset.n_group={n_views} ' 20 | os.system(cmd) 21 | 22 | cmd = f'CUDA_VISIBLE_DEVICES={gpu_id} python evaluation.py configs/infer.yaml n_views={n_views} infer.eval_novel_view_only=True ' \ 23 | f'infer.ckpt_path={ckpt_path} infer.metric_path=outputs/metrics/{name}_co3d_teddybear_{n_views}_views.json ' \ 24 | f'infer.dataset.dataset_name=gobjeverse infer.dataset.data_root=dataset/Co3D/co3d_teddybear.h5 ' \ 25 | f'infer.video_frames=0 infer.save_mesh=False ' \ 26 | f'infer.save_folder=outputs/image_vis/{name}_co3d_teddybear infer.dataset.n_group={n_views} ' 27 | os.system(cmd) 28 | 29 | cmd = f'CUDA_VISIBLE_DEVICES={gpu_id} python evaluation.py configs/infer.yaml n_views={n_views} infer.eval_novel_view_only=True ' \ 30 | f'infer.ckpt_path={ckpt_path} infer.metric_path=outputs/metrics/{name}_co3d_hydrant_{n_views}_views.json ' \ 31 | f'infer.dataset.dataset_name=gobjeverse infer.dataset.data_root=dataset/Co3D/co3d_hydrant.h5 ' \ 32 | f'infer.video_frames=0 infer.save_mesh=False ' \ 33 | f'infer.save_folder=outputs/image_vis/{name}_co3d_hydrant infer.dataset.n_group={n_views} ' 34 | os.system(cmd) 35 | 36 | 37 | -------------------------------------------------------------------------------- /lightning/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_msssim import MS_SSIM 4 | from torch.nn import functional as F 5 | 6 | from torch.cuda.amp import autocast 7 | 8 | class Losses(nn.Module): 9 | def __init__(self): 10 | super(Losses, self).__init__() 11 | 12 | self.color_crit = nn.MSELoss(reduction='mean') 13 | self.mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 14 | 15 | self.ssim = MS_SSIM(data_range=1.0, size_average=True, channel=3) 16 | 17 | def forward(self, batch, output, iter): 18 | 19 | scalar_stats = {} 20 | loss = 0 21 | 22 | B,V,H,W = batch['tar_rgb'].shape[:-1] 23 | 24 | tar_rgb = batch['tar_rgb'].permute(0,2,1,3,4).reshape(B,H,V*W,3) 25 | 26 | 27 | if 'image' in output: 28 | 29 | for prex in ['','_fine']: 30 | 31 | 32 | if prex=='_fine' and f'acc_map{prex}' not in output: 33 | continue 34 | 35 | color_loss_all = (output[f'image{prex}']-tar_rgb)**2 36 | loss += color_loss_all.mean() 37 | 38 | psnr = -10. * torch.log(color_loss_all.detach().mean()) / \ 39 | torch.log(torch.Tensor([10.]).to(color_loss_all.device)) 40 | scalar_stats.update({f'mse{prex}': color_loss_all.mean().detach()}) 41 | scalar_stats.update({f'psnr{prex}': psnr}) 42 | 43 | 44 | with autocast(enabled=False): 45 | ssim_val = self.ssim(output[f'image{prex}'].permute(0,3,1,2), tar_rgb.permute(0,3,1,2)) 46 | scalar_stats.update({f'ssim{prex}': ssim_val.detach()}) 47 | loss += 0.5 * (1-ssim_val) 48 | 49 | if f'rend_dist{prex}' in output and iter>1000 and prex!='_fine': 50 | distortion = output[f"rend_dist{prex}"].mean() 51 | scalar_stats.update({f'distortion{prex}': distortion.detach()}) 52 | loss += distortion*1000 53 | 54 | rend_normal = output[f'rend_normal{prex}'] 55 | depth_normal = output[f'depth_normal{prex}'] 56 | acc_map = output[f'acc_map{prex}'].detach() 57 | 58 | normal_error = ((1 - (rend_normal * depth_normal).sum(dim=-1))*acc_map).mean() 59 | scalar_stats.update({f'normal{prex}': normal_error.detach()}) 60 | loss += normal_error*0.2 61 | 62 | return loss, scalar_stats 63 | 64 | -------------------------------------------------------------------------------- /lightning/system.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from lightning.loss import Losses 5 | import pytorch_lightning as L 6 | 7 | import torch.nn as nn 8 | from lightning.vis import vis_images 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | from lightning.utils import CosineWarmupScheduler 11 | 12 | from lightning.network import Network 13 | 14 | class system(L.LightningModule): 15 | def __init__(self, cfg): 16 | super().__init__() 17 | 18 | self.cfg = cfg 19 | self.loss = Losses() 20 | self.net = Network(cfg) 21 | 22 | self.validation_step_outputs = [] 23 | 24 | def training_step(self, batch, batch_idx): 25 | 26 | output = self.net(batch, with_fine=self.global_step>self.cfg.train.start_fine) 27 | loss, scalar_stats = self.loss(batch, output, self.global_step) 28 | for key, value in scalar_stats.items(): 29 | prog_bar = True if key in ['psnr','mask','depth'] else False 30 | self.log(f'train/{key}', value, prog_bar=prog_bar) 31 | self.log('lr',self.trainer.optimizers[0].param_groups[0]['lr']) 32 | 33 | if 0 == self.trainer.global_step % 3000 and (self.trainer.local_rank == 0): 34 | self.vis_results(output, batch, prex='train') 35 | 36 | return loss 37 | 38 | def validation_step(self, batch, batch_idx): 39 | self.net.eval() 40 | output = self.net(batch, with_fine=self.global_step>self.cfg.train.start_fine) 41 | loss, scalar_stats = self.loss(batch, output, self.global_step) 42 | if batch_idx == 0 and (self.trainer.local_rank == 0): 43 | self.vis_results(output, batch, prex='val') 44 | self.validation_step_outputs.append(scalar_stats) 45 | return loss 46 | 47 | def on_validation_epoch_end(self): 48 | keys = self.validation_step_outputs[0] 49 | for key in keys: 50 | prog_bar = True if key in ['psnr','mask','depth'] else False 51 | metric_mean = torch.stack([x[key] for x in self.validation_step_outputs]).mean() 52 | self.log(f'val/{key}', metric_mean, prog_bar=prog_bar, sync_dist=True) 53 | 54 | self.validation_step_outputs.clear() # free memory 55 | torch.cuda.empty_cache() 56 | 57 | def vis_results(self, output, batch, prex): 58 | output_vis = vis_images(output, batch) 59 | for key, value in output_vis.items(): 60 | if isinstance(self.logger, TensorBoardLogger): 61 | B,h,w = value.shape[:3] 62 | value = value.reshape(1,B*h,w,3).transpose(0,3,1,2) 63 | self.logger.experiment.add_images(f'{prex}/{key}', value, self.global_step) 64 | else: 65 | imgs = [np.concatenate([img for img in value],axis=0)] 66 | self.logger.log_image(f'{prex}/{key}', imgs, step=self.global_step) 67 | self.net.train() 68 | 69 | def num_steps(self) -> int: 70 | """Get number of steps""" 71 | # Accessing _data_source is flaky and might break 72 | dataset = self.trainer.fit_loop._data_source.dataloader() 73 | dataset_size = len(dataset) 74 | num_devices = max(1, self.trainer.num_devices) 75 | num_steps = dataset_size * self.trainer.max_epochs * self.cfg.train.limit_train_batches // (self.trainer.accumulate_grad_batches * num_devices) 76 | return int(num_steps) 77 | 78 | def configure_optimizers(self): 79 | decay_params, no_decay_params = [], [] 80 | 81 | # add all bias and LayerNorm params to no_decay_params 82 | for name, module in self.named_modules(): 83 | if isinstance(module, nn.LayerNorm): 84 | no_decay_params.extend([p for p in module.parameters()]) 85 | elif hasattr(module, 'bias') and module.bias is not None: 86 | no_decay_params.append(module.bias) 87 | 88 | # add remaining parameters to decay_params 89 | _no_decay_ids = set(map(id, no_decay_params)) 90 | decay_params = [p for p in self.parameters() if id(p) not in _no_decay_ids] 91 | 92 | # filter out parameters with no grad 93 | decay_params = list(filter(lambda p: p.requires_grad, decay_params)) 94 | no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params)) 95 | 96 | # Optimizer 97 | opt_groups = [ 98 | {'params': decay_params, 'weight_decay': self.cfg.train.weight_decay}, 99 | {'params': no_decay_params, 'weight_decay': 0.0}, 100 | ] 101 | optimizer = torch.optim.AdamW( 102 | opt_groups, 103 | lr=self.cfg.train.lr, 104 | betas=(self.cfg.train.beta1, self.cfg.train.beta2), 105 | ) 106 | 107 | total_global_batches = self.num_steps() 108 | scheduler = CosineWarmupScheduler( 109 | optimizer=optimizer, 110 | warmup_iters=self.cfg.train.warmup_iters, 111 | max_iters=total_global_batches, 112 | ) 113 | 114 | return {"optimizer": optimizer, 115 | "lr_scheduler": { 116 | 'scheduler': scheduler, 117 | 'interval': 'step' # or 'epoch' for epoch-level updates 118 | }} -------------------------------------------------------------------------------- /lightning/utils.py: -------------------------------------------------------------------------------- 1 | import torch, os, json, math 2 | import numpy as np 3 | from torch.optim.lr_scheduler import LRScheduler 4 | 5 | def getProjectionMatrix(znear, zfar, fovX, fovY): 6 | 7 | tanHalfFovY = torch.tan((fovY / 2)) 8 | tanHalfFovX = torch.tan((fovX / 2)) 9 | 10 | P = torch.zeros(4, 4) 11 | 12 | z_sign = 1.0 13 | 14 | P[0, 0] = 1 / tanHalfFovX 15 | P[1, 1] = 1 / tanHalfFovY 16 | P[3, 2] = z_sign 17 | P[2, 2] = z_sign * zfar / (zfar - znear) 18 | P[2, 3] = -(zfar * znear) / (zfar - znear) 19 | return P 20 | 21 | 22 | class MiniCam: 23 | def __init__(self, c2w, width, height, fovy, fovx, znear, zfar, device): 24 | # c2w (pose) should be in NeRF convention. 25 | 26 | self.image_width = width 27 | self.image_height = height 28 | self.FoVy = fovy 29 | self.FoVx = fovx 30 | self.znear = znear 31 | self.zfar = zfar 32 | 33 | w2c = torch.inverse(c2w) 34 | 35 | # rectify... 36 | # w2c[1:3, :3] *= -1 37 | # w2c[:3, 3] *= -1 38 | 39 | self.world_view_transform = w2c.transpose(0, 1).to(device) 40 | self.projection_matrix = ( 41 | getProjectionMatrix( 42 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy 43 | ) 44 | .transpose(0, 1) 45 | .to(device) 46 | ) 47 | self.full_proj_transform = (self.world_view_transform @ self.projection_matrix).float() 48 | self.camera_center = -c2w[:3, 3].to(device) 49 | 50 | 51 | def rotation_matrix_to_quaternion(R): 52 | tr = R[0, 0] + R[1, 1] + R[2, 2] 53 | if tr > 0: 54 | S = torch.sqrt(tr + 1.0) * 2.0 55 | qw = 0.25 * S 56 | qx = (R[2, 1] - R[1, 2]) / S 57 | qy = (R[0, 2] - R[2, 0]) / S 58 | qz = (R[1, 0] - R[0, 1]) / S 59 | elif (R[0, 0] > R[1, 1]) and (R[0, 0] > R[2, 2]): 60 | S = torch.sqrt(1.0 + R[0, 0] - R[1, 1] - R[2, 2]) * 2.0 61 | qw = (R[2, 1] - R[1, 2]) / S 62 | qx = 0.25 * S 63 | qy = (R[0, 1] + R[1, 0]) / S 64 | qz = (R[0, 2] + R[2, 0]) / S 65 | elif R[1, 1] > R[2, 2]: 66 | S = torch.sqrt(1.0 + R[1, 1] - R[0, 0] - R[2, 2]) * 2.0 67 | qw = (R[0, 2] - R[2, 0]) / S 68 | qx = (R[0, 1] + R[1, 0]) / S 69 | qy = 0.25 * S 70 | qz = (R[1, 2] + R[2, 1]) / S 71 | else: 72 | S = torch.sqrt(1.0 + R[2, 2] - R[0, 0] - R[1, 1]) * 2.0 73 | qw = (R[1, 0] - R[0, 1]) / S 74 | qx = (R[0, 2] + R[2, 0]) / S 75 | qy = (R[1, 2] + R[2, 1]) / S 76 | qz = 0.25 * S 77 | return torch.stack([qw, qx, qy, qz], dim=1) 78 | 79 | def rotate_quaternions(q, R): 80 | # Convert quaternions to rotation matrices 81 | q = torch.cat([q[:, :1], -q[:, 1:]], dim=1) 82 | q = torch.cat([q[:, :3], q[:, 3:] * -1], dim=1) 83 | rotated_R = torch.matmul(torch.matmul(q, R), q.inverse()) 84 | 85 | # Convert the rotated rotation matrices back to quaternions 86 | return rotation_matrix_to_quaternion(rotated_R) 87 | 88 | # this function is borrowed from OpenLRM 89 | class CosineWarmupScheduler(LRScheduler): 90 | def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1): 91 | self.warmup_iters = warmup_iters 92 | self.max_iters = max_iters 93 | self.initial_lr = initial_lr 94 | super().__init__(optimizer, last_iter) 95 | 96 | def get_lr(self): 97 | 98 | if self._step_count <= self.warmup_iters: 99 | return [ 100 | self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters 101 | for base_lr in self.base_lrs] 102 | else: 103 | cos_iter = self._step_count - self.warmup_iters 104 | cos_max_iter = self.max_iters - self.warmup_iters 105 | cos_theta = cos_iter / cos_max_iter * math.pi 106 | cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs] 107 | return cos_lr -------------------------------------------------------------------------------- /lightning/vis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tools.img_utils import visualize_depth_numpy 4 | 5 | 6 | 7 | def vis_appearance_depth(output, batch): 8 | outputs = {} 9 | B, V, H, W = batch['tar_rgb'].shape[:-1] 10 | 11 | pred_rgb = output[f'image'].detach().cpu().numpy() 12 | pred_depth = output[f'depth'].detach().cpu().numpy() 13 | gt_rgb = batch[f'tar_rgb'].permute(0,2,1,3,4).reshape(B, H, V*W, 3).detach().cpu().numpy() 14 | 15 | near_far = batch['near_far'][0].tolist() 16 | pred_depth_colorlized = np.stack([visualize_depth_numpy(_depth, near_far) for _depth in pred_depth]).astype('float32')/255 17 | outputs.update({f"gt_rgb":gt_rgb, f"pred_rgb":pred_rgb, f"pred_depth":pred_depth_colorlized}) 18 | 19 | 20 | if 'rend_normal' in output: 21 | rend_normal = torch.nn.functional.normalize(output[f'rend_normal'].detach(),dim=-1) 22 | rend_normal = rend_normal.cpu().numpy() 23 | outputs.update({f"rend_normal":(rend_normal+1)/2}) 24 | 25 | depth_normal = output[f'depth_normal'].detach().cpu().numpy() 26 | outputs.update({f"depth_normal":(depth_normal+1)/2}) 27 | 28 | if 'tar_nrm' in batch: 29 | normal_gt = batch['tar_nrm'].cpu().numpy() 30 | outputs.update({f"normal_gt":(normal_gt+1)/2}) 31 | 32 | 33 | if 'img_tri' in output: 34 | img_tri = output['img_tri'].detach().cpu().permute(0,2,3,1).numpy() 35 | outputs.update({f"img_tri": img_tri}) 36 | if 'feats_tri' in output: 37 | feats_tri = output['feats_tri'].detach().cpu().permute(0,2,3,1).numpy() 38 | outputs.update({f"feats_tri": feats_tri}) 39 | 40 | if 'image_fine' in output: 41 | rgb_fine = output[f'image_fine'].detach().cpu().numpy() 42 | outputs.update({f"rgb_fine":rgb_fine}) 43 | 44 | pred_depth_fine = output[f'depth_fine'].detach().cpu().numpy() 45 | pred_depth_fine_colorlized = np.stack([visualize_depth_numpy(_depth, near_far) for _depth in pred_depth_fine]).astype('float32')/255 46 | outputs.update({f"pred_depth_fine":pred_depth_fine_colorlized}) 47 | 48 | if 'rend_normal_fine' in output: 49 | rend_normal_fine = torch.nn.functional.normalize(output[f'rend_normal_fine'].detach(),dim=-1) 50 | rend_normal_fine = rend_normal_fine.cpu().numpy() 51 | outputs.update({f"rend_normal_fine":(rend_normal_fine+1)/2}) 52 | 53 | if 'depth_normal_fine' in output: 54 | depth_normal_fine = output[f'depth_normal_fine'].detach().cpu().numpy() 55 | outputs.update({f"depth_normal_fine":(depth_normal_fine+1)/2}) 56 | 57 | return outputs 58 | 59 | def vis_depth(output, batch): 60 | 61 | outputs = {} 62 | B, S, _, H, W = batch['src_inps'].shape 63 | h, w = batch['src_deps'].shape[-2:] 64 | 65 | near_far = batch['near_far'][0].tolist() 66 | gt_src_depth = batch['src_deps'].reshape(B,-1, h, w).cpu().permute(0,2,1,3).numpy().reshape(B,h,-1) 67 | mask = gt_src_depth > 0 68 | pred_src_depth = output['pred_src_depth'].reshape(B,-1, h, w).detach().cpu().permute(0,2,1,3).numpy().reshape(B,h,-1) 69 | pred_src_depth[~mask] = 0.0 70 | depth_err = np.abs(gt_src_depth-pred_src_depth)*2 71 | gt_src_depth_colorlized = np.stack([visualize_depth_numpy(_depth, near_far) for _depth in gt_src_depth]).astype('float32')/255 72 | pred_src_depth_colorlized = np.stack([visualize_depth_numpy(_depth, near_far) for _depth in pred_src_depth]).astype('float32')/255 73 | depth_err_colorlized = np.stack([visualize_depth_numpy(_err, near_far) for _err in depth_err]).astype('float32')/255 74 | rgb_source = batch['src_inps'].reshape(B,S, 3, H, W).detach().cpu().permute(0,3,1,4,2).numpy().reshape(B,H,-1,3) 75 | 76 | outputs.update({f"rgb_source": rgb_source, "gt_src_depth": gt_src_depth_colorlized, 77 | "pred_src_depth":pred_src_depth_colorlized, "depth_err":depth_err_colorlized}) 78 | 79 | return outputs 80 | 81 | def vis_images(output, batch): 82 | if 'image' in output: 83 | return vis_appearance_depth(output, batch) 84 | else: 85 | return vis_depth(output, batch) 86 | -------------------------------------------------------------------------------- /third_party/image_generator/.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Run black 2 | on: [pull_request] 3 | 4 | jobs: 5 | lint: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v3 9 | - name: Install venv 10 | run: | 11 | sudo apt-get -y install python3.10-venv 12 | - uses: psf/black@stable 13 | with: 14 | options: "--check --verbose -l88" 15 | src: "./sgm ./scripts ./main.py" 16 | -------------------------------------------------------------------------------- /third_party/image_generator/.github/workflows/test-build.yaml: -------------------------------------------------------------------------------- 1 | name: Build package 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | 8 | jobs: 9 | build: 10 | name: Build 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.8", "3.10"] 16 | requirements-file: ["pt2", "pt13"] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install -r requirements/${{ matrix.requirements-file }}.txt 27 | pip install . -------------------------------------------------------------------------------- /third_party/image_generator/.github/workflows/test-inference.yml: -------------------------------------------------------------------------------- 1 | name: Test inference 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | test: 11 | name: "Test inference" 12 | # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment 13 | if: github.repository == 'stability-ai/generative-models' 14 | runs-on: [self-hosted, slurm, g40] 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: "Symlink checkpoints" 18 | run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints 19 | - name: "Setup python" 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: "3.10" 23 | - name: "Install Hatch" 24 | run: pip install hatch 25 | - name: "Run inference tests" 26 | run: hatch run ci:test-inference --junit-xml test-results.xml 27 | - name: Surface failing tests 28 | if: always() 29 | uses: pmeier/pytest-results-action@main 30 | with: 31 | path: test-results.xml 32 | summary: true 33 | display-options: fEX 34 | fail-on-empty: true 35 | -------------------------------------------------------------------------------- /third_party/image_generator/.gitignore: -------------------------------------------------------------------------------- 1 | # extensions 2 | *.egg-info 3 | *.py[cod] 4 | 5 | # envs 6 | .pt13 7 | .pt2 8 | 9 | # directories 10 | /checkpoints 11 | /dist 12 | /outputs 13 | /build 14 | /src -------------------------------------------------------------------------------- /third_party/image_generator/CODEOWNERS: -------------------------------------------------------------------------------- 1 | .github @Stability-AI/infrastructure -------------------------------------------------------------------------------- /third_party/image_generator/LICENSE-CODE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /third_party/image_generator/assets/000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/assets/000.jpg -------------------------------------------------------------------------------- /third_party/image_generator/assets/sv3d.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/assets/sv3d.gif -------------------------------------------------------------------------------- /third_party/image_generator/assets/tile.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/assets/tile.gif -------------------------------------------------------------------------------- /third_party/image_generator/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: sgm.models.autoencoder.AutoencodingEngine 4 | params: 5 | input_key: jpg 6 | monitor: val/rec_loss 7 | 8 | loss_config: 9 | target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator 10 | params: 11 | perceptual_weight: 0.25 12 | disc_start: 20001 13 | disc_weight: 0.5 14 | learn_logvar: True 15 | 16 | regularization_weights: 17 | kl_loss: 1.0 18 | 19 | regularizer_config: 20 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 21 | 22 | encoder_config: 23 | target: sgm.modules.diffusionmodules.model.Encoder 24 | params: 25 | attn_type: none 26 | double_z: True 27 | z_channels: 4 28 | resolution: 256 29 | in_channels: 3 30 | out_ch: 3 31 | ch: 128 32 | ch_mult: [1, 2, 4] 33 | num_res_blocks: 4 34 | attn_resolutions: [] 35 | dropout: 0.0 36 | 37 | decoder_config: 38 | target: sgm.modules.diffusionmodules.model.Decoder 39 | params: ${model.params.encoder_config.params} 40 | 41 | data: 42 | target: sgm.data.dataset.StableDataModuleFromConfig 43 | params: 44 | train: 45 | datapipeline: 46 | urls: 47 | - DATA-PATH 48 | pipeline_config: 49 | shardshuffle: 10000 50 | sample_shuffle: 10000 51 | 52 | decoders: 53 | - pil 54 | 55 | postprocessors: 56 | - target: sdata.mappers.TorchVisionImageTransforms 57 | params: 58 | key: jpg 59 | transforms: 60 | - target: torchvision.transforms.Resize 61 | params: 62 | size: 256 63 | interpolation: 3 64 | - target: torchvision.transforms.ToTensor 65 | - target: sdata.mappers.Rescaler 66 | - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare 67 | params: 68 | h_key: height 69 | w_key: width 70 | 71 | loader: 72 | batch_size: 8 73 | num_workers: 4 74 | 75 | 76 | lightning: 77 | strategy: 78 | target: pytorch_lightning.strategies.DDPStrategy 79 | params: 80 | find_unused_parameters: True 81 | 82 | modelcheckpoint: 83 | params: 84 | every_n_train_steps: 5000 85 | 86 | callbacks: 87 | metrics_over_trainsteps_checkpoint: 88 | params: 89 | every_n_train_steps: 50000 90 | 91 | image_logger: 92 | target: main.ImageLogger 93 | params: 94 | enable_autocast: False 95 | batch_frequency: 1000 96 | max_images: 8 97 | increase_log_steps: True 98 | 99 | trainer: 100 | devices: 0, 101 | limit_val_batches: 50 102 | benchmark: True 103 | accumulate_grad_batches: 1 104 | val_check_interval: 10000 -------------------------------------------------------------------------------- /third_party/image_generator/configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: sgm.models.autoencoder.AutoencodingEngine 4 | params: 5 | input_key: jpg 6 | monitor: val/loss/rec 7 | disc_start_iter: 0 8 | 9 | encoder_config: 10 | target: sgm.modules.diffusionmodules.model.Encoder 11 | params: 12 | attn_type: vanilla-xformers 13 | double_z: true 14 | z_channels: 8 15 | resolution: 256 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | 24 | decoder_config: 25 | target: sgm.modules.diffusionmodules.model.Decoder 26 | params: ${model.params.encoder_config.params} 27 | 28 | regularizer_config: 29 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 30 | 31 | loss_config: 32 | target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator 33 | params: 34 | perceptual_weight: 0.25 35 | disc_start: 20001 36 | disc_weight: 0.5 37 | learn_logvar: True 38 | 39 | regularization_weights: 40 | kl_loss: 1.0 41 | 42 | data: 43 | target: sgm.data.dataset.StableDataModuleFromConfig 44 | params: 45 | train: 46 | datapipeline: 47 | urls: 48 | - DATA-PATH 49 | pipeline_config: 50 | shardshuffle: 10000 51 | sample_shuffle: 10000 52 | 53 | decoders: 54 | - pil 55 | 56 | postprocessors: 57 | - target: sdata.mappers.TorchVisionImageTransforms 58 | params: 59 | key: jpg 60 | transforms: 61 | - target: torchvision.transforms.Resize 62 | params: 63 | size: 256 64 | interpolation: 3 65 | - target: torchvision.transforms.ToTensor 66 | - target: sdata.mappers.Rescaler 67 | - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare 68 | params: 69 | h_key: height 70 | w_key: width 71 | 72 | loader: 73 | batch_size: 8 74 | num_workers: 4 75 | 76 | 77 | lightning: 78 | strategy: 79 | target: pytorch_lightning.strategies.DDPStrategy 80 | params: 81 | find_unused_parameters: True 82 | 83 | modelcheckpoint: 84 | params: 85 | every_n_train_steps: 5000 86 | 87 | callbacks: 88 | metrics_over_trainsteps_checkpoint: 89 | params: 90 | every_n_train_steps: 50000 91 | 92 | image_logger: 93 | target: main.ImageLogger 94 | params: 95 | enable_autocast: False 96 | batch_frequency: 1000 97 | max_images: 8 98 | increase_log_steps: True 99 | 100 | trainer: 101 | devices: 0, 102 | limit_val_batches: 50 103 | benchmark: True 104 | accumulate_grad_batches: 1 105 | val_check_interval: 10000 106 | -------------------------------------------------------------------------------- /third_party/image_generator/configs/example_training/toy/cifar10_cond.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | denoiser_config: 6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 7 | params: 8 | scaling_config: 9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 10 | params: 11 | sigma_data: 1.0 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 15 | params: 16 | in_channels: 3 17 | out_channels: 3 18 | model_channels: 32 19 | attention_resolutions: [] 20 | num_res_blocks: 4 21 | channel_mult: [1, 2, 2] 22 | num_head_channels: 32 23 | num_classes: sequential 24 | adm_in_channels: 128 25 | 26 | conditioner_config: 27 | target: sgm.modules.GeneralConditioner 28 | params: 29 | emb_models: 30 | - is_trainable: True 31 | input_key: cls 32 | ucg_rate: 0.2 33 | target: sgm.modules.encoders.modules.ClassEmbedder 34 | params: 35 | embed_dim: 128 36 | n_classes: 10 37 | 38 | first_stage_config: 39 | target: sgm.models.autoencoder.IdentityFirstStage 40 | 41 | loss_fn_config: 42 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 43 | params: 44 | loss_weighting_config: 45 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 46 | params: 47 | sigma_data: 1.0 48 | sigma_sampler_config: 49 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 50 | 51 | sampler_config: 52 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 53 | params: 54 | num_steps: 50 55 | 56 | discretization_config: 57 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 58 | 59 | guider_config: 60 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 61 | params: 62 | scale: 3.0 63 | 64 | data: 65 | target: sgm.data.cifar10.CIFAR10Loader 66 | params: 67 | batch_size: 512 68 | num_workers: 1 69 | 70 | lightning: 71 | modelcheckpoint: 72 | params: 73 | every_n_train_steps: 5000 74 | 75 | callbacks: 76 | metrics_over_trainsteps_checkpoint: 77 | params: 78 | every_n_train_steps: 25000 79 | 80 | image_logger: 81 | target: main.ImageLogger 82 | params: 83 | disabled: False 84 | batch_frequency: 1000 85 | max_images: 64 86 | increase_log_steps: True 87 | log_first_step: False 88 | log_images_kwargs: 89 | use_ema_scope: False 90 | N: 64 91 | n_rows: 8 92 | 93 | trainer: 94 | devices: 0, 95 | benchmark: True 96 | num_sanity_val_steps: 0 97 | accumulate_grad_batches: 1 98 | max_epochs: 20 -------------------------------------------------------------------------------- /third_party/image_generator/configs/example_training/toy/mnist.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | denoiser_config: 6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 7 | params: 8 | scaling_config: 9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 10 | params: 11 | sigma_data: 1.0 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 15 | params: 16 | in_channels: 1 17 | out_channels: 1 18 | model_channels: 32 19 | attention_resolutions: [] 20 | num_res_blocks: 4 21 | channel_mult: [1, 2, 2] 22 | num_head_channels: 32 23 | 24 | first_stage_config: 25 | target: sgm.models.autoencoder.IdentityFirstStage 26 | 27 | loss_fn_config: 28 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 29 | params: 30 | loss_weighting_config: 31 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 32 | params: 33 | sigma_data: 1.0 34 | sigma_sampler_config: 35 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 36 | 37 | sampler_config: 38 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 39 | params: 40 | num_steps: 50 41 | 42 | discretization_config: 43 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 44 | 45 | data: 46 | target: sgm.data.mnist.MNISTLoader 47 | params: 48 | batch_size: 512 49 | num_workers: 1 50 | 51 | lightning: 52 | modelcheckpoint: 53 | params: 54 | every_n_train_steps: 5000 55 | 56 | callbacks: 57 | metrics_over_trainsteps_checkpoint: 58 | params: 59 | every_n_train_steps: 25000 60 | 61 | image_logger: 62 | target: main.ImageLogger 63 | params: 64 | disabled: False 65 | batch_frequency: 1000 66 | max_images: 64 67 | increase_log_steps: False 68 | log_first_step: False 69 | log_images_kwargs: 70 | use_ema_scope: False 71 | N: 64 72 | n_rows: 8 73 | 74 | trainer: 75 | devices: 0, 76 | benchmark: True 77 | num_sanity_val_steps: 0 78 | accumulate_grad_batches: 1 79 | max_epochs: 10 -------------------------------------------------------------------------------- /third_party/image_generator/configs/example_training/toy/mnist_cond.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | denoiser_config: 6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 7 | params: 8 | scaling_config: 9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 10 | params: 11 | sigma_data: 1.0 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 15 | params: 16 | in_channels: 1 17 | out_channels: 1 18 | model_channels: 32 19 | attention_resolutions: [] 20 | num_res_blocks: 4 21 | channel_mult: [1, 2, 2] 22 | num_head_channels: 32 23 | num_classes: sequential 24 | adm_in_channels: 128 25 | 26 | conditioner_config: 27 | target: sgm.modules.GeneralConditioner 28 | params: 29 | emb_models: 30 | - is_trainable: True 31 | input_key: cls 32 | ucg_rate: 0.2 33 | target: sgm.modules.encoders.modules.ClassEmbedder 34 | params: 35 | embed_dim: 128 36 | n_classes: 10 37 | 38 | first_stage_config: 39 | target: sgm.models.autoencoder.IdentityFirstStage 40 | 41 | loss_fn_config: 42 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 43 | params: 44 | loss_weighting_config: 45 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 46 | params: 47 | sigma_data: 1.0 48 | sigma_sampler_config: 49 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 50 | 51 | sampler_config: 52 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 53 | params: 54 | num_steps: 50 55 | 56 | discretization_config: 57 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 58 | 59 | guider_config: 60 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 61 | params: 62 | scale: 3.0 63 | 64 | data: 65 | target: sgm.data.mnist.MNISTLoader 66 | params: 67 | batch_size: 512 68 | num_workers: 1 69 | 70 | lightning: 71 | modelcheckpoint: 72 | params: 73 | every_n_train_steps: 5000 74 | 75 | callbacks: 76 | metrics_over_trainsteps_checkpoint: 77 | params: 78 | every_n_train_steps: 25000 79 | 80 | image_logger: 81 | target: main.ImageLogger 82 | params: 83 | disabled: False 84 | batch_frequency: 1000 85 | max_images: 16 86 | increase_log_steps: True 87 | log_first_step: False 88 | log_images_kwargs: 89 | use_ema_scope: False 90 | N: 16 91 | n_rows: 4 92 | 93 | trainer: 94 | devices: 0, 95 | benchmark: True 96 | num_sanity_val_steps: 0 97 | accumulate_grad_batches: 1 98 | max_epochs: 20 -------------------------------------------------------------------------------- /third_party/image_generator/configs/example_training/toy/mnist_cond_discrete_eps.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | denoiser_config: 6 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 7 | params: 8 | num_idx: 1000 9 | 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 12 | discretization_config: 13 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 14 | 15 | network_config: 16 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | in_channels: 1 19 | out_channels: 1 20 | model_channels: 32 21 | attention_resolutions: [] 22 | num_res_blocks: 4 23 | channel_mult: [1, 2, 2] 24 | num_head_channels: 32 25 | num_classes: sequential 26 | adm_in_channels: 128 27 | 28 | conditioner_config: 29 | target: sgm.modules.GeneralConditioner 30 | params: 31 | emb_models: 32 | - is_trainable: True 33 | input_key: cls 34 | ucg_rate: 0.2 35 | target: sgm.modules.encoders.modules.ClassEmbedder 36 | params: 37 | embed_dim: 128 38 | n_classes: 10 39 | 40 | first_stage_config: 41 | target: sgm.models.autoencoder.IdentityFirstStage 42 | 43 | loss_fn_config: 44 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 45 | params: 46 | loss_weighting_config: 47 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 48 | sigma_sampler_config: 49 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 50 | params: 51 | num_idx: 1000 52 | 53 | discretization_config: 54 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 55 | 56 | sampler_config: 57 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 58 | params: 59 | num_steps: 50 60 | 61 | discretization_config: 62 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 63 | 64 | guider_config: 65 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 66 | params: 67 | scale: 5.0 68 | 69 | data: 70 | target: sgm.data.mnist.MNISTLoader 71 | params: 72 | batch_size: 512 73 | num_workers: 1 74 | 75 | lightning: 76 | modelcheckpoint: 77 | params: 78 | every_n_train_steps: 5000 79 | 80 | callbacks: 81 | metrics_over_trainsteps_checkpoint: 82 | params: 83 | every_n_train_steps: 25000 84 | 85 | image_logger: 86 | target: main.ImageLogger 87 | params: 88 | disabled: False 89 | batch_frequency: 1000 90 | max_images: 16 91 | increase_log_steps: True 92 | log_first_step: False 93 | log_images_kwargs: 94 | use_ema_scope: False 95 | N: 16 96 | n_rows: 4 97 | 98 | trainer: 99 | devices: 0, 100 | benchmark: True 101 | num_sanity_val_steps: 0 102 | accumulate_grad_batches: 1 103 | max_epochs: 20 -------------------------------------------------------------------------------- /third_party/image_generator/configs/example_training/toy/mnist_cond_l1_loss.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | denoiser_config: 6 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 7 | params: 8 | scaling_config: 9 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 10 | params: 11 | sigma_data: 1.0 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 15 | params: 16 | in_channels: 1 17 | out_channels: 1 18 | model_channels: 32 19 | attention_resolutions: [] 20 | num_res_blocks: 4 21 | channel_mult: [1, 2, 2] 22 | num_head_channels: 32 23 | num_classes: sequential 24 | adm_in_channels: 128 25 | 26 | conditioner_config: 27 | target: sgm.modules.GeneralConditioner 28 | params: 29 | emb_models: 30 | - is_trainable: True 31 | input_key: cls 32 | ucg_rate: 0.2 33 | target: sgm.modules.encoders.modules.ClassEmbedder 34 | params: 35 | embed_dim: 128 36 | n_classes: 10 37 | 38 | first_stage_config: 39 | target: sgm.models.autoencoder.IdentityFirstStage 40 | 41 | loss_fn_config: 42 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 43 | params: 44 | loss_type: l1 45 | loss_weighting_config: 46 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 47 | params: 48 | sigma_data: 1.0 49 | sigma_sampler_config: 50 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 51 | 52 | sampler_config: 53 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 54 | params: 55 | num_steps: 50 56 | 57 | discretization_config: 58 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 59 | 60 | guider_config: 61 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 62 | params: 63 | scale: 3.0 64 | 65 | data: 66 | target: sgm.data.mnist.MNISTLoader 67 | params: 68 | batch_size: 512 69 | num_workers: 1 70 | 71 | lightning: 72 | modelcheckpoint: 73 | params: 74 | every_n_train_steps: 5000 75 | 76 | callbacks: 77 | metrics_over_trainsteps_checkpoint: 78 | params: 79 | every_n_train_steps: 25000 80 | 81 | image_logger: 82 | target: main.ImageLogger 83 | params: 84 | disabled: False 85 | batch_frequency: 1000 86 | max_images: 64 87 | increase_log_steps: True 88 | log_first_step: False 89 | log_images_kwargs: 90 | use_ema_scope: False 91 | N: 64 92 | n_rows: 8 93 | 94 | trainer: 95 | devices: 0, 96 | benchmark: True 97 | num_sanity_val_steps: 0 98 | accumulate_grad_batches: 1 99 | max_epochs: 20 -------------------------------------------------------------------------------- /third_party/image_generator/configs/example_training/toy/mnist_cond_with_ema.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | use_ema: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 9 | params: 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling 12 | params: 13 | sigma_data: 1.0 14 | 15 | network_config: 16 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | in_channels: 1 19 | out_channels: 1 20 | model_channels: 32 21 | attention_resolutions: [] 22 | num_res_blocks: 4 23 | channel_mult: [1, 2, 2] 24 | num_head_channels: 32 25 | num_classes: sequential 26 | adm_in_channels: 128 27 | 28 | conditioner_config: 29 | target: sgm.modules.GeneralConditioner 30 | params: 31 | emb_models: 32 | - is_trainable: True 33 | input_key: cls 34 | ucg_rate: 0.2 35 | target: sgm.modules.encoders.modules.ClassEmbedder 36 | params: 37 | embed_dim: 128 38 | n_classes: 10 39 | 40 | first_stage_config: 41 | target: sgm.models.autoencoder.IdentityFirstStage 42 | 43 | loss_fn_config: 44 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 45 | params: 46 | loss_weighting_config: 47 | target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 48 | params: 49 | sigma_data: 1.0 50 | sigma_sampler_config: 51 | target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 52 | 53 | sampler_config: 54 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 55 | params: 56 | num_steps: 50 57 | 58 | discretization_config: 59 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 60 | 61 | guider_config: 62 | target: sgm.modules.diffusionmodules.guiders.VanillaCFG 63 | params: 64 | scale: 3.0 65 | 66 | data: 67 | target: sgm.data.mnist.MNISTLoader 68 | params: 69 | batch_size: 512 70 | num_workers: 1 71 | 72 | lightning: 73 | modelcheckpoint: 74 | params: 75 | every_n_train_steps: 5000 76 | 77 | callbacks: 78 | metrics_over_trainsteps_checkpoint: 79 | params: 80 | every_n_train_steps: 25000 81 | 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | disabled: False 86 | batch_frequency: 1000 87 | max_images: 64 88 | increase_log_steps: True 89 | log_first_step: False 90 | log_images_kwargs: 91 | use_ema_scope: False 92 | N: 64 93 | n_rows: 8 94 | 95 | trainer: 96 | devices: 0, 97 | benchmark: True 98 | num_sanity_val_steps: 0 99 | accumulate_grad_batches: 1 100 | max_epochs: 20 -------------------------------------------------------------------------------- /third_party/image_generator/configs/inference/sd_2_1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | scaling_config: 13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 14 | discretization_config: 15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 16 | 17 | network_config: 18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | use_checkpoint: True 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 320 24 | attention_resolutions: [4, 2, 1] 25 | num_res_blocks: 2 26 | channel_mult: [1, 2, 4, 4] 27 | num_head_channels: 64 28 | use_linear_in_transformer: True 29 | transformer_depth: 1 30 | context_dim: 1024 31 | 32 | conditioner_config: 33 | target: sgm.modules.GeneralConditioner 34 | params: 35 | emb_models: 36 | - is_trainable: False 37 | input_key: txt 38 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder 39 | params: 40 | freeze: true 41 | layer: penultimate 42 | 43 | first_stage_config: 44 | target: sgm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: val/rec_loss 48 | ddconfig: 49 | double_z: true 50 | z_channels: 4 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: [1, 2, 4, 4] 56 | num_res_blocks: 2 57 | attn_resolutions: [] 58 | dropout: 0.0 59 | lossconfig: 60 | target: torch.nn.Identity -------------------------------------------------------------------------------- /third_party/image_generator/configs/inference/sd_2_1_768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | scaling_config: 13 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling 14 | discretization_config: 15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 16 | 17 | network_config: 18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | use_checkpoint: True 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 320 24 | attention_resolutions: [4, 2, 1] 25 | num_res_blocks: 2 26 | channel_mult: [1, 2, 4, 4] 27 | num_head_channels: 64 28 | use_linear_in_transformer: True 29 | transformer_depth: 1 30 | context_dim: 1024 31 | 32 | conditioner_config: 33 | target: sgm.modules.GeneralConditioner 34 | params: 35 | emb_models: 36 | - is_trainable: False 37 | input_key: txt 38 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder 39 | params: 40 | freeze: true 41 | layer: penultimate 42 | 43 | first_stage_config: 44 | target: sgm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: val/rec_loss 48 | ddconfig: 49 | double_z: true 50 | z_channels: 4 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: [1, 2, 4, 4] 56 | num_res_blocks: 2 57 | attn_resolutions: [] 58 | dropout: 0.0 59 | lossconfig: 60 | target: torch.nn.Identity -------------------------------------------------------------------------------- /third_party/image_generator/configs/inference/sd_xl_base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.13025 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | scaling_config: 13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 14 | discretization_config: 15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 16 | 17 | network_config: 18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | adm_in_channels: 2816 21 | num_classes: sequential 22 | use_checkpoint: True 23 | in_channels: 4 24 | out_channels: 4 25 | model_channels: 320 26 | attention_resolutions: [4, 2] 27 | num_res_blocks: 2 28 | channel_mult: [1, 2, 4] 29 | num_head_channels: 64 30 | use_linear_in_transformer: True 31 | transformer_depth: [1, 2, 10] 32 | context_dim: 2048 33 | spatial_transformer_attn_type: softmax-xformers 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - is_trainable: False 40 | input_key: txt 41 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 42 | params: 43 | layer: hidden 44 | layer_idx: 11 45 | 46 | - is_trainable: False 47 | input_key: txt 48 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 49 | params: 50 | arch: ViT-bigG-14 51 | version: laion2b_s39b_b160k 52 | freeze: True 53 | layer: penultimate 54 | always_return_pooled: True 55 | legacy: False 56 | 57 | - is_trainable: False 58 | input_key: original_size_as_tuple 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - is_trainable: False 64 | input_key: crop_coords_top_left 65 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 66 | params: 67 | outdim: 256 68 | 69 | - is_trainable: False 70 | input_key: target_size_as_tuple 71 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 72 | params: 73 | outdim: 256 74 | 75 | first_stage_config: 76 | target: sgm.models.autoencoder.AutoencoderKL 77 | params: 78 | embed_dim: 4 79 | monitor: val/rec_loss 80 | ddconfig: 81 | attn_type: vanilla-xformers 82 | double_z: true 83 | z_channels: 4 84 | resolution: 256 85 | in_channels: 3 86 | out_ch: 3 87 | ch: 128 88 | ch_mult: [1, 2, 4, 4] 89 | num_res_blocks: 2 90 | attn_resolutions: [] 91 | dropout: 0.0 92 | lossconfig: 93 | target: torch.nn.Identity 94 | -------------------------------------------------------------------------------- /third_party/image_generator/configs/inference/sd_xl_refiner.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.13025 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | scaling_config: 13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 14 | discretization_config: 15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 16 | 17 | network_config: 18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | adm_in_channels: 2560 21 | num_classes: sequential 22 | use_checkpoint: True 23 | in_channels: 4 24 | out_channels: 4 25 | model_channels: 384 26 | attention_resolutions: [4, 2] 27 | num_res_blocks: 2 28 | channel_mult: [1, 2, 4, 4] 29 | num_head_channels: 64 30 | use_linear_in_transformer: True 31 | transformer_depth: 4 32 | context_dim: [1280, 1280, 1280, 1280] 33 | spatial_transformer_attn_type: softmax-xformers 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - is_trainable: False 40 | input_key: txt 41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 42 | params: 43 | arch: ViT-bigG-14 44 | version: laion2b_s39b_b160k 45 | legacy: False 46 | freeze: True 47 | layer: penultimate 48 | always_return_pooled: True 49 | 50 | - is_trainable: False 51 | input_key: original_size_as_tuple 52 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 53 | params: 54 | outdim: 256 55 | 56 | - is_trainable: False 57 | input_key: crop_coords_top_left 58 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 59 | params: 60 | outdim: 256 61 | 62 | - is_trainable: False 63 | input_key: aesthetic_score 64 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 65 | params: 66 | outdim: 256 67 | 68 | first_stage_config: 69 | target: sgm.models.autoencoder.AutoencoderKL 70 | params: 71 | embed_dim: 4 72 | monitor: val/rec_loss 73 | ddconfig: 74 | attn_type: vanilla-xformers 75 | double_z: true 76 | z_channels: 4 77 | resolution: 256 78 | in_channels: 3 79 | out_ch: 3 80 | ch: 128 81 | ch_mult: [1, 2, 4, 4] 82 | num_res_blocks: 2 83 | attn_resolutions: [] 84 | dropout: 0.0 85 | lossconfig: 86 | target: torch.nn.Identity 87 | -------------------------------------------------------------------------------- /third_party/image_generator/configs/inference/sv3d_p.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 9 | params: 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 15 | params: 16 | adm_in_channels: 1280 17 | num_classes: sequential 18 | use_checkpoint: True 19 | in_channels: 8 20 | out_channels: 4 21 | model_channels: 320 22 | attention_resolutions: [4, 2, 1] 23 | num_res_blocks: 2 24 | channel_mult: [1, 2, 4, 4] 25 | num_head_channels: 64 26 | use_linear_in_transformer: True 27 | transformer_depth: 1 28 | context_dim: 1024 29 | spatial_transformer_attn_type: softmax-xformers 30 | extra_ff_mix_layer: True 31 | use_spatial_context: True 32 | merge_strategy: learned_with_images 33 | video_kernel_size: [3, 1, 1] 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - input_key: cond_frames_without_noise 40 | is_trainable: False 41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 42 | params: 43 | n_cond_frames: 1 44 | n_copies: 1 45 | open_clip_embedding_config: 46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 47 | params: 48 | freeze: True 49 | 50 | - input_key: cond_frames 51 | is_trainable: False 52 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 53 | params: 54 | disable_encoder_autocast: True 55 | n_cond_frames: 1 56 | n_copies: 1 57 | is_ae: True 58 | encoder_config: 59 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 60 | params: 61 | embed_dim: 4 62 | monitor: val/rec_loss 63 | ddconfig: 64 | attn_type: vanilla-xformers 65 | double_z: True 66 | z_channels: 4 67 | resolution: 256 68 | in_channels: 3 69 | out_ch: 3 70 | ch: 128 71 | ch_mult: [1, 2, 4, 4] 72 | num_res_blocks: 2 73 | attn_resolutions: [] 74 | dropout: 0.0 75 | lossconfig: 76 | target: torch.nn.Identity 77 | 78 | - input_key: cond_aug 79 | is_trainable: False 80 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 81 | params: 82 | outdim: 256 83 | 84 | - input_key: polars_rad 85 | is_trainable: False 86 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 87 | params: 88 | outdim: 512 89 | 90 | - input_key: azimuths_rad 91 | is_trainable: False 92 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 93 | params: 94 | outdim: 512 95 | 96 | first_stage_config: 97 | target: sgm.models.autoencoder.AutoencodingEngine 98 | params: 99 | loss_config: 100 | target: torch.nn.Identity 101 | regularizer_config: 102 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 103 | encoder_config: 104 | target: torch.nn.Identity 105 | decoder_config: 106 | target: sgm.modules.diffusionmodules.model.Decoder 107 | params: 108 | attn_type: vanilla-xformers 109 | double_z: True 110 | z_channels: 4 111 | resolution: 256 112 | in_channels: 3 113 | out_ch: 3 114 | ch: 128 115 | ch_mult: [ 1, 2, 4, 4 ] 116 | num_res_blocks: 2 117 | attn_resolutions: [ ] 118 | dropout: 0.0 -------------------------------------------------------------------------------- /third_party/image_generator/configs/inference/sv3d_u.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 9 | params: 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 15 | params: 16 | adm_in_channels: 256 17 | num_classes: sequential 18 | use_checkpoint: True 19 | in_channels: 8 20 | out_channels: 4 21 | model_channels: 320 22 | attention_resolutions: [4, 2, 1] 23 | num_res_blocks: 2 24 | channel_mult: [1, 2, 4, 4] 25 | num_head_channels: 64 26 | use_linear_in_transformer: True 27 | transformer_depth: 1 28 | context_dim: 1024 29 | spatial_transformer_attn_type: softmax-xformers 30 | extra_ff_mix_layer: True 31 | use_spatial_context: True 32 | merge_strategy: learned_with_images 33 | video_kernel_size: [3, 1, 1] 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - input_key: cond_frames_without_noise 40 | is_trainable: False 41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 42 | params: 43 | n_cond_frames: 1 44 | n_copies: 1 45 | open_clip_embedding_config: 46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 47 | params: 48 | freeze: True 49 | 50 | - input_key: cond_frames 51 | is_trainable: False 52 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 53 | params: 54 | disable_encoder_autocast: True 55 | n_cond_frames: 1 56 | n_copies: 1 57 | is_ae: True 58 | encoder_config: 59 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 60 | params: 61 | embed_dim: 4 62 | monitor: val/rec_loss 63 | ddconfig: 64 | attn_type: vanilla-xformers 65 | double_z: True 66 | z_channels: 4 67 | resolution: 256 68 | in_channels: 3 69 | out_ch: 3 70 | ch: 128 71 | ch_mult: [1, 2, 4, 4] 72 | num_res_blocks: 2 73 | attn_resolutions: [] 74 | dropout: 0.0 75 | lossconfig: 76 | target: torch.nn.Identity 77 | 78 | - input_key: cond_aug 79 | is_trainable: False 80 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 81 | params: 82 | outdim: 256 83 | 84 | first_stage_config: 85 | target: sgm.models.autoencoder.AutoencodingEngine 86 | params: 87 | loss_config: 88 | target: torch.nn.Identity 89 | regularizer_config: 90 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 91 | encoder_config: 92 | target: torch.nn.Identity 93 | decoder_config: 94 | target: sgm.modules.diffusionmodules.model.Decoder 95 | params: 96 | attn_type: vanilla-xformers 97 | double_z: True 98 | z_channels: 4 99 | resolution: 256 100 | in_channels: 3 101 | out_ch: 3 102 | ch: 128 103 | ch_mult: [ 1, 2, 4, 4 ] 104 | num_res_blocks: 2 105 | attn_resolutions: [ ] 106 | dropout: 0.0 -------------------------------------------------------------------------------- /third_party/image_generator/configs/inference/svd.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 9 | params: 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 15 | params: 16 | adm_in_channels: 768 17 | num_classes: sequential 18 | use_checkpoint: True 19 | in_channels: 8 20 | out_channels: 4 21 | model_channels: 320 22 | attention_resolutions: [4, 2, 1] 23 | num_res_blocks: 2 24 | channel_mult: [1, 2, 4, 4] 25 | num_head_channels: 64 26 | use_linear_in_transformer: True 27 | transformer_depth: 1 28 | context_dim: 1024 29 | spatial_transformer_attn_type: softmax-xformers 30 | extra_ff_mix_layer: True 31 | use_spatial_context: True 32 | merge_strategy: learned_with_images 33 | video_kernel_size: [3, 1, 1] 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - is_trainable: False 40 | input_key: cond_frames_without_noise 41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 42 | params: 43 | n_cond_frames: 1 44 | n_copies: 1 45 | open_clip_embedding_config: 46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 47 | params: 48 | freeze: True 49 | 50 | - input_key: fps_id 51 | is_trainable: False 52 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 53 | params: 54 | outdim: 256 55 | 56 | - input_key: motion_bucket_id 57 | is_trainable: False 58 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 59 | params: 60 | outdim: 256 61 | 62 | - input_key: cond_frames 63 | is_trainable: False 64 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 65 | params: 66 | disable_encoder_autocast: True 67 | n_cond_frames: 1 68 | n_copies: 1 69 | is_ae: True 70 | encoder_config: 71 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 72 | params: 73 | embed_dim: 4 74 | monitor: val/rec_loss 75 | ddconfig: 76 | attn_type: vanilla-xformers 77 | double_z: True 78 | z_channels: 4 79 | resolution: 256 80 | in_channels: 3 81 | out_ch: 3 82 | ch: 128 83 | ch_mult: [1, 2, 4, 4] 84 | num_res_blocks: 2 85 | attn_resolutions: [] 86 | dropout: 0.0 87 | lossconfig: 88 | target: torch.nn.Identity 89 | 90 | - input_key: cond_aug 91 | is_trainable: False 92 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 93 | params: 94 | outdim: 256 95 | 96 | first_stage_config: 97 | target: sgm.models.autoencoder.AutoencodingEngine 98 | params: 99 | loss_config: 100 | target: torch.nn.Identity 101 | regularizer_config: 102 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 103 | encoder_config: 104 | target: sgm.modules.diffusionmodules.model.Encoder 105 | params: 106 | attn_type: vanilla 107 | double_z: True 108 | z_channels: 4 109 | resolution: 256 110 | in_channels: 3 111 | out_ch: 3 112 | ch: 128 113 | ch_mult: [1, 2, 4, 4] 114 | num_res_blocks: 2 115 | attn_resolutions: [] 116 | dropout: 0.0 117 | decoder_config: 118 | target: sgm.modules.autoencoding.temporal_ae.VideoDecoder 119 | params: 120 | attn_type: vanilla 121 | double_z: True 122 | z_channels: 4 123 | resolution: 256 124 | in_channels: 3 125 | out_ch: 3 126 | ch: 128 127 | ch_mult: [1, 2, 4, 4] 128 | num_res_blocks: 2 129 | attn_resolutions: [] 130 | dropout: 0.0 131 | video_kernel_size: [3, 1, 1] -------------------------------------------------------------------------------- /third_party/image_generator/configs/inference/svd_image_decoder.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 9 | params: 10 | scaling_config: 11 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 12 | 13 | network_config: 14 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 15 | params: 16 | adm_in_channels: 768 17 | num_classes: sequential 18 | use_checkpoint: True 19 | in_channels: 8 20 | out_channels: 4 21 | model_channels: 320 22 | attention_resolutions: [4, 2, 1] 23 | num_res_blocks: 2 24 | channel_mult: [1, 2, 4, 4] 25 | num_head_channels: 64 26 | use_linear_in_transformer: True 27 | transformer_depth: 1 28 | context_dim: 1024 29 | spatial_transformer_attn_type: softmax-xformers 30 | extra_ff_mix_layer: True 31 | use_spatial_context: True 32 | merge_strategy: learned_with_images 33 | video_kernel_size: [3, 1, 1] 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - is_trainable: False 40 | input_key: cond_frames_without_noise 41 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 42 | params: 43 | n_cond_frames: 1 44 | n_copies: 1 45 | open_clip_embedding_config: 46 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 47 | params: 48 | freeze: True 49 | 50 | - input_key: fps_id 51 | is_trainable: False 52 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 53 | params: 54 | outdim: 256 55 | 56 | - input_key: motion_bucket_id 57 | is_trainable: False 58 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 59 | params: 60 | outdim: 256 61 | 62 | - input_key: cond_frames 63 | is_trainable: False 64 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 65 | params: 66 | disable_encoder_autocast: True 67 | n_cond_frames: 1 68 | n_copies: 1 69 | is_ae: True 70 | encoder_config: 71 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 72 | params: 73 | embed_dim: 4 74 | monitor: val/rec_loss 75 | ddconfig: 76 | attn_type: vanilla-xformers 77 | double_z: True 78 | z_channels: 4 79 | resolution: 256 80 | in_channels: 3 81 | out_ch: 3 82 | ch: 128 83 | ch_mult: [1, 2, 4, 4] 84 | num_res_blocks: 2 85 | attn_resolutions: [] 86 | dropout: 0.0 87 | lossconfig: 88 | target: torch.nn.Identity 89 | 90 | - input_key: cond_aug 91 | is_trainable: False 92 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 93 | params: 94 | outdim: 256 95 | 96 | first_stage_config: 97 | target: sgm.models.autoencoder.AutoencoderKL 98 | params: 99 | embed_dim: 4 100 | monitor: val/rec_loss 101 | ddconfig: 102 | attn_type: vanilla-xformers 103 | double_z: True 104 | z_channels: 4 105 | resolution: 256 106 | in_channels: 3 107 | out_ch: 3 108 | ch: 128 109 | ch_mult: [1, 2, 4, 4] 110 | num_res_blocks: 2 111 | attn_resolutions: [] 112 | dropout: 0.0 113 | lossconfig: 114 | target: torch.nn.Identity -------------------------------------------------------------------------------- /third_party/image_generator/configs/sd_xl_base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.13025 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | scaling_config: 13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 14 | discretization_config: 15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 16 | 17 | network_config: 18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | adm_in_channels: 2816 21 | num_classes: sequential 22 | use_checkpoint: True 23 | in_channels: 4 24 | out_channels: 4 25 | model_channels: 320 26 | attention_resolutions: [4, 2] 27 | num_res_blocks: 2 28 | channel_mult: [1, 2, 4] 29 | num_head_channels: 64 30 | use_linear_in_transformer: True 31 | transformer_depth: [1, 2, 10] 32 | context_dim: 2048 33 | spatial_transformer_attn_type: softmax-xformers 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - is_trainable: False 40 | input_key: txt 41 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 42 | params: 43 | layer: hidden 44 | layer_idx: 11 45 | 46 | - is_trainable: False 47 | input_key: txt 48 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 49 | params: 50 | arch: ViT-bigG-14 51 | version: laion2b_s39b_b160k 52 | freeze: True 53 | layer: penultimate 54 | always_return_pooled: True 55 | legacy: False 56 | 57 | - is_trainable: False 58 | input_key: original_size_as_tuple 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - is_trainable: False 64 | input_key: crop_coords_top_left 65 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 66 | params: 67 | outdim: 256 68 | 69 | - is_trainable: False 70 | input_key: target_size_as_tuple 71 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 72 | params: 73 | outdim: 256 74 | 75 | first_stage_config: 76 | target: sgm.models.autoencoder.AutoencoderKL 77 | params: 78 | embed_dim: 4 79 | monitor: val/rec_loss 80 | ddconfig: 81 | attn_type: vanilla-xformers 82 | double_z: true 83 | z_channels: 4 84 | resolution: 256 85 | in_channels: 3 86 | out_ch: 3 87 | ch: 128 88 | ch_mult: [1, 2, 4, 4] 89 | num_res_blocks: 2 90 | attn_resolutions: [] 91 | dropout: 0.0 92 | lossconfig: 93 | target: torch.nn.Identity 94 | -------------------------------------------------------------------------------- /third_party/image_generator/data/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/data/DejaVuSans.ttf -------------------------------------------------------------------------------- /third_party/image_generator/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "sgm" 7 | dynamic = ["version"] 8 | description = "Stability Generative Models" 9 | readme = "README.md" 10 | license-files = { paths = ["LICENSE-CODE"] } 11 | requires-python = ">=3.8" 12 | 13 | [project.urls] 14 | Homepage = "https://github.com/Stability-AI/generative-models" 15 | 16 | [tool.hatch.version] 17 | path = "sgm/__init__.py" 18 | 19 | [tool.hatch.build] 20 | # This needs to be explicitly set so the configuration files 21 | # grafted into the `sgm` directory get included in the wheel's 22 | # RECORD file. 23 | include = [ 24 | "sgm", 25 | ] 26 | # The force-include configurations below make Hatch copy 27 | # the configs/ directory (containing the various YAML files required 28 | # to generatively model) into the source distribution and the wheel. 29 | 30 | [tool.hatch.build.targets.sdist.force-include] 31 | "./configs" = "sgm/configs" 32 | 33 | [tool.hatch.build.targets.wheel.force-include] 34 | "./configs" = "sgm/configs" 35 | 36 | [tool.hatch.envs.ci] 37 | skip-install = false 38 | 39 | dependencies = [ 40 | "pytest" 41 | ] 42 | 43 | [tool.hatch.envs.ci.scripts] 44 | test-inference = [ 45 | "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", 46 | "pip install -r requirements/pt2.txt", 47 | "pytest -v tests/inference/test_inference.py {args}", 48 | ] 49 | -------------------------------------------------------------------------------- /third_party/image_generator/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | inference: mark as inference test (deselect with '-m "not inference"') -------------------------------------------------------------------------------- /third_party/image_generator/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/scripts/demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/demo/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/scripts/demo/discretization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from third_party.image_generator.sgm.modules.diffusionmodules.discretizer import Discretization 4 | 5 | 6 | class Img2ImgDiscretizationWrapper: 7 | """ 8 | wraps a discretizer, and prunes the sigmas 9 | params: 10 | strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) 11 | """ 12 | 13 | def __init__(self, discretization: Discretization, strength: float = 1.0): 14 | self.discretization = discretization 15 | self.strength = strength 16 | assert 0.0 <= self.strength <= 1.0 17 | 18 | def __call__(self, *args, **kwargs): 19 | # sigmas start large first, and decrease then 20 | sigmas = self.discretization(*args, **kwargs) 21 | print(f"sigmas after discretization, before pruning img2img: ", sigmas) 22 | sigmas = torch.flip(sigmas, (0,)) 23 | sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] 24 | print("prune index:", max(int(self.strength * len(sigmas)), 1)) 25 | sigmas = torch.flip(sigmas, (0,)) 26 | print(f"sigmas after pruning: ", sigmas) 27 | return sigmas 28 | 29 | 30 | class Txt2NoisyDiscretizationWrapper: 31 | """ 32 | wraps a discretizer, and prunes the sigmas 33 | params: 34 | strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) 35 | """ 36 | 37 | def __init__( 38 | self, discretization: Discretization, strength: float = 0.0, original_steps=None 39 | ): 40 | self.discretization = discretization 41 | self.strength = strength 42 | self.original_steps = original_steps 43 | assert 0.0 <= self.strength <= 1.0 44 | 45 | def __call__(self, *args, **kwargs): 46 | # sigmas start large first, and decrease then 47 | sigmas = self.discretization(*args, **kwargs) 48 | print(f"sigmas after discretization, before pruning img2img: ", sigmas) 49 | sigmas = torch.flip(sigmas, (0,)) 50 | if self.original_steps is None: 51 | steps = len(sigmas) 52 | else: 53 | steps = self.original_steps + 1 54 | prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) 55 | sigmas = sigmas[prune_index:] 56 | print("prune index:", prune_index) 57 | sigmas = torch.flip(sigmas, (0,)) 58 | print(f"sigmas after pruning: ", sigmas) 59 | return sigmas 60 | -------------------------------------------------------------------------------- /third_party/image_generator/scripts/demo/sv3d_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def generate_dynamic_cycle_xy_values( 8 | length=21, 9 | init_elev=0, 10 | num_components=84, 11 | frequency_range=(1, 5), 12 | amplitude_range=(0.5, 10), 13 | step_range=(0, 2), 14 | ): 15 | # Y values generation 16 | y_sequence = np.ones(length) * init_elev 17 | for _ in range(num_components): 18 | # Choose a frequency that will complete whole cycles in the sequence 19 | frequency = np.random.randint(*frequency_range) * (2 * np.pi / length) 20 | amplitude = np.random.uniform(*amplitude_range) 21 | phase_shift = np.random.choice([0, np.pi]) # np.random.uniform(0, 2 * np.pi) 22 | angles = ( 23 | np.linspace(0, frequency * length, length, endpoint=False) + phase_shift 24 | ) 25 | y_sequence += np.sin(angles) * amplitude 26 | # X values generation 27 | # Generate length - 1 steps since the last step is back to start 28 | steps = np.random.uniform(*step_range, length - 1) 29 | total_step_sum = np.sum(steps) 30 | # Calculate the scale factor to scale total steps to just under 360 31 | scale_factor = ( 32 | 360 - ((360 / length) * np.random.uniform(*step_range)) 33 | ) / total_step_sum 34 | # Apply the scale factor and generate the sequence of X values 35 | x_values = np.cumsum(steps * scale_factor) 36 | # Ensure the sequence starts at 0 and add the final step to complete the loop 37 | x_values = np.insert(x_values, 0, 0) 38 | return x_values, y_sequence 39 | 40 | 41 | def smooth_data(data, window_size): 42 | # Extend data at both ends by wrapping around to create a continuous loop 43 | pad_size = window_size 44 | padded_data = np.concatenate((data[-pad_size:], data, data[:pad_size])) 45 | 46 | # Apply smoothing 47 | kernel = np.ones(window_size) / window_size 48 | smoothed_data = np.convolve(padded_data, kernel, mode="same") 49 | 50 | # Extract the smoothed data corresponding to the original sequence 51 | # Adjust the indices to account for the larger padding 52 | start_index = pad_size 53 | end_index = -pad_size if pad_size != 0 else None 54 | smoothed_original_data = smoothed_data[start_index:end_index] 55 | return smoothed_original_data 56 | 57 | 58 | # Function to generate and process the data 59 | def gen_dynamic_loop(length=21, elev_deg=0): 60 | while True: 61 | # Generate the combined X and Y values using the new function 62 | azim_values, elev_values = generate_dynamic_cycle_xy_values( 63 | length=84, init_elev=elev_deg 64 | ) 65 | # Smooth the Y values directly 66 | smoothed_elev_values = smooth_data(elev_values, 5) 67 | max_magnitude = np.max(np.abs(smoothed_elev_values)) 68 | if max_magnitude < 90: 69 | break 70 | subsample = 84 // length 71 | azim_rad = np.deg2rad(azim_values[::subsample]) 72 | elev_rad = np.deg2rad(smoothed_elev_values[::subsample]) 73 | # Make cond frame the last one 74 | return np.roll(azim_rad, -1), np.roll(elev_rad, -1) 75 | 76 | 77 | def plot_3D(azim, polar, save_path, dynamic=True): 78 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 79 | elev = np.deg2rad(90) - polar 80 | fig = plt.figure(figsize=(5, 5)) 81 | ax = fig.add_subplot(projection="3d") 82 | cm = plt.get_cmap("Greys") 83 | col_line = [cm(i) for i in np.linspace(0.3, 1, len(azim) + 1)] 84 | cm = plt.get_cmap("cool") 85 | col = [cm(float(i) / (len(azim))) for i in np.arange(len(azim))] 86 | xs = np.cos(elev) * np.cos(azim) 87 | ys = np.cos(elev) * np.sin(azim) 88 | zs = np.sin(elev) 89 | ax.scatter(xs[0], ys[0], zs[0], s=100, color=col[0]) 90 | xs_d, ys_d, zs_d = (xs[1:] - xs[:-1]), (ys[1:] - ys[:-1]), (zs[1:] - zs[:-1]) 91 | for i in range(len(xs) - 1): 92 | if dynamic: 93 | ax.quiver( 94 | xs[i], ys[i], zs[i], xs_d[i], ys_d[i], zs_d[i], lw=2, color=col_line[i] 95 | ) 96 | else: 97 | ax.plot(xs[i : i + 2], ys[i : i + 2], zs[i : i + 2], lw=2, c=col_line[i]) 98 | ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1]) 99 | ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k") 100 | ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k") 101 | ax.view_init(elev=30, azim=-20, roll=0) 102 | plt.savefig(save_path, bbox_inches="tight") 103 | plt.clf() 104 | plt.close() 105 | -------------------------------------------------------------------------------- /third_party/image_generator/scripts/sampling/configs/sv3d_p.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/sv3d_p.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 1280 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - input_key: cond_frames_without_noise 41 | is_trainable: False 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: cond_frames 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 54 | params: 55 | disable_encoder_autocast: True 56 | n_cond_frames: 1 57 | n_copies: 1 58 | is_ae: True 59 | encoder_config: 60 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 61 | params: 62 | embed_dim: 4 63 | monitor: val/rec_loss 64 | ddconfig: 65 | attn_type: vanilla-xformers 66 | double_z: True 67 | z_channels: 4 68 | resolution: 256 69 | in_channels: 3 70 | out_ch: 3 71 | ch: 128 72 | ch_mult: [1, 2, 4, 4] 73 | num_res_blocks: 2 74 | attn_resolutions: [] 75 | dropout: 0.0 76 | lossconfig: 77 | target: torch.nn.Identity 78 | 79 | - input_key: cond_aug 80 | is_trainable: False 81 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 82 | params: 83 | outdim: 256 84 | 85 | - input_key: polars_rad 86 | is_trainable: False 87 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 88 | params: 89 | outdim: 512 90 | 91 | - input_key: azimuths_rad 92 | is_trainable: False 93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 94 | params: 95 | outdim: 512 96 | 97 | first_stage_config: 98 | target: sgm.models.autoencoder.AutoencodingEngine 99 | params: 100 | loss_config: 101 | target: torch.nn.Identity 102 | regularizer_config: 103 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 104 | encoder_config: 105 | target: torch.nn.Identity 106 | decoder_config: 107 | target: sgm.modules.diffusionmodules.model.Decoder 108 | params: 109 | attn_type: vanilla-xformers 110 | double_z: True 111 | z_channels: 4 112 | resolution: 256 113 | in_channels: 3 114 | out_ch: 3 115 | ch: 128 116 | ch_mult: [ 1, 2, 4, 4 ] 117 | num_res_blocks: 2 118 | attn_resolutions: [ ] 119 | dropout: 0.0 120 | 121 | sampler_config: 122 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 123 | params: 124 | discretization_config: 125 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 126 | params: 127 | sigma_max: 700.0 128 | 129 | guider_config: 130 | target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider 131 | params: 132 | max_scale: 2.5 133 | -------------------------------------------------------------------------------- /third_party/image_generator/scripts/sampling/configs/sv3d_u.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/sv3d_u.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 256 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - is_trainable: False 41 | input_key: cond_frames_without_noise 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: cond_frames 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 54 | params: 55 | disable_encoder_autocast: True 56 | n_cond_frames: 1 57 | n_copies: 1 58 | is_ae: True 59 | encoder_config: 60 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 61 | params: 62 | embed_dim: 4 63 | monitor: val/rec_loss 64 | ddconfig: 65 | attn_type: vanilla-xformers 66 | double_z: True 67 | z_channels: 4 68 | resolution: 256 69 | in_channels: 3 70 | out_ch: 3 71 | ch: 128 72 | ch_mult: [1, 2, 4, 4] 73 | num_res_blocks: 2 74 | attn_resolutions: [] 75 | dropout: 0.0 76 | lossconfig: 77 | target: torch.nn.Identity 78 | 79 | - input_key: cond_aug 80 | is_trainable: False 81 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 82 | params: 83 | outdim: 256 84 | 85 | first_stage_config: 86 | target: sgm.models.autoencoder.AutoencodingEngine 87 | params: 88 | loss_config: 89 | target: torch.nn.Identity 90 | regularizer_config: 91 | target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 92 | encoder_config: 93 | target: torch.nn.Identity 94 | decoder_config: 95 | target: sgm.modules.diffusionmodules.model.Decoder 96 | params: 97 | attn_type: vanilla-xformers 98 | double_z: True 99 | z_channels: 4 100 | resolution: 256 101 | in_channels: 3 102 | out_ch: 3 103 | ch: 128 104 | ch_mult: [ 1, 2, 4, 4 ] 105 | num_res_blocks: 2 106 | attn_resolutions: [ ] 107 | dropout: 0.0 108 | 109 | sampler_config: 110 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 111 | params: 112 | discretization_config: 113 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 114 | params: 115 | sigma_max: 700.0 116 | 117 | guider_config: 118 | target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider 119 | params: 120 | max_scale: 2.5 121 | -------------------------------------------------------------------------------- /third_party/image_generator/scripts/sampling/configs/svd_image_decoder.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/svd_image_decoder.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 768 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - is_trainable: False 41 | input_key: cond_frames_without_noise 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: fps_id 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 54 | params: 55 | outdim: 256 56 | 57 | - input_key: motion_bucket_id 58 | is_trainable: False 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - input_key: cond_frames 64 | is_trainable: False 65 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 66 | params: 67 | disable_encoder_autocast: True 68 | n_cond_frames: 1 69 | n_copies: 1 70 | is_ae: True 71 | encoder_config: 72 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 73 | params: 74 | embed_dim: 4 75 | monitor: val/rec_loss 76 | ddconfig: 77 | attn_type: vanilla-xformers 78 | double_z: True 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: [1, 2, 4, 4] 85 | num_res_blocks: 2 86 | attn_resolutions: [] 87 | dropout: 0.0 88 | lossconfig: 89 | target: torch.nn.Identity 90 | 91 | - input_key: cond_aug 92 | is_trainable: False 93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 94 | params: 95 | outdim: 256 96 | 97 | first_stage_config: 98 | target: sgm.models.autoencoder.AutoencoderKL 99 | params: 100 | embed_dim: 4 101 | monitor: val/rec_loss 102 | ddconfig: 103 | attn_type: vanilla-xformers 104 | double_z: True 105 | z_channels: 4 106 | resolution: 256 107 | in_channels: 3 108 | out_ch: 3 109 | ch: 128 110 | ch_mult: [1, 2, 4, 4] 111 | num_res_blocks: 2 112 | attn_resolutions: [] 113 | dropout: 0.0 114 | lossconfig: 115 | target: torch.nn.Identity 116 | 117 | sampler_config: 118 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 119 | params: 120 | discretization_config: 121 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 122 | params: 123 | sigma_max: 700.0 124 | 125 | guider_config: 126 | target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider 127 | params: 128 | max_scale: 2.5 129 | min_scale: 1.0 -------------------------------------------------------------------------------- /third_party/image_generator/scripts/sampling/configs/svd_xt_image_decoder.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.18215 5 | disable_first_stage_autocast: True 6 | ckpt_path: checkpoints/svd_xt_image_decoder.safetensors 7 | 8 | denoiser_config: 9 | target: sgm.modules.diffusionmodules.denoiser.Denoiser 10 | params: 11 | scaling_config: 12 | target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 13 | 14 | network_config: 15 | target: sgm.modules.diffusionmodules.video_model.VideoUNet 16 | params: 17 | adm_in_channels: 768 18 | num_classes: sequential 19 | use_checkpoint: True 20 | in_channels: 8 21 | out_channels: 4 22 | model_channels: 320 23 | attention_resolutions: [4, 2, 1] 24 | num_res_blocks: 2 25 | channel_mult: [1, 2, 4, 4] 26 | num_head_channels: 64 27 | use_linear_in_transformer: True 28 | transformer_depth: 1 29 | context_dim: 1024 30 | spatial_transformer_attn_type: softmax-xformers 31 | extra_ff_mix_layer: True 32 | use_spatial_context: True 33 | merge_strategy: learned_with_images 34 | video_kernel_size: [3, 1, 1] 35 | 36 | conditioner_config: 37 | target: sgm.modules.GeneralConditioner 38 | params: 39 | emb_models: 40 | - is_trainable: False 41 | input_key: cond_frames_without_noise 42 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 43 | params: 44 | n_cond_frames: 1 45 | n_copies: 1 46 | open_clip_embedding_config: 47 | target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 48 | params: 49 | freeze: True 50 | 51 | - input_key: fps_id 52 | is_trainable: False 53 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 54 | params: 55 | outdim: 256 56 | 57 | - input_key: motion_bucket_id 58 | is_trainable: False 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - input_key: cond_frames 64 | is_trainable: False 65 | target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 66 | params: 67 | disable_encoder_autocast: True 68 | n_cond_frames: 1 69 | n_copies: 1 70 | is_ae: True 71 | encoder_config: 72 | target: sgm.models.autoencoder.AutoencoderKLModeOnly 73 | params: 74 | embed_dim: 4 75 | monitor: val/rec_loss 76 | ddconfig: 77 | attn_type: vanilla-xformers 78 | double_z: True 79 | z_channels: 4 80 | resolution: 256 81 | in_channels: 3 82 | out_ch: 3 83 | ch: 128 84 | ch_mult: [1, 2, 4, 4] 85 | num_res_blocks: 2 86 | attn_resolutions: [] 87 | dropout: 0.0 88 | lossconfig: 89 | target: torch.nn.Identity 90 | 91 | - input_key: cond_aug 92 | is_trainable: False 93 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 94 | params: 95 | outdim: 256 96 | 97 | first_stage_config: 98 | target: sgm.models.autoencoder.AutoencoderKL 99 | params: 100 | embed_dim: 4 101 | monitor: val/rec_loss 102 | ddconfig: 103 | attn_type: vanilla-xformers 104 | double_z: True 105 | z_channels: 4 106 | resolution: 256 107 | in_channels: 3 108 | out_ch: 3 109 | ch: 128 110 | ch_mult: [1, 2, 4, 4] 111 | num_res_blocks: 2 112 | attn_resolutions: [] 113 | dropout: 0.0 114 | lossconfig: 115 | target: torch.nn.Identity 116 | 117 | sampler_config: 118 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 119 | params: 120 | discretization_config: 121 | target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization 122 | params: 123 | sigma_max: 700.0 124 | 125 | guider_config: 126 | target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider 127 | params: 128 | max_scale: 3.0 129 | min_scale: 1.5 -------------------------------------------------------------------------------- /third_party/image_generator/scripts/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/util/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/scripts/util/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/util/detection/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/scripts/util/detection/nsfw_and_watermark_dectection.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import clip 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as T 7 | from PIL import Image 8 | 9 | RESOURCES_ROOT = "dataLoader/image_generator/scripts/util/detection/" 10 | 11 | 12 | def predict_proba(X, weights, biases): 13 | logits = X @ weights.T + biases 14 | proba = np.where( 15 | logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits)) 16 | ) 17 | return proba.T 18 | 19 | 20 | def load_model_weights(path: str): 21 | model_weights = np.load(path) 22 | return model_weights["weights"], model_weights["biases"] 23 | 24 | 25 | def clip_process_images(images: torch.Tensor) -> torch.Tensor: 26 | min_size = min(images.shape[-2:]) 27 | return T.Compose( 28 | [ 29 | T.CenterCrop(min_size), # TODO: this might affect the watermark, check this 30 | T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True), 31 | T.Normalize( 32 | (0.48145466, 0.4578275, 0.40821073), 33 | (0.26862954, 0.26130258, 0.27577711), 34 | ), 35 | ] 36 | )(images) 37 | 38 | 39 | class DeepFloydDataFiltering(object): 40 | def __init__( 41 | self, verbose: bool = False, device: torch.device = torch.device("cpu") 42 | ): 43 | super().__init__() 44 | self.verbose = verbose 45 | self._device = None 46 | self.clip_model, _ = clip.load("ViT-L/14", device=device) 47 | self.clip_model.eval() 48 | 49 | self.cpu_w_weights, self.cpu_w_biases = load_model_weights( 50 | os.path.join(RESOURCES_ROOT, "w_head_v1.npz") 51 | ) 52 | self.cpu_p_weights, self.cpu_p_biases = load_model_weights( 53 | os.path.join(RESOURCES_ROOT, "p_head_v1.npz") 54 | ) 55 | self.w_threshold, self.p_threshold = 0.5, 0.5 56 | 57 | @torch.inference_mode() 58 | def __call__(self, images: torch.Tensor) -> torch.Tensor: 59 | imgs = clip_process_images(images) 60 | if self._device is None: 61 | self._device = next(p for p in self.clip_model.parameters()).device 62 | image_features = self.clip_model.encode_image(imgs.to(self._device)) 63 | image_features = image_features.detach().cpu().numpy().astype(np.float16) 64 | p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases) 65 | w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases) 66 | print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None 67 | query = p_pred > self.p_threshold 68 | if query.sum() > 0: 69 | print(f"Hit for p_threshold: {p_pred}") if self.verbose else None 70 | images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) 71 | query = w_pred > self.w_threshold 72 | if query.sum() > 0: 73 | print(f"Hit for w_threshold: {w_pred}") if self.verbose else None 74 | images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) 75 | return images 76 | 77 | 78 | def load_img(path: str) -> torch.Tensor: 79 | image = Image.open(path) 80 | if not image.mode == "RGB": 81 | image = image.convert("RGB") 82 | image_transforms = T.Compose( 83 | [ 84 | T.ToTensor(), 85 | ] 86 | ) 87 | return image_transforms(image)[None, ...] 88 | 89 | 90 | def test(root): 91 | from einops import rearrange 92 | 93 | filter = DeepFloydDataFiltering(verbose=True) 94 | for p in os.listdir((root)): 95 | print(f"running on {p}...") 96 | img = load_img(os.path.join(root, p)) 97 | filtered_img = filter(img) 98 | filtered_img = rearrange( 99 | 255.0 * (filtered_img.numpy())[0], "c h w -> h w c" 100 | ).astype(np.uint8) 101 | Image.fromarray(filtered_img).save( 102 | os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg") 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | import fire 108 | 109 | fire.Fire(test) 110 | print("done.") 111 | -------------------------------------------------------------------------------- /third_party/image_generator/scripts/util/detection/p_head_v1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/util/detection/p_head_v1.npz -------------------------------------------------------------------------------- /third_party/image_generator/scripts/util/detection/w_head_v1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/scripts/util/detection/w_head_v1.npz -------------------------------------------------------------------------------- /third_party/image_generator/sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine, DiffusionEngine 2 | from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import StableDataModuleFromConfig 2 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/data/cifar10.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import transforms 5 | 6 | 7 | class CIFAR10DataDictWrapper(Dataset): 8 | def __init__(self, dset): 9 | super().__init__() 10 | self.dset = dset 11 | 12 | def __getitem__(self, i): 13 | x, y = self.dset[i] 14 | return {"jpg": x, "cls": y} 15 | 16 | def __len__(self): 17 | return len(self.dset) 18 | 19 | 20 | class CIFAR10Loader(pl.LightningDataModule): 21 | def __init__(self, batch_size, num_workers=0, shuffle=True): 22 | super().__init__() 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 26 | ) 27 | 28 | self.batch_size = batch_size 29 | self.num_workers = num_workers 30 | self.shuffle = shuffle 31 | self.train_dataset = CIFAR10DataDictWrapper( 32 | torchvision.datasets.CIFAR10( 33 | root=".data/", train=True, download=True, transform=transform 34 | ) 35 | ) 36 | self.test_dataset = CIFAR10DataDictWrapper( 37 | torchvision.datasets.CIFAR10( 38 | root=".data/", train=False, download=True, transform=transform 39 | ) 40 | ) 41 | 42 | def prepare_data(self): 43 | pass 44 | 45 | def train_dataloader(self): 46 | return DataLoader( 47 | self.train_dataset, 48 | batch_size=self.batch_size, 49 | shuffle=self.shuffle, 50 | num_workers=self.num_workers, 51 | ) 52 | 53 | def test_dataloader(self): 54 | return DataLoader( 55 | self.test_dataset, 56 | batch_size=self.batch_size, 57 | shuffle=self.shuffle, 58 | num_workers=self.num_workers, 59 | ) 60 | 61 | def val_dataloader(self): 62 | return DataLoader( 63 | self.test_dataset, 64 | batch_size=self.batch_size, 65 | shuffle=self.shuffle, 66 | num_workers=self.num_workers, 67 | ) 68 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchdata.datapipes.iter 4 | import webdataset as wds 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import LightningDataModule 7 | 8 | try: 9 | from sdata import create_dataset, create_dummy_dataset, create_loader 10 | except ImportError as e: 11 | print("#" * 100) 12 | print("Datasets not yet available") 13 | print("to enable, we need to add stable-datasets as a submodule") 14 | print("please use ``git submodule update --init --recursive``") 15 | print("and do ``pip install -e stable-datasets/`` from the root of this repo") 16 | print("#" * 100) 17 | exit(1) 18 | 19 | 20 | class StableDataModuleFromConfig(LightningDataModule): 21 | def __init__( 22 | self, 23 | train: DictConfig, 24 | validation: Optional[DictConfig] = None, 25 | test: Optional[DictConfig] = None, 26 | skip_val_loader: bool = False, 27 | dummy: bool = False, 28 | ): 29 | super().__init__() 30 | self.train_config = train 31 | assert ( 32 | "datapipeline" in self.train_config and "loader" in self.train_config 33 | ), "train config requires the fields `datapipeline` and `loader`" 34 | 35 | self.val_config = validation 36 | if not skip_val_loader: 37 | if self.val_config is not None: 38 | assert ( 39 | "datapipeline" in self.val_config and "loader" in self.val_config 40 | ), "validation config requires the fields `datapipeline` and `loader`" 41 | else: 42 | print( 43 | "Warning: No Validation datapipeline defined, using that one from training" 44 | ) 45 | self.val_config = train 46 | 47 | self.test_config = test 48 | if self.test_config is not None: 49 | assert ( 50 | "datapipeline" in self.test_config and "loader" in self.test_config 51 | ), "test config requires the fields `datapipeline` and `loader`" 52 | 53 | self.dummy = dummy 54 | if self.dummy: 55 | print("#" * 100) 56 | print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") 57 | print("#" * 100) 58 | 59 | def setup(self, stage: str) -> None: 60 | print("Preparing datasets") 61 | if self.dummy: 62 | data_fn = create_dummy_dataset 63 | else: 64 | data_fn = create_dataset 65 | 66 | self.train_datapipeline = data_fn(**self.train_config.datapipeline) 67 | if self.val_config: 68 | self.val_datapipeline = data_fn(**self.val_config.datapipeline) 69 | if self.test_config: 70 | self.test_datapipeline = data_fn(**self.test_config.datapipeline) 71 | 72 | def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: 73 | loader = create_loader(self.train_datapipeline, **self.train_config.loader) 74 | return loader 75 | 76 | def val_dataloader(self) -> wds.DataPipeline: 77 | return create_loader(self.val_datapipeline, **self.val_config.loader) 78 | 79 | def test_dataloader(self) -> wds.DataPipeline: 80 | return create_loader(self.test_datapipeline, **self.test_config.loader) 81 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/data/mnist.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import transforms 5 | 6 | 7 | class MNISTDataDictWrapper(Dataset): 8 | def __init__(self, dset): 9 | super().__init__() 10 | self.dset = dset 11 | 12 | def __getitem__(self, i): 13 | x, y = self.dset[i] 14 | return {"jpg": x, "cls": y} 15 | 16 | def __len__(self): 17 | return len(self.dset) 18 | 19 | 20 | class MNISTLoader(pl.LightningDataModule): 21 | def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): 22 | super().__init__() 23 | 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 26 | ) 27 | 28 | self.batch_size = batch_size 29 | self.num_workers = num_workers 30 | self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 31 | self.shuffle = shuffle 32 | self.train_dataset = MNISTDataDictWrapper( 33 | torchvision.datasets.MNIST( 34 | root=".data/", train=True, download=True, transform=transform 35 | ) 36 | ) 37 | self.test_dataset = MNISTDataDictWrapper( 38 | torchvision.datasets.MNIST( 39 | root=".data/", train=False, download=True, transform=transform 40 | ) 41 | ) 42 | 43 | def prepare_data(self): 44 | pass 45 | 46 | def train_dataloader(self): 47 | return DataLoader( 48 | self.train_dataset, 49 | batch_size=self.batch_size, 50 | shuffle=self.shuffle, 51 | num_workers=self.num_workers, 52 | prefetch_factor=self.prefetch_factor, 53 | ) 54 | 55 | def test_dataloader(self): 56 | return DataLoader( 57 | self.test_dataset, 58 | batch_size=self.batch_size, 59 | shuffle=self.shuffle, 60 | num_workers=self.num_workers, 61 | prefetch_factor=self.prefetch_factor, 62 | ) 63 | 64 | def val_dataloader(self): 65 | return DataLoader( 66 | self.test_dataset, 67 | batch_size=self.batch_size, 68 | shuffle=self.shuffle, 69 | num_workers=self.num_workers, 70 | prefetch_factor=self.prefetch_factor, 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | dset = MNISTDataDictWrapper( 76 | torchvision.datasets.MNIST( 77 | root=".data/", 78 | train=False, 79 | download=True, 80 | transform=transforms.Compose( 81 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] 82 | ), 83 | ) 84 | ) 85 | ex = dset[0] 86 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = ( 32 | self.lr_max - self.lr_start 33 | ) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / ( 38 | self.lr_max_decay_steps - self.lr_warm_up_steps 39 | ) 40 | t = min(t, 1.0) 41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 42 | 1 + np.cos(t * np.pi) 43 | ) 44 | self.last_lr = lr 45 | return lr 46 | 47 | def __call__(self, n, **kwargs): 48 | return self.schedule(n, **kwargs) 49 | 50 | 51 | class LambdaWarmUpCosineScheduler2: 52 | """ 53 | supports repeated iterations, configurable via lists 54 | note: use with a base_lr of 1.0. 55 | """ 56 | 57 | def __init__( 58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 59 | ): 60 | assert ( 61 | len(warm_up_steps) 62 | == len(f_min) 63 | == len(f_max) 64 | == len(f_start) 65 | == len(cycle_lengths) 66 | ) 67 | self.lr_warm_up_steps = warm_up_steps 68 | self.f_start = f_start 69 | self.f_min = f_min 70 | self.f_max = f_max 71 | self.cycle_lengths = cycle_lengths 72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 73 | self.last_f = 0.0 74 | self.verbosity_interval = verbosity_interval 75 | 76 | def find_in_interval(self, n): 77 | interval = 0 78 | for cl in self.cum_cycles[1:]: 79 | if n <= cl: 80 | return interval 81 | interval += 1 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: 88 | print( 89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 90 | f"current cycle {cycle}" 91 | ) 92 | if n < self.lr_warm_up_steps[cycle]: 93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 94 | cycle 95 | ] * n + self.f_start[cycle] 96 | self.last_f = f 97 | return f 98 | else: 99 | t = (n - self.lr_warm_up_steps[cycle]) / ( 100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] 101 | ) 102 | t = min(t, 1.0) 103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 104 | 1 + np.cos(t * np.pi) 105 | ) 106 | self.last_f = f 107 | return f 108 | 109 | def __call__(self, n, **kwargs): 110 | return self.schedule(n, **kwargs) 111 | 112 | 113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 114 | def schedule(self, n, **kwargs): 115 | cycle = self.find_in_interval(n) 116 | n = n - self.cum_cycles[cycle] 117 | if self.verbosity_interval > 0: 118 | if n % self.verbosity_interval == 0: 119 | print( 120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 121 | f"current cycle {cycle}" 122 | ) 123 | 124 | if n < self.lr_warm_up_steps[cycle]: 125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 126 | cycle 127 | ] * n + self.f_start[cycle] 128 | self.last_f = f 129 | return f 130 | else: 131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( 132 | self.cycle_lengths[cycle] - n 133 | ) / (self.cycle_lengths[cycle]) 134 | self.last_f = f 135 | return f 136 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | from .diffusion import DiffusionEngine 3 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | "target": "sgm.modules.GeneralConditioner", 5 | "params": {"emb_models": []}, 6 | } 7 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/autoencoding/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/losses/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "GeneralLPIPSWithDiscriminator", 3 | "LatentLPIPS", 4 | ] 5 | 6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator 7 | from .lpips import LatentLPIPS 8 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/losses/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ....util import default, instantiate_from_config 5 | from ..lpips.loss.lpips import LPIPS 6 | 7 | 8 | class LatentLPIPS(nn.Module): 9 | def __init__( 10 | self, 11 | decoder_config, 12 | perceptual_weight=1.0, 13 | latent_weight=1.0, 14 | scale_input_to_tgt_size=False, 15 | scale_tgt_to_input_size=False, 16 | perceptual_weight_on_inputs=0.0, 17 | ): 18 | super().__init__() 19 | self.scale_input_to_tgt_size = scale_input_to_tgt_size 20 | self.scale_tgt_to_input_size = scale_tgt_to_input_size 21 | self.init_decoder(decoder_config) 22 | self.perceptual_loss = LPIPS().eval() 23 | self.perceptual_weight = perceptual_weight 24 | self.latent_weight = latent_weight 25 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs 26 | 27 | def init_decoder(self, config): 28 | self.decoder = instantiate_from_config(config) 29 | if hasattr(self.decoder, "encoder"): 30 | del self.decoder.encoder 31 | 32 | def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): 33 | log = dict() 34 | loss = (latent_inputs - latent_predictions) ** 2 35 | log[f"{split}/latent_l2_loss"] = loss.mean().detach() 36 | image_reconstructions = None 37 | if self.perceptual_weight > 0.0: 38 | image_reconstructions = self.decoder.decode(latent_predictions) 39 | image_targets = self.decoder.decode(latent_inputs) 40 | perceptual_loss = self.perceptual_loss( 41 | image_targets.contiguous(), image_reconstructions.contiguous() 42 | ) 43 | loss = ( 44 | self.latent_weight * loss.mean() 45 | + self.perceptual_weight * perceptual_loss.mean() 46 | ) 47 | log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() 48 | 49 | if self.perceptual_weight_on_inputs > 0.0: 50 | image_reconstructions = default( 51 | image_reconstructions, self.decoder.decode(latent_predictions) 52 | ) 53 | if self.scale_input_to_tgt_size: 54 | image_inputs = torch.nn.functional.interpolate( 55 | image_inputs, 56 | image_reconstructions.shape[2:], 57 | mode="bicubic", 58 | antialias=True, 59 | ) 60 | elif self.scale_tgt_to_input_size: 61 | image_reconstructions = torch.nn.functional.interpolate( 62 | image_reconstructions, 63 | image_inputs.shape[2:], 64 | mode="bicubic", 65 | antialias=True, 66 | ) 67 | 68 | perceptual_loss2 = self.perceptual_loss( 69 | image_inputs.contiguous(), image_reconstructions.contiguous() 70 | ) 71 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() 72 | log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() 73 | return loss, log 74 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/autoencoding/lpips/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/lpips/loss/.gitignore: -------------------------------------------------------------------------------- 1 | vgg.pth -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/lpips/loss/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/autoencoding/lpips/loss/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/autoencoding/lpips/model/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find("BatchNorm") != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | 22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 23 | """Construct a PatchGAN discriminator 24 | Parameters: 25 | input_nc (int) -- the number of channels in input images 26 | ndf (int) -- the number of filters in the last conv layer 27 | n_layers (int) -- the number of conv layers in the discriminator 28 | norm_layer -- normalization layer 29 | """ 30 | super(NLayerDiscriminator, self).__init__() 31 | if not use_actnorm: 32 | norm_layer = nn.BatchNorm2d 33 | else: 34 | norm_layer = ActNorm 35 | if ( 36 | type(norm_layer) == functools.partial 37 | ): # no need to use bias as BatchNorm2d has affine parameters 38 | use_bias = norm_layer.func != nn.BatchNorm2d 39 | else: 40 | use_bias = norm_layer != nn.BatchNorm2d 41 | 42 | kw = 4 43 | padw = 1 44 | sequence = [ 45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 46 | nn.LeakyReLU(0.2, True), 47 | ] 48 | nf_mult = 1 49 | nf_mult_prev = 1 50 | for n in range(1, n_layers): # gradually increase the number of filters 51 | nf_mult_prev = nf_mult 52 | nf_mult = min(2**n, 8) 53 | sequence += [ 54 | nn.Conv2d( 55 | ndf * nf_mult_prev, 56 | ndf * nf_mult, 57 | kernel_size=kw, 58 | stride=2, 59 | padding=padw, 60 | bias=use_bias, 61 | ), 62 | norm_layer(ndf * nf_mult), 63 | nn.LeakyReLU(0.2, True), 64 | ] 65 | 66 | nf_mult_prev = nf_mult 67 | nf_mult = min(2**n_layers, 8) 68 | sequence += [ 69 | nn.Conv2d( 70 | ndf * nf_mult_prev, 71 | ndf * nf_mult, 72 | kernel_size=kw, 73 | stride=1, 74 | padding=padw, 75 | bias=use_bias, 76 | ), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True), 79 | ] 80 | 81 | sequence += [ 82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 83 | ] # output 1 channel prediction map 84 | self.main = nn.Sequential(*sequence) 85 | 86 | def forward(self, input): 87 | """Standard forward.""" 88 | return self.main(input) 89 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 10 | 11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 12 | 13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 14 | 15 | 16 | def download(url, local_path, chunk_size=1024): 17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 18 | with requests.get(url, stream=True) as r: 19 | total_size = int(r.headers.get("content-length", 0)) 20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 21 | with open(local_path, "wb") as f: 22 | for data in r.iter_content(chunk_size=chunk_size): 23 | if data: 24 | f.write(data) 25 | pbar.update(chunk_size) 26 | 27 | 28 | def md5_hash(path): 29 | with open(path, "rb") as f: 30 | content = f.read() 31 | return hashlib.md5(content).hexdigest() 32 | 33 | 34 | def get_ckpt_path(name, root, check=False): 35 | assert name in URL_MAP 36 | path = os.path.join(root, CKPT_MAP[name]) 37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 39 | download(URL_MAP[name], path) 40 | md5 = md5_hash(path) 41 | assert md5 == MD5_MAP[name], md5 42 | return path 43 | 44 | 45 | class ActNorm(nn.Module): 46 | def __init__( 47 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 48 | ): 49 | assert affine 50 | super().__init__() 51 | self.logdet = logdet 52 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 53 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 54 | self.allow_reverse_init = allow_reverse_init 55 | 56 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 57 | 58 | def initialize(self, input): 59 | with torch.no_grad(): 60 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 61 | mean = ( 62 | flatten.mean(1) 63 | .unsqueeze(1) 64 | .unsqueeze(2) 65 | .unsqueeze(3) 66 | .permute(1, 0, 2, 3) 67 | ) 68 | std = ( 69 | flatten.std(1) 70 | .unsqueeze(1) 71 | .unsqueeze(2) 72 | .unsqueeze(3) 73 | .permute(1, 0, 2, 3) 74 | ) 75 | 76 | self.loc.data.copy_(-mean) 77 | self.scale.data.copy_(1 / (std + 1e-6)) 78 | 79 | def forward(self, input, reverse=False): 80 | if reverse: 81 | return self.reverse(input) 82 | if len(input.shape) == 2: 83 | input = input[:, :, None, None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | _, _, height, width = input.shape 89 | 90 | if self.training and self.initialized.item() == 0: 91 | self.initialize(input) 92 | self.initialized.fill_(1) 93 | 94 | h = self.scale * (input + self.loc) 95 | 96 | if squeeze: 97 | h = h.squeeze(-1).squeeze(-1) 98 | 99 | if self.logdet: 100 | log_abs = torch.log(torch.abs(self.scale)) 101 | logdet = height * width * torch.sum(log_abs) 102 | logdet = logdet * torch.ones(input.shape[0]).to(input) 103 | return h, logdet 104 | 105 | return h 106 | 107 | def reverse(self, output): 108 | if self.training and self.initialized.item() == 0: 109 | if not self.allow_reverse_init: 110 | raise RuntimeError( 111 | "Initializing ActNorm in reverse direction is " 112 | "disabled by default. Use allow_reverse_init=True to enable." 113 | ) 114 | else: 115 | self.initialize(output) 116 | self.initialized.fill_(1) 117 | 118 | if len(output.shape) == 2: 119 | output = output[:, :, None, None] 120 | squeeze = True 121 | else: 122 | squeeze = False 123 | 124 | h = output / self.scale - self.loc 125 | 126 | if squeeze: 127 | h = h.squeeze(-1).squeeze(-1) 128 | return h 129 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * ( 14 | torch.mean(torch.nn.functional.softplus(-logits_real)) 15 | + torch.mean(torch.nn.functional.softplus(logits_fake)) 16 | ) 17 | return d_loss 18 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import \ 9 | DiagonalGaussianDistribution 10 | from .base import AbstractRegularizer 11 | 12 | 13 | class DiagonalGaussianRegularizer(AbstractRegularizer): 14 | def __init__(self, sample: bool = True): 15 | super().__init__() 16 | self.sample = sample 17 | 18 | def get_trainable_parameters(self) -> Any: 19 | yield from () 20 | 21 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 22 | log = dict() 23 | posterior = DiagonalGaussianDistribution(z) 24 | if self.sample: 25 | z = posterior.sample() 26 | else: 27 | z = posterior.mode() 28 | kl_loss = posterior.kl() 29 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 30 | log["kl_loss"] = kl_loss 31 | return z, log 32 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/autoencoding/regularizers/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class AbstractRegularizer(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def get_trainable_parameters(self) -> Any: 18 | raise NotImplementedError() 19 | 20 | 21 | class IdentityRegularizer(AbstractRegularizer): 22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 23 | return z, dict() 24 | 25 | def get_trainable_parameters(self) -> Any: 26 | yield from () 27 | 28 | 29 | def measure_perplexity( 30 | predicted_indices: torch.Tensor, num_centroids: int 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 33 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 34 | encodings = ( 35 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 36 | ) 37 | avg_probs = encodings.mean(0) 38 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 39 | cluster_use = torch.sum(avg_probs > 0) 40 | return perplexity, cluster_use 41 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...util import append_dims, instantiate_from_config 7 | from .denoiser_scaling import DenoiserScaling 8 | from .discretizer import Discretization 9 | 10 | 11 | class Denoiser(nn.Module): 12 | def __init__(self, scaling_config: Dict): 13 | super().__init__() 14 | 15 | self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) 16 | 17 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: 18 | return sigma 19 | 20 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: 21 | return c_noise 22 | 23 | def forward( 24 | self, 25 | network: nn.Module, 26 | input: torch.Tensor, 27 | sigma: torch.Tensor, 28 | cond: Dict, 29 | **additional_model_inputs, 30 | ) -> torch.Tensor: 31 | sigma = self.possibly_quantize_sigma(sigma) 32 | sigma_shape = sigma.shape 33 | sigma = append_dims(sigma, input.ndim) 34 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 35 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 36 | return ( 37 | network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out 38 | + input * c_skip 39 | ) 40 | 41 | 42 | class DiscreteDenoiser(Denoiser): 43 | def __init__( 44 | self, 45 | scaling_config: Dict, 46 | num_idx: int, 47 | discretization_config: Dict, 48 | do_append_zero: bool = False, 49 | quantize_c_noise: bool = True, 50 | flip: bool = True, 51 | ): 52 | super().__init__(scaling_config) 53 | self.discretization: Discretization = instantiate_from_config( 54 | discretization_config 55 | ) 56 | sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) 57 | self.register_buffer("sigmas", sigmas) 58 | self.quantize_c_noise = quantize_c_noise 59 | self.num_idx = num_idx 60 | 61 | def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: 62 | dists = sigma - self.sigmas[:, None] 63 | return dists.abs().argmin(dim=0).view(sigma.shape) 64 | 65 | def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: 66 | return self.sigmas[idx] 67 | 68 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: 69 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 70 | 71 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: 72 | if self.quantize_c_noise: 73 | return self.sigma_to_idx(c_noise) 74 | else: 75 | return c_noise 76 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | 4 | import torch 5 | 6 | 7 | class DenoiserScaling(ABC): 8 | @abstractmethod 9 | def __call__( 10 | self, sigma: torch.Tensor 11 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 12 | pass 13 | 14 | 15 | class EDMScaling: 16 | def __init__(self, sigma_data: float = 0.5): 17 | self.sigma_data = sigma_data 18 | 19 | def __call__( 20 | self, sigma: torch.Tensor 21 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 22 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 23 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 24 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 25 | c_noise = 0.25 * sigma.log() 26 | return c_skip, c_out, c_in, c_noise 27 | 28 | 29 | class EpsScaling: 30 | def __call__( 31 | self, sigma: torch.Tensor 32 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 33 | c_skip = torch.ones_like(sigma, device=sigma.device) 34 | c_out = -sigma 35 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 36 | c_noise = sigma.clone() 37 | return c_skip, c_out, c_in, c_noise 38 | 39 | 40 | class VScaling: 41 | def __call__( 42 | self, sigma: torch.Tensor 43 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 44 | c_skip = 1.0 / (sigma**2 + 1.0) 45 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 46 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 47 | c_noise = sigma.clone() 48 | return c_skip, c_out, c_in, c_noise 49 | 50 | 51 | class VScalingWithEDMcNoise(DenoiserScaling): 52 | def __call__( 53 | self, sigma: torch.Tensor 54 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 55 | c_skip = 1.0 / (sigma**2 + 1.0) 56 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 57 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 58 | c_noise = 0.25 * sigma.log() 59 | return c_skip, c_out, c_in, c_noise 60 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | def __call__(self, sigma): 6 | return torch.ones_like(sigma, device=sigma.device) 7 | 8 | 9 | class EDMWeighting: 10 | def __init__(self, sigma_data=0.5): 11 | self.sigma_data = sigma_data 12 | 13 | def __call__(self, sigma): 14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 15 | 16 | 17 | class VWeighting(EDMWeighting): 18 | def __init__(self): 19 | super().__init__(sigma_data=1.0) 20 | 21 | 22 | class EpsWeighting: 23 | def __call__(self, sigma): 24 | return sigma**-2.0 25 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps( 12 | num_substeps: int, max_step: int 13 | ) -> np.ndarray: 14 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 15 | 16 | 17 | class Discretization: 18 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False): 19 | sigmas = self.get_sigmas(n, device=device) 20 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 21 | return sigmas if not flip else torch.flip(sigmas, (0,)) 22 | 23 | @abstractmethod 24 | def get_sigmas(self, n, device): 25 | pass 26 | 27 | 28 | class EDMDiscretization(Discretization): 29 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): 30 | self.sigma_min = sigma_min 31 | self.sigma_max = sigma_max 32 | self.rho = rho 33 | 34 | def get_sigmas(self, n, device="cpu"): 35 | ramp = torch.linspace(0, 1, n, device=device) 36 | min_inv_rho = self.sigma_min ** (1 / self.rho) 37 | max_inv_rho = self.sigma_max ** (1 / self.rho) 38 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 39 | return sigmas 40 | 41 | 42 | class LegacyDDPMDiscretization(Discretization): 43 | def __init__( 44 | self, 45 | linear_start=0.00085, 46 | linear_end=0.0120, 47 | num_timesteps=1000, 48 | ): 49 | super().__init__() 50 | self.num_timesteps = num_timesteps 51 | betas = make_beta_schedule( 52 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end 53 | ) 54 | alphas = 1.0 - betas 55 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 56 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 57 | 58 | def get_sigmas(self, n, device="cpu"): 59 | if n < self.num_timesteps: 60 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 61 | alphas_cumprod = self.alphas_cumprod[timesteps] 62 | elif n == self.num_timesteps: 63 | alphas_cumprod = self.alphas_cumprod 64 | else: 65 | raise ValueError 66 | 67 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 68 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 69 | return torch.flip(sigmas, (0,)) 70 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from typing import Dict, List, Literal, Optional, Tuple, Union 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | from ...util import append_dims, default 9 | 10 | logpy = logging.getLogger(__name__) 11 | 12 | 13 | class Guider(ABC): 14 | @abstractmethod 15 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 16 | pass 17 | 18 | def prepare_inputs( 19 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict 20 | ) -> Tuple[torch.Tensor, float, Dict]: 21 | pass 22 | 23 | 24 | class VanillaCFG(Guider): 25 | def __init__(self, scale: float): 26 | self.scale = scale 27 | 28 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 29 | x_u, x_c = x.chunk(2) 30 | x_pred = x_u + self.scale * (x_c - x_u) 31 | return x_pred 32 | 33 | def prepare_inputs(self, x, s, c, uc): 34 | c_out = dict() 35 | 36 | for k in c: 37 | if k in ["vector", "crossattn", "concat"]: 38 | c_out[k] = torch.cat((uc[k], c[k]), 0) 39 | else: 40 | assert c[k] == uc[k] 41 | c_out[k] = c[k] 42 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 43 | 44 | 45 | class IdentityGuider(Guider): 46 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 47 | return x 48 | 49 | def prepare_inputs( 50 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict 51 | ) -> Tuple[torch.Tensor, float, Dict]: 52 | c_out = dict() 53 | 54 | for k in c: 55 | c_out[k] = c[k] 56 | 57 | return x, s, c_out 58 | 59 | 60 | class LinearPredictionGuider(Guider): 61 | def __init__( 62 | self, 63 | max_scale: float, 64 | num_frames: int, 65 | min_scale: float = 1.0, 66 | additional_cond_keys: Optional[Union[List[str], str]] = None, 67 | ): 68 | self.min_scale = min_scale 69 | self.max_scale = max_scale 70 | self.num_frames = num_frames 71 | self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) 72 | 73 | additional_cond_keys = default(additional_cond_keys, []) 74 | if isinstance(additional_cond_keys, str): 75 | additional_cond_keys = [additional_cond_keys] 76 | self.additional_cond_keys = additional_cond_keys 77 | 78 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 79 | x_u, x_c = x.chunk(2) 80 | 81 | x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) 82 | x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) 83 | scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) 84 | scale = append_dims(scale, x_u.ndim).to(x_u.device) 85 | 86 | return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") 87 | 88 | def prepare_inputs( 89 | self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict 90 | ) -> Tuple[torch.Tensor, torch.Tensor, dict]: 91 | c_out = dict() 92 | 93 | for k in c: 94 | if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: 95 | c_out[k] = torch.cat((uc[k], c[k]), 0) 96 | else: 97 | assert c[k] == uc[k] 98 | c_out[k] = c[k] 99 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 100 | 101 | 102 | class TrianglePredictionGuider(LinearPredictionGuider): 103 | def __init__( 104 | self, 105 | max_scale: float, 106 | num_frames: int, 107 | min_scale: float = 1.0, 108 | period: Optional[float] = 1.0, 109 | period_fusing: Literal["mean", "multiply", "max"] = "max", 110 | additional_cond_keys: Optional[Union[List[str], str]] = None, 111 | ): 112 | super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) 113 | values = torch.linspace(0, 1, num_frames) 114 | # Constructs a triangle wave 115 | if isinstance(period, float): 116 | period = [period] 117 | 118 | scales = [] 119 | for p in period: 120 | scales.append(self.triangle_wave(values, p)) 121 | 122 | if period_fusing == "mean": 123 | scale = sum(scales) / len(period) 124 | elif period_fusing == "multiply": 125 | scale = torch.prod(torch.stack(scales), dim=0) 126 | elif period_fusing == "max": 127 | scale = torch.max(torch.stack(scales), dim=0).values 128 | self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) 129 | 130 | def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor: 131 | return 2 * (values / period - torch.floor(values / period + 0.5)).abs() 132 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 7 | from ...modules.encoders.modules import GeneralConditioner 8 | from ...util import append_dims, instantiate_from_config 9 | from .denoiser import Denoiser 10 | 11 | 12 | class StandardDiffusionLoss(nn.Module): 13 | def __init__( 14 | self, 15 | sigma_sampler_config: dict, 16 | loss_weighting_config: dict, 17 | loss_type: str = "l2", 18 | offset_noise_level: float = 0.0, 19 | batch2model_keys: Optional[Union[str, List[str]]] = None, 20 | ): 21 | super().__init__() 22 | 23 | assert loss_type in ["l2", "l1", "lpips"] 24 | 25 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 26 | self.loss_weighting = instantiate_from_config(loss_weighting_config) 27 | 28 | self.loss_type = loss_type 29 | self.offset_noise_level = offset_noise_level 30 | 31 | if loss_type == "lpips": 32 | self.lpips = LPIPS().eval() 33 | 34 | if not batch2model_keys: 35 | batch2model_keys = [] 36 | 37 | if isinstance(batch2model_keys, str): 38 | batch2model_keys = [batch2model_keys] 39 | 40 | self.batch2model_keys = set(batch2model_keys) 41 | 42 | def get_noised_input( 43 | self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor 44 | ) -> torch.Tensor: 45 | noised_input = input + noise * sigmas_bc 46 | return noised_input 47 | 48 | def forward( 49 | self, 50 | network: nn.Module, 51 | denoiser: Denoiser, 52 | conditioner: GeneralConditioner, 53 | input: torch.Tensor, 54 | batch: Dict, 55 | ) -> torch.Tensor: 56 | cond = conditioner(batch) 57 | return self._forward(network, denoiser, cond, input, batch) 58 | 59 | def _forward( 60 | self, 61 | network: nn.Module, 62 | denoiser: Denoiser, 63 | cond: Dict, 64 | input: torch.Tensor, 65 | batch: Dict, 66 | ) -> Tuple[torch.Tensor, Dict]: 67 | additional_model_inputs = { 68 | key: batch[key] for key in self.batch2model_keys.intersection(batch) 69 | } 70 | sigmas = self.sigma_sampler(input.shape[0]).to(input) 71 | 72 | noise = torch.randn_like(input) 73 | if self.offset_noise_level > 0.0: 74 | offset_shape = ( 75 | (input.shape[0], 1, input.shape[2]) 76 | if self.n_frames is not None 77 | else (input.shape[0], input.shape[1]) 78 | ) 79 | noise = noise + self.offset_noise_level * append_dims( 80 | torch.randn(offset_shape, device=input.device), 81 | input.ndim, 82 | ) 83 | sigmas_bc = append_dims(sigmas, input.ndim) 84 | noised_input = self.get_noised_input(sigmas_bc, noise, input) 85 | 86 | model_output = denoiser( 87 | network, noised_input, sigmas, cond, **additional_model_inputs 88 | ) 89 | w = append_dims(self.loss_weighting(sigmas), input.ndim) 90 | return self.get_loss(model_output, input, w) 91 | 92 | def get_loss(self, model_output, target, w): 93 | if self.loss_type == "l2": 94 | return torch.mean( 95 | (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 96 | ) 97 | elif self.loss_type == "l1": 98 | return torch.mean( 99 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 100 | ) 101 | elif self.loss_type == "lpips": 102 | loss = self.lpips(model_output, target).reshape(-1) 103 | return loss 104 | else: 105 | raise NotImplementedError(f"Unknown loss type {self.loss_type}") 106 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/loss_weighting.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class DiffusionLossWeighting(ABC): 7 | @abstractmethod 8 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 9 | pass 10 | 11 | 12 | class UnitWeighting(DiffusionLossWeighting): 13 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 14 | return torch.ones_like(sigma, device=sigma.device) 15 | 16 | 17 | class EDMWeighting(DiffusionLossWeighting): 18 | def __init__(self, sigma_data: float = 0.5): 19 | self.sigma_data = sigma_data 20 | 21 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 22 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 23 | 24 | 25 | class VWeighting(EDMWeighting): 26 | def __init__(self): 27 | super().__init__(sigma_data=1.0) 28 | 29 | 30 | class EpsWeighting(DiffusionLossWeighting): 31 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 32 | return sigma**-2.0 33 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | 6 | 7 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 8 | if order - 1 > i: 9 | raise ValueError(f"Order {order} too high for step {i}") 10 | 11 | def fn(tau): 12 | prod = 1.0 13 | for k in range(order): 14 | if j == k: 15 | continue 16 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 17 | return prod 18 | 19 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 20 | 21 | 22 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 23 | if not eta: 24 | return sigma_to, 0.0 25 | sigma_up = torch.minimum( 26 | sigma_to, 27 | eta 28 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 29 | ) 30 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 31 | return sigma_down, sigma_up 32 | 33 | 34 | def to_d(x, sigma, denoised): 35 | return (x - denoised) / append_dims(sigma, x.ndim) 36 | 37 | 38 | def to_neg_log_sigma(sigma): 39 | return sigma.log().neg() 40 | 41 | 42 | def to_sigma(neg_log_sigma): 43 | return neg_log_sigma.neg().exp() 44 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...util import default, instantiate_from_config 4 | 5 | 6 | class EDMSampling: 7 | def __init__(self, p_mean=-1.2, p_std=1.2): 8 | self.p_mean = p_mean 9 | self.p_std = p_std 10 | 11 | def __call__(self, n_samples, rand=None): 12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 13 | return log_sigma.exp() 14 | 15 | 16 | class DiscreteSampling: 17 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): 18 | self.num_idx = num_idx 19 | self.sigmas = instantiate_from_config(discretization_config)( 20 | num_idx, do_append_zero=do_append_zero, flip=flip 21 | ) 22 | 23 | def idx_to_sigma(self, idx): 24 | return self.sigmas[idx] 25 | 26 | def __call__(self, n_samples, rand=None): 27 | idx = default( 28 | rand, 29 | torch.randint(0, self.num_idx, (n_samples,)), 30 | ) 31 | return self.idx_to_sigma(idx) 32 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | def __init__(self, diffusion_model, compile_model: bool = False): 10 | super().__init__() 11 | compile = ( 12 | torch.compile 13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) 14 | and compile_model 15 | else lambda x: x 16 | ) 17 | self.diffusion_model = compile(diffusion_model) 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.diffusion_model(*args, **kwargs) 21 | 22 | 23 | class OpenAIWrapper(IdentityWrapper): 24 | def forward( 25 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 26 | ) -> torch.Tensor: 27 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 28 | return self.diffusion_model( 29 | x, 30 | timesteps=t, 31 | context=c.get("crossattn", None), 32 | y=c.get("vector", None), 33 | **kwargs, 34 | ) 35 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /third_party/image_generator/sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/LaRa/c32ae7704faee3423634ba226eefc4dda644f8f5/third_party/image_generator/sgm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /third_party/image_generator/tests/inference/test_inference.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from PIL import Image 3 | import pytest 4 | from pytest import fixture 5 | import torch 6 | from typing import Tuple 7 | 8 | from sgm.inference.api import ( 9 | model_specs, 10 | SamplingParams, 11 | SamplingPipeline, 12 | Sampler, 13 | ModelArchitecture, 14 | ) 15 | import sgm.inference.helpers as helpers 16 | 17 | 18 | @pytest.mark.inference 19 | class TestInference: 20 | @fixture(scope="class", params=model_specs.keys()) 21 | def pipeline(self, request) -> SamplingPipeline: 22 | pipeline = SamplingPipeline(request.param) 23 | yield pipeline 24 | del pipeline 25 | torch.cuda.empty_cache() 26 | 27 | @fixture( 28 | scope="class", 29 | params=[ 30 | [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER], 31 | [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], 32 | ], 33 | ids=["SDXL_V1", "SDXL_V0_9"], 34 | ) 35 | def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]: 36 | base_pipeline = SamplingPipeline(request.param[0]) 37 | refiner_pipeline = SamplingPipeline(request.param[1]) 38 | yield base_pipeline, refiner_pipeline 39 | del base_pipeline 40 | del refiner_pipeline 41 | torch.cuda.empty_cache() 42 | 43 | def create_init_image(self, h, w): 44 | image_array = numpy.random.rand(h, w, 3) * 255 45 | image = Image.fromarray(image_array.astype("uint8")).convert("RGB") 46 | return helpers.get_input_image_tensor(image) 47 | 48 | @pytest.mark.parametrize("sampler_enum", Sampler) 49 | def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum): 50 | output = pipeline.text_to_image( 51 | params=SamplingParams(sampler=sampler_enum.value, steps=10), 52 | prompt="A professional photograph of an astronaut riding a pig", 53 | negative_prompt="", 54 | samples=1, 55 | ) 56 | 57 | assert output is not None 58 | 59 | @pytest.mark.parametrize("sampler_enum", Sampler) 60 | def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): 61 | output = pipeline.image_to_image( 62 | params=SamplingParams(sampler=sampler_enum.value, steps=10), 63 | image=self.create_init_image(pipeline.specs.height, pipeline.specs.width), 64 | prompt="A professional photograph of an astronaut riding a pig", 65 | negative_prompt="", 66 | samples=1, 67 | ) 68 | assert output is not None 69 | 70 | @pytest.mark.parametrize("sampler_enum", Sampler) 71 | @pytest.mark.parametrize( 72 | "use_init_image", [True, False], ids=["img2img", "txt2img"] 73 | ) 74 | def test_sdxl_with_refiner( 75 | self, 76 | sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], 77 | sampler_enum, 78 | use_init_image, 79 | ): 80 | base_pipeline, refiner_pipeline = sdxl_pipelines 81 | if use_init_image: 82 | output = base_pipeline.image_to_image( 83 | params=SamplingParams(sampler=sampler_enum.value, steps=10), 84 | image=self.create_init_image( 85 | base_pipeline.specs.height, base_pipeline.specs.width 86 | ), 87 | prompt="A professional photograph of an astronaut riding a pig", 88 | negative_prompt="", 89 | samples=1, 90 | return_latents=True, 91 | ) 92 | else: 93 | output = base_pipeline.text_to_image( 94 | params=SamplingParams(sampler=sampler_enum.value, steps=10), 95 | prompt="A professional photograph of an astronaut riding a pig", 96 | negative_prompt="", 97 | samples=1, 98 | return_latents=True, 99 | ) 100 | 101 | assert isinstance(output, (tuple, list)) 102 | samples, samples_z = output 103 | assert samples is not None 104 | assert samples_z is not None 105 | refiner_pipeline.refiner( 106 | params=SamplingParams(sampler=sampler_enum.value, steps=10), 107 | image=samples_z, 108 | prompt="A professional photograph of an astronaut riding a pig", 109 | negative_prompt="", 110 | samples=1, 111 | ) 112 | -------------------------------------------------------------------------------- /tools/camera.py: -------------------------------------------------------------------------------- 1 | import math, torch 2 | from dataLoader.utils import build_rays, fov_to_ixt 3 | 4 | def getProjectionMatrix(znear, zfar, fovX, fovY): 5 | 6 | tanHalfFovY = math.tan((fovY / 2)) 7 | tanHalfFovX = math.tan((fovX / 2)) 8 | 9 | P = torch.zeros(4, 4) 10 | 11 | z_sign = 1.0 12 | 13 | P[0, 0] = 1 / tanHalfFovX 14 | P[1, 1] = 1 / tanHalfFovY 15 | P[3, 2] = z_sign 16 | P[2, 2] = z_sign * zfar / (zfar - znear) 17 | P[2, 3] = -(zfar * znear) / (zfar - znear) 18 | return P 19 | 20 | 21 | class MiniCam: 22 | def __init__(self, c2w, width, height, fovy, fovx, znear, zfar): 23 | # c2w (pose) should be in NeRF convention. 24 | 25 | self.image_width = width 26 | self.image_height = height 27 | self.FoVy = fovy 28 | self.FoVx = fovx 29 | self.znear = znear 30 | self.zfar = zfar 31 | 32 | 33 | w2c = torch.inverse(c2w) 34 | 35 | # rectify... 36 | # w2c[1:3, :3] *= -1 37 | # w2c[:3, 3] *= -1 38 | 39 | self.view_world_transform = c2w 40 | self.world_view_transform = w2c.transpose(0, 1) 41 | self.projection_matrix = getProjectionMatrix( 42 | znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy 43 | ).transpose(0, 1) 44 | 45 | self.full_proj_transform = (self.world_view_transform @ self.projection_matrix).to(torch.float32) 46 | self.camera_center = -c2w[:3, 3] 47 | 48 | def to_device(self, device): 49 | self.world_view_transform = self.world_view_transform.to(device) 50 | self.projection_matrix = self.projection_matrix.to(device) 51 | self.camera_center = self.camera_center.to(device) 52 | self.full_proj_transform = self.full_proj_transform.to(device) 53 | 54 | def get_rays(self): 55 | ixt = fov_to_ixt(torch.tensor((self.FoVx,self.FoVy)), torch.tensor((self.image_width,self.image_height))) 56 | rays = build_rays(self.view_world_transform.cpu().numpy()[None], ixt[None], self.image_height, self.image_width) 57 | return torch.from_numpy(rays) -------------------------------------------------------------------------------- /tools/depth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def abs_error(depth_pred, depth_gt, mask): 4 | depth_pred, depth_gt = depth_pred[mask], depth_gt[mask] 5 | 6 | err = depth_pred - depth_gt 7 | return np.abs(err) if type(depth_pred) is np.ndarray else err.abs() 8 | 9 | def acc_threshold(depth_pred, depth_gt, mask, threshold): 10 | """ 11 | computes the percentage of pixels whose depth error is less than @threshold 12 | """ 13 | errors = abs_error(depth_pred, depth_gt, mask) 14 | acc_mask = errors < threshold 15 | return acc_mask.astype('float') if type(depth_pred) is np.ndarray else acc_mask.float() -------------------------------------------------------------------------------- /tools/download_dataset.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | import os 3 | import shutil 4 | import argparse 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | def download_folder(repo_id, folder, local_dir, files, repo_type="dataset"):# model, dataset, or space. 8 | 9 | def download_file(file): 10 | cache_file_path = hf_hub_download( 11 | repo_id=repo_id, 12 | filename=file, 13 | subfolder=folder, 14 | # repo_type=repo_type, 15 | cache_dir=f'{local_dir}/{folder}/_temp', 16 | ) 17 | 18 | target_path = f'{local_dir}/{folder}/{file}' 19 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 20 | os.system(f'mv {os.path.realpath(cache_file_path)} {target_path}') 21 | 22 | with ThreadPoolExecutor() as executor: 23 | futures = [] 24 | for file in files: 25 | futures.append(executor.submit(download_file, file)) 26 | for future in futures: 27 | future.result() 28 | 29 | 30 | # Example usage 31 | repo_id = "apchen/LaRa" # Replace with your repository ID 32 | folder_path = "dataset" # Replace with the path to the folder in the repository 33 | local_dir = "." # Replace with your local destination directory 34 | 35 | gso_list = ['GSO.zip'] 36 | co3d_list = ['Co3D/co3d_hydrant.h5','Co3D/co3d_teddybear.h5'] 37 | gobjaverse_list = [f'gobjaverse/gobjaverse_part_{i+1:02d}.h5' for i in range(32)] + ['gobjaverse/gobjaverse.h5'] 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser(description="download files.") 41 | parser.add_argument("dtype", type=str, default="gso", help="one of [gso,co3d,objaverse,all]") 42 | 43 | args = parser.parse_args() 44 | 45 | if "gso" == args.dtype: 46 | # download_folder(repo_id, folder_path, local_dir,gso_list) 47 | os.system(f'unzip {local_dir}/{folder_path}/{gso_list[0]} -d {local_dir}/{folder_path}') 48 | os.system(f'rm {local_dir}/{folder_path}/{gso_list[0]}') 49 | elif "co3d" == args.dtype: 50 | download_folder(repo_id, folder_path, local_dir,co3d_list) 51 | elif "objaverse" == args.dtype: 52 | download_folder(repo_id, folder_path, local_dir, gobjaverse_list) 53 | elif "all" == args.dtype: 54 | download_folder(repo_id, folder_path, local_dir,gso_list+co3d_list+gobjaverse_list) 55 | os.system(f'unzip {local_dir}/{folder_path}/{gso_list[0]} -d {local_dir}/{folder_path}') 56 | os.system(f'rm {local_dir}/{folder_path}/{gso_list[0]}') 57 | 58 | # shutil.rmtree(f'{local_dir}/{folder_path}/_temp') 59 | -------------------------------------------------------------------------------- /tools/download_objaverse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import os, sys, json 4 | from multiprocessing import Pool 5 | 6 | def download_url(item): 7 | global save_dir 8 | oss_full_dir = os.path.join("https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/objaverse_tar", item+".tar") 9 | os.system("wget -P {} {}".format(os.path.join(save_dir, item.split("/")[0]), oss_full_dir)) 10 | 11 | def get_all_folders(root): 12 | all_folders = [] 13 | categrey = os.listdir(root) 14 | for item in categrey: 15 | if not os.path.isdir(f'{root}/{item}'): 16 | continue 17 | folders = os.listdir(f'{root}/{item}') 18 | all_folders += [f'{root}/{item}/{folder}' for folder in folders] 19 | return all_folders 20 | 21 | def folder_to_json(exist_files): 22 | files = [] 23 | for item in exist_files: 24 | split = item.split('/')[-2:] 25 | files.append(f'{split[0]}/{split[1][:-4]}') 26 | return files 27 | 28 | def filterout_existing(json, exist_files): 29 | for item in exist_files: 30 | json.remove(item) 31 | return json 32 | 33 | if __name__=="__main__": 34 | # download_gobjaverse_280k index file 35 | # wget https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/gobjaverse_280k.json 36 | assert len(sys.argv) == 4, "eg: python download_objaverse.py ./data /path/to/json_file 10" 37 | save_dir = str(sys.argv[1]) 38 | json_file = str(sys.argv[2]) 39 | n_threads = int(sys.argv[3]) 40 | 41 | data = json.load(open(json_file))[:100] 42 | 43 | exist_files = get_all_folders(save_dir) 44 | exist_files = folder_to_json(exist_files) 45 | 46 | print(len(data)) 47 | data = filterout_existing(data, exist_files) 48 | print(len(data)) 49 | 50 | p = Pool(n_threads) 51 | p.map(download_url, data) 52 | -------------------------------------------------------------------------------- /tools/hdf5_split_merge.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import argparse 3 | import os 4 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 5 | 6 | def split_hdf5_file(input_file, output_prefix, num_splits): 7 | with h5py.File(input_file, 'r') as f: 8 | keys = sorted(list(f.keys())) 9 | chunk_size = len(keys) // num_splits 10 | 11 | def write_chunk(i, keys_chunk): 12 | output_file = f"{output_prefix}_part_{i+1}.h5" 13 | with h5py.File(output_file, 'w') as out_f: 14 | for key in keys_chunk: 15 | f.copy(key, out_f) 16 | 17 | with ThreadPoolExecutor() as executor: 18 | futures = [] 19 | for i in range(16): 20 | keys_chunk = keys[i*chunk_size: (i+1)*chunk_size] 21 | futures.append(executor.submit(write_chunk, i, keys_chunk)) 22 | for future in futures: 23 | future.result() 24 | 25 | print(f"Split into {num_splits} files with prefix '{output_prefix}'.") 26 | 27 | def merge_hdf5_files(output_file, input_files): 28 | with h5py.File(output_file, 'w') as out_f: 29 | def copy_data(input_file): 30 | with h5py.File(input_file, 'r') as in_f: 31 | for key in in_f.keys(): 32 | in_f.copy(key, out_f) 33 | 34 | with ThreadPoolExecutor() as executor: 35 | futures = [executor.submit(copy_data, input_file) for input_file in input_files] 36 | for future in futures: 37 | future.result() 38 | 39 | print(f"Merged files into '{output_file}'.") 40 | 41 | def get_absolute_paths(directory, prefix): 42 | return [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith(prefix)] 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser(description="Split and merge HDF5 files.") 47 | 48 | subparsers = parser.add_subparsers(dest="command", required=True) 49 | 50 | split_parser = subparsers.add_parser("split", help="Split an HDF5 file into multiple files.") 51 | split_parser.add_argument("input_file", type=str, help="Input HDF5 file to split.") 52 | split_parser.add_argument("output_prefix", type=str, help="Output prefix for split files.") 53 | split_parser.add_argument("num_splits", type=int, help="Number of splits.") 54 | 55 | merge_parser = subparsers.add_parser("merge", help="Merge multiple HDF5 files into one file.") 56 | merge_parser.add_argument("output_file", type=str, help="Output HDF5 file to create.") 57 | merge_parser.add_argument("file_prefix", type=str, help="Input HDF5 files to merge.") 58 | 59 | args = parser.parse_args() 60 | 61 | if args.command == "split": 62 | split_hdf5_file(args.input_file, args.output_prefix, args.num_splits) 63 | elif args.command == "merge": 64 | input_files = get_absolute_paths(args.input_directory, args.file_prefix) 65 | merge_hdf5_files(args.output_file, input_files) 66 | -------------------------------------------------------------------------------- /tools/meshRender.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import mitsuba as mi 3 | from tqdm import tqdm 4 | mi.set_variant('cuda_ad_rgb', 'llvm_ad_rgb') 5 | sys.path.append(os.path.join(os.path.dirname(__file__), "lib")) 6 | 7 | import numpy as np 8 | 9 | def render_mesh(cams, mesh_path, spp = 512, white_bg=True): 10 | 11 | image_width = cams[0].image_width 12 | image_height = cams[0].image_height 13 | 14 | mesh_type = os.path.splitext(mesh_path)[1][1:] 15 | sdf_scene = mi.load_file("configs/render/scene.xml", resx=image_width, resy=image_height, mesh_path=mesh_path, mesh_type=mesh_type, 16 | integrator_file="configs/render/integrator_path.xml", update_scene=False, spp=spp, max_depth=8) 17 | 18 | imgs = [] 19 | pbar = tqdm(total=len(cams), desc='Files', position=0) 20 | b2c = np.array([[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32) 21 | for cam in cams: 22 | c2w, fov = cam.view_world_transform.numpy(), cam.FoVx 23 | fov = np.degrees(fov) 24 | image_width = cam.image_width 25 | image_height = cam.image_height 26 | 27 | to_world = c2w @ b2c 28 | to_world_transform = mi.ScalarTransform4f(to_world.tolist()) 29 | 30 | sensor = mi.load_dict({ 31 | 'type': 'perspective', 32 | 'fov': fov, 'sampler': {'type': 'independent'}, 33 | 'film': {'type': 'hdrfilm', 'width': image_width, 'height': image_height, 34 | 'pixel_filter': {'type': 'gaussian'}, 'pixel_format': 'rgba'}, 35 | 'to_world': to_world_transform 36 | }) 37 | 38 | img = mi.render(sdf_scene, sensor=sensor, spp=spp) 39 | img = mi.Bitmap(img).convert(mi.Bitmap.PixelFormat.RGBA, mi.Struct.Type.UInt8, srgb_gamma=True) 40 | # img.write(f'123.png') 41 | img = np.array(img, copy=False) 42 | if white_bg: 43 | img = img.astype(np.float32)/255 44 | img = img[...,:3]*img[...,3:] + (1.0-img[...,3:])*np.array([0.722,0.376,0.161]) 45 | img = np.round(img*255).astype('uint8') 46 | 47 | imgs.append(img) 48 | pbar.update(1) 49 | pbar.set_description("Mesh extraction Done. Rendering *_mesh.mp4: ") 50 | 51 | return np.stack(imgs) -------------------------------------------------------------------------------- /train_lightning.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | n_thread = 2 4 | os.environ["MKL_NUM_THREADS"] = f"{n_thread}" 5 | os.environ["NUMEXPR_NUM_THREADS"] = f"{n_thread}" 6 | os.environ["OMP_NUM_THREADS"] = f"4" 7 | os.environ["VECLIB_MAXIMUM_THREADS"] = f"{n_thread}" 8 | os.environ["OPENBLAS_NUM_THREADS"] = f"{n_thread}" 9 | 10 | 11 | import torch 12 | from dataLoader import dataset_dict 13 | from omegaconf import OmegaConf 14 | 15 | from lightning.system import system 16 | from torch.utils.data import DataLoader 17 | import pytorch_lightning as L 18 | 19 | from datetime import datetime 20 | 21 | 22 | from pytorch_lightning.loggers import TensorBoardLogger 23 | from pytorch_lightning.loggers import WandbLogger 24 | from pytorch_lightning.callbacks import ModelCheckpoint 25 | from pytorch_lightning.strategies import DDPStrategy 26 | 27 | def main(cfg): 28 | 29 | torch.set_float32_matmul_precision('medium') 30 | torch.autograd.set_detect_anomaly(True) 31 | print("Using PyTorch {} and Lightning {}".format(torch.__version__, L.__version__)) 32 | 33 | # data loader 34 | train_dataset = dataset_dict[cfg.train_dataset.dataset_name] 35 | train_loader = DataLoader(train_dataset(cfg.train_dataset), 36 | batch_size= cfg.train.batch_size, 37 | num_workers= 8, 38 | shuffle=True, 39 | pin_memory=False) 40 | val_dataset = dataset_dict[cfg.test_dataset.dataset_name] 41 | val_loader = DataLoader(val_dataset(cfg.test_dataset), 42 | batch_size=cfg.test.batch_size, 43 | num_workers=2, 44 | shuffle=True, 45 | pin_memory=False) 46 | 47 | # build logger 48 | project_name = cfg.exp_name.split("/")[0] 49 | exp_name = cfg.exp_name.split("/")[1] 50 | 51 | if cfg.logger.name == "tensorboard": 52 | logger = TensorBoardLogger(save_dir=cfg.logger.dir, name=exp_name) 53 | elif cfg.logger.name == "wandb": 54 | os.environ["WANDB__SERVICE_WAIT"] = "300" 55 | logger = WandbLogger(name=exp_name,project=project_name, save_dir=cfg.logger.dir, entity="large-reconstruction-model") 56 | 57 | # Set up ModelCheckpoint callback 58 | checkpoint_callback = ModelCheckpoint( 59 | dirpath=cfg.logger.dir, # Path where checkpoints will be saved 60 | filename='{epoch}', # Filename for the checkpoints 61 | # save_top_k=1, # Set to -1 to save all checkpoints 62 | every_n_epochs=5, # Save a checkpoint every K epochs 63 | save_on_train_epoch_end=True, # Ensure it saves at the end of an epoch, not the beginning 64 | ) 65 | 66 | my_system = system(cfg) 67 | 68 | trainer = L.Trainer(devices=cfg.gpu_id, 69 | num_nodes=1, 70 | max_epochs=cfg.train.n_epoch, 71 | accelerator='gpu', 72 | strategy=DDPStrategy(find_unused_parameters=True), 73 | accumulate_grad_batches=2, 74 | logger=logger, 75 | gradient_clip_val=0.5, 76 | precision="bf16-mixed", 77 | callbacks=[checkpoint_callback], 78 | check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, 79 | limit_val_batches=cfg.train.limit_val_batches, # Run on only 10% of the validation data 80 | limit_train_batches=cfg.train.limit_train_batches, 81 | ) 82 | 83 | 84 | t0 = datetime.now() 85 | trainer.fit( 86 | my_system, 87 | train_dataloaders=train_loader, 88 | val_dataloaders=val_loader, 89 | ckpt_path=cfg.model.ckpt_path 90 | ) 91 | 92 | dt = datetime.now() - t0 93 | print('Training took {}'.format(dt)) 94 | 95 | 96 | if __name__ == '__main__': 97 | 98 | base_conf = OmegaConf.load('configs/base.yaml') 99 | 100 | cli_conf = OmegaConf.from_cli() 101 | cfg = OmegaConf.merge(base_conf, cli_conf) 102 | 103 | main(cfg) --------------------------------------------------------------------------------