├── .gitignore ├── LICENSE ├── README.md ├── acce.yaml ├── app.py ├── configs ├── nf7_v3_SNR_rd_size_stroke.yaml ├── nf7_v3_SNR_rd_size_stroke_train.yaml ├── specs_objaverse_total.json ├── stage2-v2-snr.yaml └── stage2-v2-snr_train.yaml ├── examples ├── 3D卡通狗.webp ├── astronaut.webp ├── bulldog.webp ├── ghost-eating-burger.webp ├── kunkun.webp ├── 万圣南瓜.webp ├── 人物骑马.webp ├── 初音未来玩偶.webp ├── 卡通恐龙.webp ├── 卡通手枪截图.webp ├── 卡通猫.webp ├── 卡通蘑菇套装.webp ├── 可爱玄策.webp ├── 大头泡泡马特.webp ├── 彩色蘑菇.webp ├── 彩色蘑菇2.webp ├── 恐龙套装.webp ├── 手办.webp ├── 机械狗裁切.webp ├── 林克.webp ├── 植物1.webp ├── 武器-剑.webp ├── 毛线衣.webp ├── 海龟.webp ├── 猫人.webp ├── 猫头鹰.webp ├── 玩具兔.webp ├── 玩具熊.webp ├── 玩具猪.webp ├── 玫瑰.webp ├── 皮卡丘.webp ├── 皮鞋.webp ├── 石头.webp ├── 石头哆啦A梦.webp ├── 红玩具猪.webp ├── 翅膀道具.webp ├── 茶壶.webp ├── 草系精灵.webp ├── 蓝色小怪物.webp ├── 蓝色泡泡马特.webp ├── 蓝色猫.webp ├── 赛博朋克-男.webp ├── 路灯.webp └── 运动系手办.webp ├── imagedream ├── __init__.py ├── camera_utils.py ├── configs │ ├── sd_v2_base_ipmv.yaml │ ├── sd_v2_base_ipmv_ch8.yaml │ ├── sd_v2_base_ipmv_chin8.yaml │ ├── sd_v2_base_ipmv_chin8_zero_snr.yaml │ ├── sd_v2_base_ipmv_local.yaml │ └── sd_v2_base_ipmv_zero_SNR.yaml ├── ldm │ ├── __init__.py │ ├── interface.py │ ├── models │ │ ├── __init__.py │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ └── ddim.py │ ├── modules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── adaptors.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ └── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ └── util.py └── model_zoo.py ├── inference.py ├── launch_train.sh ├── libs ├── base_utils.py ├── data.py └── sample.py ├── mesh.py ├── model ├── __init__.py ├── archs │ ├── __init__.py │ ├── decoders │ │ ├── __init__.py │ │ └── shape_texture_net.py │ ├── mlp_head.py │ └── unet.py └── crm │ └── model.py ├── pipelines.py ├── requirements.txt ├── run.py ├── train.py ├── train_examples ├── 0011662ee0fc4b4481bfd28314d154c1 │ ├── 000.png │ ├── 001.png │ ├── 002.png │ ├── 003.png │ ├── 004.png │ ├── 005.png │ ├── xyz_new_000.png │ ├── xyz_new_001.png │ ├── xyz_new_002.png │ ├── xyz_new_003.png │ ├── xyz_new_004.png │ └── xyz_new_005.png └── caption.csv ├── train_stage2.py └── util ├── __init__.py ├── flexicubes.py ├── flexicubes_geometry.py ├── renderer.py ├── tables.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | out/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 TSAIL group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Reconstruction Model 2 | 3 | Official implementation for *CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model*. 4 | 5 | **CRM is a feed-forward model which can generate 3D textured mesh in 10 seconds.** 6 | 7 | ## [Project Page](https://ml.cs.tsinghua.edu.cn/~zhengyi/CRM/) | [Arxiv](https://arxiv.org/abs/2403.05034) | [HF-Demo](https://huggingface.co/spaces/Zhengyi/CRM) | [Weights](https://huggingface.co/Zhengyi/CRM) 8 | 9 | https://github.com/thu-ml/CRM/assets/40787266/8b325bc0-aa74-4c26-92e8-a8f0c1079382 10 | 11 | ## Try CRM 🍻 12 | * Try CRM at [Huggingface Demo](https://huggingface.co/spaces/Zhengyi/CRM). 13 | * Try CRM at [Replicate Demo](https://replicate.com/camenduru/crm). Thanks [@camenduru](https://github.com/camenduru)! 14 | 15 | ## Install 16 | 17 | ### Step 1 - Base 18 | 19 | Install package one by one, we use **python 3.9** 20 | 21 | ```bash 22 | pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 23 | pip install torch-scatter==2.1.1 -f https://data.pyg.org/whl/torch-1.13.1+cu117.html 24 | pip install kaolin==0.14.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.13.1_cu117.html 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | besides, one by one need to install xformers manually according to the official [doc](https://github.com/facebookresearch/xformers?tab=readme-ov-file#installing-xformers) (**conda no need**), e.g. 29 | 30 | ```bash 31 | pip install ninja 32 | pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers 33 | ``` 34 | 35 | ### Step 2 - Nvdiffrast 36 | 37 | Install nvdiffrast according to the official [doc](https://nvlabs.github.io/nvdiffrast/#installation), e.g. 38 | 39 | ```bash 40 | pip install git+https://github.com/NVlabs/nvdiffrast 41 | ``` 42 | 43 | 44 | 45 | ## Inference 46 | 47 | We suggest gradio for a visualized inference. 48 | 49 | ``` 50 | gradio app.py 51 | ``` 52 | 53 | ![image](https://github.com/thu-ml/CRM/assets/40787266/4354d22a-a641-4531-8408-c761ead8b1a2) 54 | 55 | For inference in command lines, simply run 56 | ```bash 57 | CUDA_VISIBLE_DEVICES="0" python run.py --inputdir "examples/kunkun.webp" 58 | ``` 59 | It will output the preprocessed image, generated 6-view images and CCMs and a 3D model in obj format. 60 | 61 | **Tips:** (1) If the result is unsatisfatory, please check whether the input image is correctly pre-processed into a grey background. Otherwise the results will be unpredictable. 62 | (2) Different from the [Huggingface Demo](https://huggingface.co/spaces/Zhengyi/CRM), this official implementation uses UV texture instead of vertex color. It has better texture than the online demo but longer generating time owing to the UV texturing. 63 | 64 | ## Train 65 | We provide training script for multivew generation and their data requirements. 66 | To launch a simple one instance overfit training of multivew gen: 67 | ```shell 68 | accelerate launch $accelerate_args train.py --config configs/nf7_v3_SNR_rd_size_stroke_train.yaml \ 69 | config.batch_size=1 \ 70 | config.eval_interval=100 71 | ``` 72 | To launch a simple one instance overfit training of CCM gen: 73 | ```shell 74 | accelerate launch $accelerate_args train_stage2.py --config configs/stage2-v2-snr_train.yaml \ 75 | config.batch_size=1 \ 76 | config.eval_interval=100 77 | ``` 78 | 79 | ### data prepare 80 | To specify the data dir modify the following params in the configs/xxxx.yaml 81 | ```yaml 82 | base_dir: 83 | xyz_base: 84 | caption_csv: 85 | ``` 86 | The file tree of basedirs should satisfy as following: 87 | ```shell 88 | base_dir 89 | ├── uid1 90 | │ ├── 000.png 91 | │ ├── 001.png 92 | │ ├── 002.png 93 | │ ├── 003.png 94 | │ ├── 004.png 95 | │ ├── 005.png 96 | ├── uid2 97 | .... 98 | 99 | xyz_base 100 | ├── uid1 101 | │ ├── xyz_new_000.png 102 | │ ├── xyz_new_001.png 103 | │ ├── xyz_new_002.png 104 | │ ├── xyz_new_003.png 105 | │ ├── xyz_new_004.png 106 | │ └── xyz_new_005.png 107 | ├── uid2 108 | .... 109 | ``` 110 | The `train_example` dir shows a minimal case of train data and `caption.csv` file. 111 | 112 | 113 | 114 | ## Todo List 115 | - [x] Release inference code. 116 | - [x] Release pretrained models. 117 | - [ ] Optimize inference code to fit in low memery GPU. 118 | - [x] Upload training code. 119 | 120 | ## Acknowledgement 121 | - [ImageDream](https://github.com/bytedance/ImageDream) 122 | - [nvdiffrast](https://github.com/NVlabs/nvdiffrast) 123 | - [kiuikit](https://github.com/ashawkey/kiuikit) 124 | - [GET3D](https://github.com/nv-tlabs/GET3D) 125 | 126 | ## Citation 127 | 128 | ``` 129 | @article{wang2024crm, 130 | title={CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model}, 131 | author={Zhengyi Wang and Yikai Wang and Yifei Chen and Chendong Xiang and Shuo Chen and Dajiang Yu and Chongxuan Li and Hang Su and Jun Zhu}, 132 | journal={arXiv preprint arXiv:2403.05034}, 133 | year={2024} 134 | } 135 | ``` 136 | -------------------------------------------------------------------------------- /acce.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | offload_optimizer_device: none 5 | offload_param_device: none 6 | zero3_init_flag: false 7 | zero_stage: 2 8 | distributed_type: DEEPSPEED 9 | mixed_precision: fp16 10 | num_processes: 8 -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Not ready to use yet 2 | import argparse 3 | import numpy as np 4 | import gradio as gr 5 | from omegaconf import OmegaConf 6 | import torch 7 | from PIL import Image 8 | import PIL 9 | from pipelines import TwoStagePipeline 10 | from huggingface_hub import hf_hub_download 11 | import os 12 | import rembg 13 | from typing import Any 14 | import json 15 | import os 16 | import json 17 | import argparse 18 | 19 | from model import CRM 20 | from inference import generate3d 21 | 22 | pipeline = None 23 | rembg_session = rembg.new_session() 24 | 25 | 26 | def expand_to_square(image, bg_color=(0, 0, 0, 0)): 27 | # expand image to 1:1 28 | width, height = image.size 29 | if width == height: 30 | return image 31 | new_size = (max(width, height), max(width, height)) 32 | new_image = Image.new("RGBA", new_size, bg_color) 33 | paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) 34 | new_image.paste(image, paste_position) 35 | return new_image 36 | 37 | def check_input_image(input_image): 38 | if input_image is None: 39 | raise gr.Error("No image uploaded!") 40 | 41 | 42 | def remove_background( 43 | image: PIL.Image.Image, 44 | rembg_session = None, 45 | force: bool = False, 46 | **rembg_kwargs, 47 | ) -> PIL.Image.Image: 48 | do_remove = True 49 | if image.mode == "RGBA" and image.getextrema()[3][0] < 255: 50 | # explain why current do not rm bg 51 | print("alhpa channl not enpty, skip remove background, using alpha channel as mask") 52 | background = Image.new("RGBA", image.size, (0, 0, 0, 0)) 53 | image = Image.alpha_composite(background, image) 54 | do_remove = False 55 | do_remove = do_remove or force 56 | if do_remove: 57 | image = rembg.remove(image, session=rembg_session, **rembg_kwargs) 58 | return image 59 | 60 | def do_resize_content(original_image: Image, scale_rate): 61 | # resize image content wile retain the original image size 62 | if scale_rate != 1: 63 | # Calculate the new size after rescaling 64 | new_size = tuple(int(dim * scale_rate) for dim in original_image.size) 65 | # Resize the image while maintaining the aspect ratio 66 | resized_image = original_image.resize(new_size) 67 | # Create a new image with the original size and black background 68 | padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) 69 | paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) 70 | padded_image.paste(resized_image, paste_position) 71 | return padded_image 72 | else: 73 | return original_image 74 | 75 | def add_background(image, bg_color=(255, 255, 255)): 76 | # given an RGBA image, alpha channel is used as mask to add background color 77 | background = Image.new("RGBA", image.size, bg_color) 78 | return Image.alpha_composite(background, image) 79 | 80 | 81 | def preprocess_image(image, background_choice, foreground_ratio, backgroud_color): 82 | """ 83 | input image is a pil image in RGBA, return RGB image 84 | """ 85 | print(background_choice) 86 | if background_choice == "Alpha as mask": 87 | background = Image.new("RGBA", image.size, (0, 0, 0, 0)) 88 | image = Image.alpha_composite(background, image) 89 | else: 90 | image = remove_background(image, rembg_session, force_remove=True) 91 | image = do_resize_content(image, foreground_ratio) 92 | image = expand_to_square(image) 93 | image = add_background(image, backgroud_color) 94 | return image.convert("RGB") 95 | 96 | 97 | def gen_image(input_image, seed, scale, step): 98 | global pipeline, model, args 99 | pipeline.set_seed(seed) 100 | rt_dict = pipeline(input_image, scale=scale, step=step) 101 | stage1_images = rt_dict["stage1_images"] 102 | stage2_images = rt_dict["stage2_images"] 103 | np_imgs = np.concatenate(stage1_images, 1) 104 | np_xyzs = np.concatenate(stage2_images, 1) 105 | 106 | glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, args.device) 107 | return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path, obj_path 108 | 109 | 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument( 112 | "--stage1_config", 113 | type=str, 114 | default="configs/nf7_v3_SNR_rd_size_stroke.yaml", 115 | help="config for stage1", 116 | ) 117 | parser.add_argument( 118 | "--stage2_config", 119 | type=str, 120 | default="configs/stage2-v2-snr.yaml", 121 | help="config for stage2", 122 | ) 123 | 124 | parser.add_argument("--device", type=str, default="cuda") 125 | args = parser.parse_args() 126 | 127 | crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth") 128 | specs = json.load(open("configs/specs_objaverse_total.json")) 129 | model = CRM(specs).to(args.device) 130 | model.load_state_dict(torch.load(crm_path, map_location = args.device), strict=False) 131 | 132 | stage1_config = OmegaConf.load(args.stage1_config).config 133 | stage2_config = OmegaConf.load(args.stage2_config).config 134 | stage2_sampler_config = stage2_config.sampler 135 | stage1_sampler_config = stage1_config.sampler 136 | 137 | stage1_model_config = stage1_config.models 138 | stage2_model_config = stage2_config.models 139 | 140 | xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth") 141 | pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth") 142 | stage1_model_config.resume = pixel_path 143 | stage2_model_config.resume = xyz_path 144 | 145 | pipeline = TwoStagePipeline( 146 | stage1_model_config, 147 | stage2_model_config, 148 | stage1_sampler_config, 149 | stage2_sampler_config, 150 | device=args.device, 151 | dtype=torch.float16 152 | ) 153 | 154 | with gr.Blocks() as demo: 155 | gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model") 156 | with gr.Row(): 157 | with gr.Column(): 158 | with gr.Row(): 159 | image_input = gr.Image( 160 | label="Image input", 161 | image_mode="RGBA", 162 | sources="upload", 163 | type="pil", 164 | ) 165 | processed_image = gr.Image(label="Processed Image", interactive=False, type="pil", image_mode="RGB") 166 | with gr.Row(): 167 | with gr.Column(): 168 | with gr.Row(): 169 | background_choice = gr.Radio([ 170 | "Alpha as mask", 171 | "Auto Remove background" 172 | ], value="Auto Remove background", 173 | label="backgroud choice") 174 | # do_remove_background = gr.Checkbox(label=, value=True) 175 | # force_remove = gr.Checkbox(label=, value=False) 176 | back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False) 177 | foreground_ratio = gr.Slider( 178 | label="Foreground Ratio", 179 | minimum=0.5, 180 | maximum=1.0, 181 | value=1.0, 182 | step=0.05, 183 | ) 184 | 185 | with gr.Column(): 186 | seed = gr.Number(value=1234, label="seed", precision=0) 187 | guidance_scale = gr.Number(value=5.5, minimum=3, maximum=10, label="guidance_scale") 188 | step = gr.Number(value=50, minimum=30, maximum=100, label="sample steps", precision=0) 189 | text_button = gr.Button("Generate 3D shape") 190 | gr.Examples( 191 | examples=[os.path.join("examples", i) for i in os.listdir("examples")], 192 | inputs=[image_input], 193 | ) 194 | with gr.Column(): 195 | image_output = gr.Image(interactive=False, label="Output RGB image") 196 | xyz_ouput = gr.Image(interactive=False, label="Output CCM image") 197 | 198 | output_model = gr.Model3D( 199 | label="Output GLB", 200 | interactive=False, 201 | ) 202 | gr.Markdown("Note: The GLB model shown here has a darker lighting and enlarged UV seams. Download for correct results.") 203 | output_obj = gr.File(interactive=False, label="Output OBJ") 204 | 205 | inputs = [ 206 | processed_image, 207 | seed, 208 | guidance_scale, 209 | step, 210 | ] 211 | outputs = [ 212 | image_output, 213 | xyz_ouput, 214 | output_model, 215 | output_obj, 216 | ] 217 | 218 | 219 | text_button.click(fn=check_input_image, inputs=[image_input]).success( 220 | fn=preprocess_image, 221 | inputs=[image_input, background_choice, foreground_ratio, back_groud_color], 222 | outputs=[processed_image], 223 | ).success( 224 | fn=gen_image, 225 | inputs=inputs, 226 | outputs=outputs, 227 | ) 228 | demo.queue().launch() 229 | -------------------------------------------------------------------------------- /configs/nf7_v3_SNR_rd_size_stroke.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | # others 3 | seed: 1234 4 | num_frames: 7 5 | mode: pixel 6 | offset_noise: true 7 | # model related 8 | models: 9 | config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml 10 | resume: models/pixel.pth 11 | # sampler related 12 | sampler: 13 | target: libs.sample.ImageDreamDiffusion 14 | params: 15 | mode: pixel 16 | num_frames: 7 17 | camera_views: [1, 2, 3, 4, 5, 0, 0] 18 | ref_position: 6 19 | random_background: false 20 | offset_noise: true 21 | resize_rate: 1.0 -------------------------------------------------------------------------------- /configs/nf7_v3_SNR_rd_size_stroke_train.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | # others 3 | seed: 1234 4 | num_frames: 7 5 | mode: pixel 6 | offset_noise: true 7 | # model related 8 | models: 9 | config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml 10 | resume: release_models/sd-v2.1-base-4view-ipmv.pt 11 | # sampler related 12 | sampler: 13 | target: libs.sample.ImageDreamDiffusion 14 | params: 15 | mode: pixel 16 | num_frames: 7 17 | camera_views: [1, 2, 3, 4, 5, 0, 0] 18 | ref_position: 6 19 | random_background: false 20 | offset_noise: true 21 | resize_rate: 1.0 22 | 23 | # config datasets 24 | train_data: 25 | target: libs.data.DataRelativeStroke 26 | params: 27 | base_dir: train_examples 28 | caption_csv: train_examples/caption.csv 29 | image_size: 256 30 | repeat: 1 31 | camera_views: [1, 2, 3, 4, 5, 0, 0] 32 | ref_indexs: [0, 1, 3, 4, 5, 2] 33 | ref_position: 6 34 | split: train 35 | num_frames: 7 36 | random_background: true 37 | resize_rate: 0.95 38 | stroke_p: 0.5 39 | eval_size: 100 40 | resize_range: 41 | - 0.5 42 | - 1.0 43 | eval_data: 44 | target: libs.data.DataRelativeStroke 45 | params: 46 | base_dir: train_examples 47 | caption_csv: train_examples/caption.csv 48 | image_size: 256 49 | repeat: 1 50 | camera_views: [1, 2, 3, 4, 5, 0, 0] # camera views are relative views 51 | ref_indexs: [0, 1, 3, 4, 5, 2] 52 | ref_position: 6 53 | split: eval 54 | num_frames: 7 55 | random_background: true 56 | resize_rate: 0.95 57 | stroke_p: 0.5 58 | eval_size: 100 59 | resize_range: 60 | - 0.5 61 | - 1.0 62 | 63 | in_the_wild_images: 64 | target: libs.data.InTheWildImages 65 | params: 66 | base_dirs: 67 | - examples 68 | 69 | # optimizer related 70 | optimizer: 71 | lr: 5e-5 72 | gradient_accumulation_steps: 12 73 | 74 | # wandb related parameters 75 | project: CRM 76 | wandb_run_name: CRM-pixel 77 | wandb_mode: offline 78 | 79 | 80 | # training hyperparmeters 81 | batch_size: 16 82 | dataloader: 83 | num_workers: 10 84 | shuffle: true 85 | drop_last: true 86 | 87 | save_interval: 600000 88 | log_interval: 5000 89 | eval_interval: 300000 90 | max_step: 10000000 91 | -------------------------------------------------------------------------------- /configs/specs_objaverse_total.json: -------------------------------------------------------------------------------- 1 | { 2 | "Input": { 3 | "img_num": 16, 4 | "class": "all", 5 | "camera_angle_num": 8, 6 | "tet_grid_size": 80, 7 | "validate_num": 16, 8 | "scale": 0.95, 9 | "radius": 3, 10 | "resolution": [256, 256] 11 | }, 12 | 13 | "Pretrain": { 14 | "mode": null, 15 | "sdf_threshold": 0.1, 16 | "sdf_scale": 10, 17 | "batch_infer": false, 18 | "lr": 1e-4, 19 | "radius": 0.5 20 | }, 21 | 22 | "Train": { 23 | "mode": "rnd", 24 | "num_epochs": 500, 25 | "grad_acc": 1, 26 | "warm_up": 0, 27 | "decay": 0.000, 28 | "learning_rate": { 29 | "init": 1e-4, 30 | "sdf_decay": 1, 31 | "rgb_decay": 1 32 | }, 33 | "batch_size": 4, 34 | "eva_iter": 80, 35 | "eva_all_epoch": 10, 36 | "tex_sup_mode": "blender", 37 | "exp_uv_mesh": false, 38 | "doub": false, 39 | "random_bg": false, 40 | "shift": 0, 41 | "aug_shift": 0, 42 | "geo_type": "flex" 43 | }, 44 | 45 | "ArchSpecs": { 46 | "unet_type": "diffusers", 47 | "use_3D_aware": false, 48 | "fea_concat": false, 49 | "mlp_bias": true 50 | }, 51 | 52 | "DecoderSpecs": { 53 | "c_dim": 32, 54 | "plane_resolution": 256 55 | } 56 | } 57 | 58 | -------------------------------------------------------------------------------- /configs/stage2-v2-snr.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | # others 3 | seed: 1234 4 | num_frames: 6 5 | mode: pixel 6 | offset_noise: true 7 | gd_type: xyz 8 | # model related 9 | models: 10 | config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml 11 | resume: models/xyz.pth 12 | 13 | # eval related 14 | sampler: 15 | target: libs.sample.ImageDreamDiffusionStage2 16 | params: 17 | mode: pixel 18 | num_frames: 6 19 | camera_views: [1, 2, 3, 4, 5, 0] 20 | ref_position: null 21 | random_background: false 22 | offset_noise: true 23 | resize_rate: 1.0 24 | 25 | 26 | -------------------------------------------------------------------------------- /configs/stage2-v2-snr_train.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | # others 3 | seed: 1234 4 | num_frames: 6 5 | mode: pixel 6 | offset_noise: true 7 | gd_type: xyz 8 | # model related 9 | models: 10 | config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml 11 | resume: release_models/ImageDream/sd-v2.1-base-4view-ipmv.pt 12 | resume_unet: null 13 | 14 | # eval related 15 | sampler: 16 | target: libs.sample.ImageDreamDiffusionStage2 17 | params: 18 | mode: pixel 19 | num_frames: 6 20 | camera_views: [1, 2, 3, 4, 5, 0] 21 | ref_position: null 22 | random_background: false 23 | offset_noise: true 24 | resize_rate: 1.0 25 | 26 | # config datasets 27 | train_data: 28 | target: libs.data.DataHQCRelative 29 | params: 30 | xyz_base: train_examples 31 | base_dir: train_examples 32 | caption_csv: train_examples/caption.csv 33 | image_size: 256 34 | repeat: 1 35 | camera_views: [1, 2, 3, 4, 5, 0] 36 | ref_indexs: [0, 1, 3, 4] 37 | ref_position: null 38 | split: train 39 | num_frames: 6 40 | random_background: true 41 | resize_rate: 0.95 42 | eval_data: 43 | target: libs.data.DataHQCRelative 44 | params: 45 | xyz_base: train_examples 46 | base_dir: train_examples 47 | caption_csv: train_examples/caption.csv 48 | image_size: 256 49 | repeat: 1 50 | camera_views: [1, 2, 3, 4, 5, 0] # when pixel mode, last image will be coverd by ref image 51 | ref_indexs: [0, 1, 3, 4] 52 | ref_position: null 53 | split: eval 54 | num_frames: 6 55 | random_background: true 56 | resize_rate: 0.95 57 | 58 | # optimizer related 59 | optimizer: 60 | lr: 5e-5 61 | gradient_accumulation_steps: 12 62 | 63 | # wandb related parameters 64 | project: CRM 65 | wandb_run_name: CRM-xyz 66 | wandb_mode: offline 67 | 68 | 69 | # training hyperparmeters 70 | batch_size: 16 71 | dataloader: 72 | num_workers: 10 73 | shuffle: true 74 | drop_last: true 75 | 76 | save_interval: 400000 77 | log_interval: 5000 78 | eval_interval: 50000 79 | max_step: 100000000 80 | -------------------------------------------------------------------------------- /examples/3D卡通狗.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/3D卡通狗.webp -------------------------------------------------------------------------------- /examples/astronaut.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/astronaut.webp -------------------------------------------------------------------------------- /examples/bulldog.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/bulldog.webp -------------------------------------------------------------------------------- /examples/ghost-eating-burger.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/ghost-eating-burger.webp -------------------------------------------------------------------------------- /examples/kunkun.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/kunkun.webp -------------------------------------------------------------------------------- /examples/万圣南瓜.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/万圣南瓜.webp -------------------------------------------------------------------------------- /examples/人物骑马.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/人物骑马.webp -------------------------------------------------------------------------------- /examples/初音未来玩偶.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/初音未来玩偶.webp -------------------------------------------------------------------------------- /examples/卡通恐龙.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/卡通恐龙.webp -------------------------------------------------------------------------------- /examples/卡通手枪截图.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/卡通手枪截图.webp -------------------------------------------------------------------------------- /examples/卡通猫.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/卡通猫.webp -------------------------------------------------------------------------------- /examples/卡通蘑菇套装.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/卡通蘑菇套装.webp -------------------------------------------------------------------------------- /examples/可爱玄策.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/可爱玄策.webp -------------------------------------------------------------------------------- /examples/大头泡泡马特.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/大头泡泡马特.webp -------------------------------------------------------------------------------- /examples/彩色蘑菇.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/彩色蘑菇.webp -------------------------------------------------------------------------------- /examples/彩色蘑菇2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/彩色蘑菇2.webp -------------------------------------------------------------------------------- /examples/恐龙套装.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/恐龙套装.webp -------------------------------------------------------------------------------- /examples/手办.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/手办.webp -------------------------------------------------------------------------------- /examples/机械狗裁切.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/机械狗裁切.webp -------------------------------------------------------------------------------- /examples/林克.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/林克.webp -------------------------------------------------------------------------------- /examples/植物1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/植物1.webp -------------------------------------------------------------------------------- /examples/武器-剑.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/武器-剑.webp -------------------------------------------------------------------------------- /examples/毛线衣.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/毛线衣.webp -------------------------------------------------------------------------------- /examples/海龟.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/海龟.webp -------------------------------------------------------------------------------- /examples/猫人.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/猫人.webp -------------------------------------------------------------------------------- /examples/猫头鹰.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/猫头鹰.webp -------------------------------------------------------------------------------- /examples/玩具兔.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/玩具兔.webp -------------------------------------------------------------------------------- /examples/玩具熊.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/玩具熊.webp -------------------------------------------------------------------------------- /examples/玩具猪.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/玩具猪.webp -------------------------------------------------------------------------------- /examples/玫瑰.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/玫瑰.webp -------------------------------------------------------------------------------- /examples/皮卡丘.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/皮卡丘.webp -------------------------------------------------------------------------------- /examples/皮鞋.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/皮鞋.webp -------------------------------------------------------------------------------- /examples/石头.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/石头.webp -------------------------------------------------------------------------------- /examples/石头哆啦A梦.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/石头哆啦A梦.webp -------------------------------------------------------------------------------- /examples/红玩具猪.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/红玩具猪.webp -------------------------------------------------------------------------------- /examples/翅膀道具.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/翅膀道具.webp -------------------------------------------------------------------------------- /examples/茶壶.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/茶壶.webp -------------------------------------------------------------------------------- /examples/草系精灵.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/草系精灵.webp -------------------------------------------------------------------------------- /examples/蓝色小怪物.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/蓝色小怪物.webp -------------------------------------------------------------------------------- /examples/蓝色泡泡马特.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/蓝色泡泡马特.webp -------------------------------------------------------------------------------- /examples/蓝色猫.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/蓝色猫.webp -------------------------------------------------------------------------------- /examples/赛博朋克-男.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/赛博朋克-男.webp -------------------------------------------------------------------------------- /examples/路灯.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/路灯.webp -------------------------------------------------------------------------------- /examples/运动系手办.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/examples/运动系手办.webp -------------------------------------------------------------------------------- /imagedream/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_zoo import build_model 2 | -------------------------------------------------------------------------------- /imagedream/camera_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def create_camera_to_world_matrix(elevation, azimuth): 6 | elevation = np.radians(elevation) 7 | azimuth = np.radians(azimuth) 8 | # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere 9 | x = np.cos(elevation) * np.sin(azimuth) 10 | y = np.sin(elevation) 11 | z = np.cos(elevation) * np.cos(azimuth) 12 | 13 | # Calculate camera position, target, and up vectors 14 | camera_pos = np.array([x, y, z]) 15 | target = np.array([0, 0, 0]) 16 | up = np.array([0, 1, 0]) 17 | 18 | # Construct view matrix 19 | forward = target - camera_pos 20 | forward /= np.linalg.norm(forward) 21 | right = np.cross(forward, up) 22 | right /= np.linalg.norm(right) 23 | new_up = np.cross(right, forward) 24 | new_up /= np.linalg.norm(new_up) 25 | cam2world = np.eye(4) 26 | cam2world[:3, :3] = np.array([right, new_up, -forward]).T 27 | cam2world[:3, 3] = camera_pos 28 | return cam2world 29 | 30 | 31 | def convert_opengl_to_blender(camera_matrix): 32 | if isinstance(camera_matrix, np.ndarray): 33 | # Construct transformation matrix to convert from OpenGL space to Blender space 34 | flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) 35 | camera_matrix_blender = np.dot(flip_yz, camera_matrix) 36 | else: 37 | # Construct transformation matrix to convert from OpenGL space to Blender space 38 | flip_yz = torch.tensor( 39 | [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]] 40 | ) 41 | if camera_matrix.ndim == 3: 42 | flip_yz = flip_yz.unsqueeze(0) 43 | camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix) 44 | return camera_matrix_blender 45 | 46 | 47 | def normalize_camera(camera_matrix): 48 | """normalize the camera location onto a unit-sphere""" 49 | if isinstance(camera_matrix, np.ndarray): 50 | camera_matrix = camera_matrix.reshape(-1, 4, 4) 51 | translation = camera_matrix[:, :3, 3] 52 | translation = translation / ( 53 | np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8 54 | ) 55 | camera_matrix[:, :3, 3] = translation 56 | else: 57 | camera_matrix = camera_matrix.reshape(-1, 4, 4) 58 | translation = camera_matrix[:, :3, 3] 59 | translation = translation / ( 60 | torch.norm(translation, dim=1, keepdim=True) + 1e-8 61 | ) 62 | camera_matrix[:, :3, 3] = translation 63 | return camera_matrix.reshape(-1, 16) 64 | 65 | 66 | def get_camera( 67 | num_frames, 68 | elevation=15, 69 | azimuth_start=0, 70 | azimuth_span=360, 71 | blender_coord=True, 72 | extra_view=False, 73 | ): 74 | angle_gap = azimuth_span / num_frames 75 | cameras = [] 76 | for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap): 77 | camera_matrix = create_camera_to_world_matrix(elevation, azimuth) 78 | if blender_coord: 79 | camera_matrix = convert_opengl_to_blender(camera_matrix) 80 | cameras.append(camera_matrix.flatten()) 81 | 82 | if extra_view: 83 | dim = len(cameras[0]) 84 | cameras.append(np.zeros(dim)) 85 | return torch.tensor(np.stack(cameras, 0)).float() 86 | 87 | 88 | def get_camera_for_index(data_index): 89 | """ 90 | 按照当前我们的数据格式, 以000为正对我们的情况: 91 | 000是正面, ev: 0, azimuth: 0 92 | 001是左边, ev: 0, azimuth: -90 93 | 002是下面, ev: -90, azimuth: 0 94 | 003是背面, ev: 0, azimuth: 180 95 | 004是右边, ev: 0, azimuth: 90 96 | 005是上面, ev: 90, azimuth: 0 97 | """ 98 | params = [(0, 0), (0, -90), (-90, 0), (0, 180), (0, 90), (90, 0)] 99 | return get_camera(1, *params[data_index]) -------------------------------------------------------------------------------- /imagedream/configs/sd_v2_base_ipmv.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: imagedream.ldm.interface.LatentDiffusionInterface 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | parameterization: "eps" 9 | 10 | unet_config: 11 | target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel 12 | params: 13 | image_size: 32 # unused 14 | in_channels: 4 15 | out_channels: 4 16 | model_channels: 320 17 | attention_resolutions: [ 4, 2, 1 ] 18 | num_res_blocks: 2 19 | channel_mult: [ 1, 2, 4, 4 ] 20 | num_head_channels: 64 # need to fix for flash-attn 21 | use_spatial_transformer: True 22 | use_linear_in_transformer: True 23 | transformer_depth: 1 24 | context_dim: 1024 25 | use_checkpoint: False 26 | legacy: False 27 | camera_dim: 16 28 | with_ip: True 29 | ip_dim: 16 # ip token length 30 | ip_mode: "local_resample" 31 | 32 | vae_config: 33 | target: imagedream.ldm.models.autoencoder.AutoencoderKL 34 | params: 35 | embed_dim: 4 36 | monitor: val/rec_loss 37 | ddconfig: 38 | #attn_type: "vanilla-xformers" 39 | double_z: true 40 | z_channels: 4 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | 56 | clip_config: 57 | target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 58 | params: 59 | freeze: True 60 | layer: "penultimate" 61 | ip_mode: "local_resample" 62 | -------------------------------------------------------------------------------- /imagedream/configs/sd_v2_base_ipmv_ch8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: imagedream.ldm.interface.LatentDiffusionInterface 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | parameterization: "eps" 9 | 10 | unet_config: 11 | target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel 12 | params: 13 | image_size: 32 # unused 14 | in_channels: 8 15 | out_channels: 8 16 | model_channels: 320 17 | attention_resolutions: [ 4, 2, 1 ] 18 | num_res_blocks: 2 19 | channel_mult: [ 1, 2, 4, 4 ] 20 | num_head_channels: 64 # need to fix for flash-attn 21 | use_spatial_transformer: True 22 | use_linear_in_transformer: True 23 | transformer_depth: 1 24 | context_dim: 1024 25 | use_checkpoint: False 26 | legacy: False 27 | camera_dim: 16 28 | with_ip: True 29 | ip_dim: 16 # ip token length 30 | ip_mode: "local_resample" 31 | 32 | vae_config: 33 | target: imagedream.ldm.models.autoencoder.AutoencoderKL 34 | params: 35 | embed_dim: 4 36 | monitor: val/rec_loss 37 | ddconfig: 38 | #attn_type: "vanilla-xformers" 39 | double_z: true 40 | z_channels: 4 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | 56 | clip_config: 57 | target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 58 | params: 59 | freeze: True 60 | layer: "penultimate" 61 | ip_mode: "local_resample" 62 | -------------------------------------------------------------------------------- /imagedream/configs/sd_v2_base_ipmv_chin8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: imagedream.ldm.interface.LatentDiffusionInterface 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | parameterization: "eps" 9 | 10 | unet_config: 11 | target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2 12 | params: 13 | image_size: 32 # unused 14 | in_channels: 8 15 | out_channels: 4 16 | model_channels: 320 17 | attention_resolutions: [ 4, 2, 1 ] 18 | num_res_blocks: 2 19 | channel_mult: [ 1, 2, 4, 4 ] 20 | num_head_channels: 64 # need to fix for flash-attn 21 | use_spatial_transformer: True 22 | use_linear_in_transformer: True 23 | transformer_depth: 1 24 | context_dim: 1024 25 | use_checkpoint: False 26 | legacy: False 27 | camera_dim: 16 28 | with_ip: True 29 | ip_dim: 16 # ip token length 30 | ip_mode: "local_resample" 31 | 32 | vae_config: 33 | target: imagedream.ldm.models.autoencoder.AutoencoderKL 34 | params: 35 | embed_dim: 4 36 | monitor: val/rec_loss 37 | ddconfig: 38 | #attn_type: "vanilla-xformers" 39 | double_z: true 40 | z_channels: 4 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | 56 | clip_config: 57 | target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 58 | params: 59 | freeze: True 60 | layer: "penultimate" 61 | ip_mode: "local_resample" 62 | -------------------------------------------------------------------------------- /imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: imagedream.ldm.interface.LatentDiffusionInterface 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | parameterization: "eps" 9 | zero_snr: true 10 | 11 | unet_config: 12 | target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2 13 | params: 14 | image_size: 32 # unused 15 | in_channels: 8 16 | out_channels: 4 17 | model_channels: 320 18 | attention_resolutions: [ 4, 2, 1 ] 19 | num_res_blocks: 2 20 | channel_mult: [ 1, 2, 4, 4 ] 21 | num_head_channels: 64 # need to fix for flash-attn 22 | use_spatial_transformer: True 23 | use_linear_in_transformer: True 24 | transformer_depth: 1 25 | context_dim: 1024 26 | use_checkpoint: False 27 | legacy: False 28 | camera_dim: 16 29 | with_ip: True 30 | ip_dim: 16 # ip token length 31 | ip_mode: "local_resample" 32 | 33 | vae_config: 34 | target: imagedream.ldm.models.autoencoder.AutoencoderKL 35 | params: 36 | embed_dim: 4 37 | monitor: val/rec_loss 38 | ddconfig: 39 | #attn_type: "vanilla-xformers" 40 | double_z: true 41 | z_channels: 4 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | 57 | clip_config: 58 | target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 59 | params: 60 | freeze: True 61 | layer: "penultimate" 62 | ip_mode: "local_resample" 63 | -------------------------------------------------------------------------------- /imagedream/configs/sd_v2_base_ipmv_local.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: imagedream.ldm.interface.LatentDiffusionInterface 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | parameterization: "eps" 9 | 10 | unet_config: 11 | target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel 12 | params: 13 | image_size: 32 # unused 14 | in_channels: 4 15 | out_channels: 4 16 | model_channels: 320 17 | attention_resolutions: [ 4, 2, 1 ] 18 | num_res_blocks: 2 19 | channel_mult: [ 1, 2, 4, 4 ] 20 | num_head_channels: 64 # need to fix for flash-attn 21 | use_spatial_transformer: True 22 | use_linear_in_transformer: True 23 | transformer_depth: 1 24 | context_dim: 1024 25 | use_checkpoint: False 26 | legacy: False 27 | camera_dim: 16 28 | with_ip: True 29 | ip_dim: 16 # ip token length 30 | ip_mode: "local_resample" 31 | ip_weight: 1.0 # adjust for similarity to image 32 | 33 | vae_config: 34 | target: imagedream.ldm.models.autoencoder.AutoencoderKL 35 | params: 36 | embed_dim: 4 37 | monitor: val/rec_loss 38 | ddconfig: 39 | #attn_type: "vanilla-xformers" 40 | double_z: true 41 | z_channels: 4 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | 57 | clip_config: 58 | target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 59 | params: 60 | freeze: True 61 | layer: "penultimate" 62 | ip_mode: "local_resample" 63 | -------------------------------------------------------------------------------- /imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: imagedream.ldm.interface.LatentDiffusionInterface 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | parameterization: "eps" 9 | zero_snr: true 10 | 11 | unet_config: 12 | target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel 13 | params: 14 | image_size: 32 # unused 15 | in_channels: 4 16 | out_channels: 4 17 | model_channels: 320 18 | attention_resolutions: [ 4, 2, 1 ] 19 | num_res_blocks: 2 20 | channel_mult: [ 1, 2, 4, 4 ] 21 | num_head_channels: 64 # need to fix for flash-attn 22 | use_spatial_transformer: True 23 | use_linear_in_transformer: True 24 | transformer_depth: 1 25 | context_dim: 1024 26 | use_checkpoint: False 27 | legacy: False 28 | camera_dim: 16 29 | with_ip: True 30 | ip_dim: 16 # ip token length 31 | ip_mode: "local_resample" 32 | 33 | vae_config: 34 | target: imagedream.ldm.models.autoencoder.AutoencoderKL 35 | params: 36 | embed_dim: 4 37 | monitor: val/rec_loss 38 | ddconfig: 39 | #attn_type: "vanilla-xformers" 40 | double_z: true 41 | z_channels: 4 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | 57 | clip_config: 58 | target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 59 | params: 60 | freeze: True 61 | layer: "penultimate" 62 | ip_mode: "local_resample" 63 | -------------------------------------------------------------------------------- /imagedream/ldm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/imagedream/ldm/__init__.py -------------------------------------------------------------------------------- /imagedream/ldm/interface.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .modules.diffusionmodules.util import ( 9 | make_beta_schedule, 10 | extract_into_tensor, 11 | enforce_zero_terminal_snr, 12 | noise_like, 13 | ) 14 | from .util import exists, default, instantiate_from_config 15 | from .modules.distributions.distributions import DiagonalGaussianDistribution 16 | 17 | 18 | class DiffusionWrapper(nn.Module): 19 | def __init__(self, diffusion_model): 20 | super().__init__() 21 | self.diffusion_model = diffusion_model 22 | 23 | def forward(self, *args, **kwargs): 24 | return self.diffusion_model(*args, **kwargs) 25 | 26 | 27 | class LatentDiffusionInterface(nn.Module): 28 | """a simple interface class for LDM inference""" 29 | 30 | def __init__( 31 | self, 32 | unet_config, 33 | clip_config, 34 | vae_config, 35 | parameterization="eps", 36 | scale_factor=0.18215, 37 | beta_schedule="linear", 38 | timesteps=1000, 39 | linear_start=0.00085, 40 | linear_end=0.0120, 41 | cosine_s=8e-3, 42 | given_betas=None, 43 | zero_snr=False, 44 | *args, 45 | **kwargs, 46 | ): 47 | super().__init__() 48 | 49 | unet = instantiate_from_config(unet_config) 50 | self.model = DiffusionWrapper(unet) 51 | self.clip_model = instantiate_from_config(clip_config) 52 | self.vae_model = instantiate_from_config(vae_config) 53 | 54 | self.parameterization = parameterization 55 | self.scale_factor = scale_factor 56 | self.register_schedule( 57 | given_betas=given_betas, 58 | beta_schedule=beta_schedule, 59 | timesteps=timesteps, 60 | linear_start=linear_start, 61 | linear_end=linear_end, 62 | cosine_s=cosine_s, 63 | zero_snr=zero_snr 64 | ) 65 | 66 | def register_schedule( 67 | self, 68 | given_betas=None, 69 | beta_schedule="linear", 70 | timesteps=1000, 71 | linear_start=1e-4, 72 | linear_end=2e-2, 73 | cosine_s=8e-3, 74 | zero_snr=False 75 | ): 76 | if exists(given_betas): 77 | betas = given_betas 78 | else: 79 | betas = make_beta_schedule( 80 | beta_schedule, 81 | timesteps, 82 | linear_start=linear_start, 83 | linear_end=linear_end, 84 | cosine_s=cosine_s, 85 | ) 86 | if zero_snr: 87 | print("--- using zero snr---") 88 | betas = enforce_zero_terminal_snr(betas).numpy() 89 | alphas = 1.0 - betas 90 | alphas_cumprod = np.cumprod(alphas, axis=0) 91 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 92 | 93 | (timesteps,) = betas.shape 94 | self.num_timesteps = int(timesteps) 95 | self.linear_start = linear_start 96 | self.linear_end = linear_end 97 | assert ( 98 | alphas_cumprod.shape[0] == self.num_timesteps 99 | ), "alphas have to be defined for each timestep" 100 | 101 | to_torch = partial(torch.tensor, dtype=torch.float32) 102 | 103 | self.register_buffer("betas", to_torch(betas)) 104 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 105 | self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) 106 | 107 | # calculations for diffusion q(x_t | x_{t-1}) and others 108 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) 109 | self.register_buffer( 110 | "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) 111 | ) 112 | self.register_buffer( 113 | "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) 114 | ) 115 | eps = 1e-8 # adding small epsilon value to avoid devide by zero error 116 | self.register_buffer( 117 | "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps))) 118 | ) 119 | self.register_buffer( 120 | "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1)) 121 | ) 122 | 123 | # calculations for posterior q(x_{t-1} | x_t, x_0) 124 | self.v_posterior = 0 125 | posterior_variance = (1 - self.v_posterior) * betas * ( 126 | 1.0 - alphas_cumprod_prev 127 | ) / (1.0 - alphas_cumprod) + self.v_posterior * betas 128 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 129 | self.register_buffer("posterior_variance", to_torch(posterior_variance)) 130 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 131 | self.register_buffer( 132 | "posterior_log_variance_clipped", 133 | to_torch(np.log(np.maximum(posterior_variance, 1e-20))), 134 | ) 135 | self.register_buffer( 136 | "posterior_mean_coef1", 137 | to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), 138 | ) 139 | self.register_buffer( 140 | "posterior_mean_coef2", 141 | to_torch( 142 | (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) 143 | ), 144 | ) 145 | 146 | def q_sample(self, x_start, t, noise=None): 147 | noise = default(noise, lambda: torch.randn_like(x_start)) 148 | return ( 149 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 150 | + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 151 | * noise 152 | ) 153 | 154 | def get_v(self, x, noise, t): 155 | return ( 156 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise 157 | - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x 158 | ) 159 | 160 | def predict_start_from_noise(self, x_t, t, noise): 161 | return ( 162 | extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 163 | - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 164 | * noise 165 | ) 166 | 167 | def predict_start_from_z_and_v(self, x_t, t, v): 168 | return ( 169 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t 170 | - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v 171 | ) 172 | 173 | def predict_eps_from_z_and_v(self, x_t, t, v): 174 | return ( 175 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v 176 | + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) 177 | * x_t 178 | ) 179 | 180 | def apply_model(self, x_noisy, t, cond, **kwargs): 181 | assert isinstance(cond, dict), "cond has to be a dictionary" 182 | return self.model(x_noisy, t, **cond, **kwargs) 183 | 184 | def get_learned_conditioning(self, prompts: List[str]): 185 | return self.clip_model(prompts) 186 | 187 | def get_learned_image_conditioning(self, images): 188 | return self.clip_model.forward_image(images) 189 | 190 | def get_first_stage_encoding(self, encoder_posterior): 191 | if isinstance(encoder_posterior, DiagonalGaussianDistribution): 192 | z = encoder_posterior.sample() 193 | elif isinstance(encoder_posterior, torch.Tensor): 194 | z = encoder_posterior 195 | else: 196 | raise NotImplementedError( 197 | f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" 198 | ) 199 | return self.scale_factor * z 200 | 201 | def encode_first_stage(self, x): 202 | return self.vae_model.encode(x) 203 | 204 | def decode_first_stage(self, z): 205 | z = 1.0 / self.scale_factor * z 206 | return self.vae_model.decode(z) 207 | -------------------------------------------------------------------------------- /imagedream/ldm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/imagedream/ldm/models/__init__.py -------------------------------------------------------------------------------- /imagedream/ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from contextlib import contextmanager 4 | 5 | from ..modules.diffusionmodules.model import Encoder, Decoder 6 | from ..modules.distributions.distributions import DiagonalGaussianDistribution 7 | 8 | from ..util import instantiate_from_config 9 | from ..modules.ema import LitEma 10 | 11 | 12 | class AutoencoderKL(torch.nn.Module): 13 | def __init__( 14 | self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False, 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels) == int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0.0 < ema_decay < 1.0 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss( 116 | inputs, 117 | reconstructions, 118 | posterior, 119 | optimizer_idx, 120 | self.global_step, 121 | last_layer=self.get_last_layer(), 122 | split="train", 123 | ) 124 | self.log( 125 | "aeloss", 126 | aeloss, 127 | prog_bar=True, 128 | logger=True, 129 | on_step=True, 130 | on_epoch=True, 131 | ) 132 | self.log_dict( 133 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False 134 | ) 135 | return aeloss 136 | 137 | if optimizer_idx == 1: 138 | # train the discriminator 139 | discloss, log_dict_disc = self.loss( 140 | inputs, 141 | reconstructions, 142 | posterior, 143 | optimizer_idx, 144 | self.global_step, 145 | last_layer=self.get_last_layer(), 146 | split="train", 147 | ) 148 | 149 | self.log( 150 | "discloss", 151 | discloss, 152 | prog_bar=True, 153 | logger=True, 154 | on_step=True, 155 | on_epoch=True, 156 | ) 157 | self.log_dict( 158 | log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False 159 | ) 160 | return discloss 161 | 162 | def validation_step(self, batch, batch_idx): 163 | log_dict = self._validation_step(batch, batch_idx) 164 | with self.ema_scope(): 165 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 166 | return log_dict 167 | 168 | def _validation_step(self, batch, batch_idx, postfix=""): 169 | inputs = self.get_input(batch, self.image_key) 170 | reconstructions, posterior = self(inputs) 171 | aeloss, log_dict_ae = self.loss( 172 | inputs, 173 | reconstructions, 174 | posterior, 175 | 0, 176 | self.global_step, 177 | last_layer=self.get_last_layer(), 178 | split="val" + postfix, 179 | ) 180 | 181 | discloss, log_dict_disc = self.loss( 182 | inputs, 183 | reconstructions, 184 | posterior, 185 | 1, 186 | self.global_step, 187 | last_layer=self.get_last_layer(), 188 | split="val" + postfix, 189 | ) 190 | 191 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 192 | self.log_dict(log_dict_ae) 193 | self.log_dict(log_dict_disc) 194 | return self.log_dict 195 | 196 | def configure_optimizers(self): 197 | lr = self.learning_rate 198 | ae_params_list = ( 199 | list(self.encoder.parameters()) 200 | + list(self.decoder.parameters()) 201 | + list(self.quant_conv.parameters()) 202 | + list(self.post_quant_conv.parameters()) 203 | ) 204 | if self.learn_logvar: 205 | print(f"{self.__class__.__name__}: Learning logvar") 206 | ae_params_list.append(self.loss.logvar) 207 | opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) 208 | opt_disc = torch.optim.Adam( 209 | self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) 210 | ) 211 | return [opt_ae, opt_disc], [] 212 | 213 | def get_last_layer(self): 214 | return self.decoder.conv_out.weight 215 | 216 | @torch.no_grad() 217 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 218 | log = dict() 219 | x = self.get_input(batch, self.image_key) 220 | x = x.to(self.device) 221 | if not only_inputs: 222 | xrec, posterior = self(x) 223 | if x.shape[1] > 3: 224 | # colorize with random projection 225 | assert xrec.shape[1] > 3 226 | x = self.to_rgb(x) 227 | xrec = self.to_rgb(xrec) 228 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 229 | log["reconstructions"] = xrec 230 | if log_ema or self.use_ema: 231 | with self.ema_scope(): 232 | xrec_ema, posterior_ema = self(x) 233 | if x.shape[1] > 3: 234 | # colorize with random projection 235 | assert xrec_ema.shape[1] > 3 236 | xrec_ema = self.to_rgb(xrec_ema) 237 | log["samples_ema"] = self.decode( 238 | torch.randn_like(posterior_ema.sample()) 239 | ) 240 | log["reconstructions_ema"] = xrec_ema 241 | log["inputs"] = x 242 | return log 243 | 244 | def to_rgb(self, x): 245 | assert self.image_key == "segmentation" 246 | if not hasattr(self, "colorize"): 247 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 248 | x = F.conv2d(x, weight=self.colorize) 249 | x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 250 | return x 251 | 252 | 253 | class IdentityFirstStage(torch.nn.Module): 254 | def __init__(self, *args, vq_interface=False, **kwargs): 255 | self.vq_interface = vq_interface 256 | super().__init__() 257 | 258 | def encode(self, x, *args, **kwargs): 259 | return x 260 | 261 | def decode(self, x, *args, **kwargs): 262 | return x 263 | 264 | def quantize(self, x, *args, **kwargs): 265 | if self.vq_interface: 266 | return x, None, [None, None, None] 267 | return x 268 | 269 | def forward(self, x, *args, **kwargs): 270 | return x 271 | -------------------------------------------------------------------------------- /imagedream/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/imagedream/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /imagedream/ldm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/imagedream/ldm/modules/__init__.py -------------------------------------------------------------------------------- /imagedream/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/imagedream/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /imagedream/ldm/modules/diffusionmodules/adaptors.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | # FFN 9 | def FeedForward(dim, mult=4): 10 | inner_dim = int(dim * mult) 11 | return nn.Sequential( 12 | nn.LayerNorm(dim), 13 | nn.Linear(dim, inner_dim, bias=False), 14 | nn.GELU(), 15 | nn.Linear(inner_dim, dim, bias=False), 16 | ) 17 | 18 | 19 | def reshape_tensor(x, heads): 20 | bs, length, width = x.shape 21 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 22 | x = x.view(bs, length, heads, -1) 23 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 24 | x = x.transpose(1, 2) 25 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 26 | x = x.reshape(bs, heads, length, -1) 27 | return x 28 | 29 | 30 | class PerceiverAttention(nn.Module): 31 | def __init__(self, *, dim, dim_head=64, heads=8): 32 | super().__init__() 33 | self.scale = dim_head**-0.5 34 | self.dim_head = dim_head 35 | self.heads = heads 36 | inner_dim = dim_head * heads 37 | 38 | self.norm1 = nn.LayerNorm(dim) 39 | self.norm2 = nn.LayerNorm(dim) 40 | 41 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 42 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 43 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 44 | 45 | 46 | def forward(self, x, latents): 47 | """ 48 | Args: 49 | x (torch.Tensor): image features 50 | shape (b, n1, D) 51 | latent (torch.Tensor): latent features 52 | shape (b, n2, D) 53 | """ 54 | x = self.norm1(x) 55 | latents = self.norm2(latents) 56 | 57 | b, l, _ = latents.shape 58 | 59 | q = self.to_q(latents) 60 | kv_input = torch.cat((x, latents), dim=-2) 61 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 62 | 63 | q = reshape_tensor(q, self.heads) 64 | k = reshape_tensor(k, self.heads) 65 | v = reshape_tensor(v, self.heads) 66 | 67 | # attention 68 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 69 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 70 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 71 | out = weight @ v 72 | 73 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 74 | 75 | return self.to_out(out) 76 | 77 | 78 | class ImageProjModel(torch.nn.Module): 79 | """Projection Model""" 80 | def __init__(self, 81 | cross_attention_dim=1024, 82 | clip_embeddings_dim=1024, 83 | clip_extra_context_tokens=4): 84 | super().__init__() 85 | self.cross_attention_dim = cross_attention_dim 86 | self.clip_extra_context_tokens = clip_extra_context_tokens 87 | 88 | # from 1024 -> 4 * 1024 89 | self.proj = torch.nn.Linear( 90 | clip_embeddings_dim, 91 | self.clip_extra_context_tokens * cross_attention_dim) 92 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 93 | 94 | def forward(self, image_embeds): 95 | embeds = image_embeds 96 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 97 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 98 | return clip_extra_context_tokens 99 | 100 | 101 | class SimpleReSampler(nn.Module): 102 | def __init__(self, embedding_dim=1280, output_dim=1024): 103 | super().__init__() 104 | self.proj_out = nn.Linear(embedding_dim, output_dim) 105 | self.norm_out = nn.LayerNorm(output_dim) 106 | 107 | def forward(self, latents): 108 | """ 109 | latents: B 256 N 110 | """ 111 | latents = self.proj_out(latents) 112 | return self.norm_out(latents) 113 | 114 | 115 | class Resampler(nn.Module): 116 | def __init__( 117 | self, 118 | dim=1024, 119 | depth=8, 120 | dim_head=64, 121 | heads=16, 122 | num_queries=8, 123 | embedding_dim=768, 124 | output_dim=1024, 125 | ff_mult=4, 126 | ): 127 | super().__init__() 128 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 129 | self.proj_in = nn.Linear(embedding_dim, dim) 130 | self.proj_out = nn.Linear(dim, output_dim) 131 | self.norm_out = nn.LayerNorm(output_dim) 132 | 133 | self.layers = nn.ModuleList([]) 134 | for _ in range(depth): 135 | self.layers.append( 136 | nn.ModuleList( 137 | [ 138 | PerceiverAttention(dim=dim, 139 | dim_head=dim_head, 140 | heads=heads), 141 | FeedForward(dim=dim, mult=ff_mult), 142 | ] 143 | ) 144 | ) 145 | 146 | def forward(self, x): 147 | latents = self.latents.repeat(x.size(0), 1, 1) 148 | x = self.proj_in(x) 149 | for attn, ff in self.layers: 150 | latents = attn(x, latents) + latents 151 | latents = ff(latents) + latents 152 | 153 | latents = self.proj_out(latents) 154 | return self.norm_out(latents) 155 | 156 | 157 | if __name__ == '__main__': 158 | resampler = Resampler(embedding_dim=1280) 159 | resampler = SimpleReSampler(embedding_dim=1280) 160 | tensor = torch.rand(4, 257, 1280) 161 | embed = resampler(tensor) 162 | # embed = (tensor) 163 | print(embed.shape) 164 | -------------------------------------------------------------------------------- /imagedream/ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | import importlib 18 | 19 | 20 | def instantiate_from_config(config): 21 | if not "target" in config: 22 | if config == "__is_first_stage__": 23 | return None 24 | elif config == "__is_unconditional__": 25 | return None 26 | raise KeyError("Expected key `target` to instantiate.") 27 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 28 | 29 | 30 | def get_obj_from_str(string, reload=False): 31 | module, cls = string.rsplit(".", 1) 32 | if reload: 33 | module_imp = importlib.import_module(module) 34 | importlib.reload(module_imp) 35 | return getattr(importlib.import_module(module, package=None), cls) 36 | 37 | 38 | def make_beta_schedule( 39 | schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 40 | ): 41 | if schedule == "linear": 42 | betas = ( 43 | torch.linspace( 44 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 45 | ) 46 | ** 2 47 | ) 48 | 49 | elif schedule == "cosine": 50 | timesteps = ( 51 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 52 | ) 53 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 54 | alphas = torch.cos(alphas).pow(2) 55 | alphas = alphas / alphas[0] 56 | betas = 1 - alphas[1:] / alphas[:-1] 57 | betas = np.clip(betas, a_min=0, a_max=0.999) 58 | 59 | elif schedule == "sqrt_linear": 60 | betas = torch.linspace( 61 | linear_start, linear_end, n_timestep, dtype=torch.float64 62 | ) 63 | elif schedule == "sqrt": 64 | betas = ( 65 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 66 | ** 0.5 67 | ) 68 | else: 69 | raise ValueError(f"schedule '{schedule}' unknown.") 70 | return betas.numpy() 71 | 72 | def enforce_zero_terminal_snr(betas): 73 | betas = torch.tensor(betas) if not isinstance(betas, torch.Tensor) else betas 74 | # Convert betas to alphas_bar_sqrt 75 | alphas =1 - betas 76 | alphas_bar = alphas.cumprod(0) 77 | alphas_bar_sqrt = alphas_bar.sqrt() 78 | # Store old values. 79 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 80 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() 81 | # Shift so last timestep is zero. 82 | alphas_bar_sqrt -= alphas_bar_sqrt_T 83 | # Scale so first timestep is back to old value. 84 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 85 | # Convert alphas_bar_sqrt to betas 86 | alphas_bar = alphas_bar_sqrt ** 2 87 | alphas = alphas_bar[1:] / alphas_bar[:-1] 88 | alphas = torch.cat ([alphas_bar[0:1], alphas]) 89 | betas = 1 - alphas 90 | return betas 91 | 92 | 93 | def make_ddim_timesteps( 94 | ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True 95 | ): 96 | if ddim_discr_method == "uniform": 97 | c = num_ddpm_timesteps // num_ddim_timesteps 98 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 99 | elif ddim_discr_method == "quad": 100 | ddim_timesteps = ( 101 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 102 | ).astype(int) 103 | else: 104 | raise NotImplementedError( 105 | f'There is no ddim discretization method called "{ddim_discr_method}"' 106 | ) 107 | 108 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 109 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 110 | steps_out = ddim_timesteps + 1 111 | if verbose: 112 | print(f"Selected timesteps for ddim sampler: {steps_out}") 113 | return steps_out 114 | 115 | 116 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 117 | # select alphas for computing the variance schedule 118 | alphas = alphacums[ddim_timesteps] 119 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 120 | 121 | # according the the formula provided in https://arxiv.org/abs/2010.02502 122 | sigmas = eta * np.sqrt( 123 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) 124 | ) 125 | if verbose: 126 | print( 127 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" 128 | ) 129 | print( 130 | f"For the chosen value of eta, which is {eta}, " 131 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}" 132 | ) 133 | return sigmas, alphas, alphas_prev 134 | 135 | 136 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 137 | """ 138 | Create a beta schedule that discretizes the given alpha_t_bar function, 139 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 140 | :param num_diffusion_timesteps: the number of betas to produce. 141 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 142 | produces the cumulative product of (1-beta) up to that 143 | part of the diffusion process. 144 | :param max_beta: the maximum beta to use; use values lower than 1 to 145 | prevent singularities. 146 | """ 147 | betas = [] 148 | for i in range(num_diffusion_timesteps): 149 | t1 = i / num_diffusion_timesteps 150 | t2 = (i + 1) / num_diffusion_timesteps 151 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 152 | return np.array(betas) 153 | 154 | 155 | def extract_into_tensor(a, t, x_shape): 156 | b, *_ = t.shape 157 | out = a.gather(-1, t) 158 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 159 | 160 | 161 | def checkpoint(func, inputs, params, flag): 162 | """ 163 | Evaluate a function without caching intermediate activations, allowing for 164 | reduced memory at the expense of extra compute in the backward pass. 165 | :param func: the function to evaluate. 166 | :param inputs: the argument sequence to pass to `func`. 167 | :param params: a sequence of parameters `func` depends on but does not 168 | explicitly take as arguments. 169 | :param flag: if False, disable gradient checkpointing. 170 | """ 171 | if flag: 172 | args = tuple(inputs) + tuple(params) 173 | return CheckpointFunction.apply(func, len(inputs), *args) 174 | else: 175 | return func(*inputs) 176 | 177 | 178 | class CheckpointFunction(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, run_function, length, *args): 181 | ctx.run_function = run_function 182 | ctx.input_tensors = list(args[:length]) 183 | ctx.input_params = list(args[length:]) 184 | 185 | with torch.no_grad(): 186 | output_tensors = ctx.run_function(*ctx.input_tensors) 187 | return output_tensors 188 | 189 | @staticmethod 190 | def backward(ctx, *output_grads): 191 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 192 | with torch.enable_grad(): 193 | # Fixes a bug where the first op in run_function modifies the 194 | # Tensor storage in place, which is not allowed for detach()'d 195 | # Tensors. 196 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 197 | output_tensors = ctx.run_function(*shallow_copies) 198 | input_grads = torch.autograd.grad( 199 | output_tensors, 200 | ctx.input_tensors + ctx.input_params, 201 | output_grads, 202 | allow_unused=True, 203 | ) 204 | del ctx.input_tensors 205 | del ctx.input_params 206 | del output_tensors 207 | return (None, None) + input_grads 208 | 209 | 210 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 211 | """ 212 | Create sinusoidal timestep embeddings. 213 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 214 | These may be fractional. 215 | :param dim: the dimension of the output. 216 | :param max_period: controls the minimum frequency of the embeddings. 217 | :return: an [N x dim] Tensor of positional embeddings. 218 | """ 219 | if not repeat_only: 220 | half = dim // 2 221 | freqs = torch.exp( 222 | -math.log(max_period) 223 | * torch.arange(start=0, end=half, dtype=torch.float32) 224 | / half 225 | ).to(device=timesteps.device) 226 | args = timesteps[:, None].float() * freqs[None] 227 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 228 | if dim % 2: 229 | embedding = torch.cat( 230 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 231 | ) 232 | else: 233 | embedding = repeat(timesteps, "b -> b d", d=dim) 234 | # import pdb; pdb.set_trace() 235 | return embedding 236 | 237 | 238 | def zero_module(module): 239 | """ 240 | Zero out the parameters of a module and return it. 241 | """ 242 | for p in module.parameters(): 243 | p.detach().zero_() 244 | return module 245 | 246 | 247 | def scale_module(module, scale): 248 | """ 249 | Scale the parameters of a module and return it. 250 | """ 251 | for p in module.parameters(): 252 | p.detach().mul_(scale) 253 | return module 254 | 255 | 256 | def mean_flat(tensor): 257 | """ 258 | Take the mean over all non-batch dimensions. 259 | """ 260 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 261 | 262 | 263 | def normalization(channels): 264 | """ 265 | Make a standard normalization layer. 266 | :param channels: number of input channels. 267 | :return: an nn.Module for normalization. 268 | """ 269 | return GroupNorm32(32, channels) 270 | 271 | 272 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 273 | class SiLU(nn.Module): 274 | def forward(self, x): 275 | return x * torch.sigmoid(x) 276 | 277 | 278 | class GroupNorm32(nn.GroupNorm): 279 | def forward(self, x): 280 | return super().forward(x.float()).type(x.dtype) 281 | 282 | 283 | def conv_nd(dims, *args, **kwargs): 284 | """ 285 | Create a 1D, 2D, or 3D convolution module. 286 | """ 287 | if dims == 1: 288 | return nn.Conv1d(*args, **kwargs) 289 | elif dims == 2: 290 | return nn.Conv2d(*args, **kwargs) 291 | elif dims == 3: 292 | return nn.Conv3d(*args, **kwargs) 293 | raise ValueError(f"unsupported dimensions: {dims}") 294 | 295 | 296 | def linear(*args, **kwargs): 297 | """ 298 | Create a linear module. 299 | """ 300 | return nn.Linear(*args, **kwargs) 301 | 302 | 303 | def avg_pool_nd(dims, *args, **kwargs): 304 | """ 305 | Create a 1D, 2D, or 3D average pooling module. 306 | """ 307 | if dims == 1: 308 | return nn.AvgPool1d(*args, **kwargs) 309 | elif dims == 2: 310 | return nn.AvgPool2d(*args, **kwargs) 311 | elif dims == 3: 312 | return nn.AvgPool3d(*args, **kwargs) 313 | raise ValueError(f"unsupported dimensions: {dims}") 314 | 315 | 316 | class HybridConditioner(nn.Module): 317 | def __init__(self, c_concat_config, c_crossattn_config): 318 | super().__init__() 319 | self.concat_conditioner = instantiate_from_config(c_concat_config) 320 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 321 | 322 | def forward(self, c_concat, c_crossattn): 323 | c_concat = self.concat_conditioner(c_concat) 324 | c_crossattn = self.crossattn_conditioner(c_crossattn) 325 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} 326 | 327 | 328 | def noise_like(shape, device, repeat=False): 329 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 330 | shape[0], *((1,) * (len(shape) - 1)) 331 | ) 332 | noise = lambda: torch.randn(shape, device=device) 333 | return repeat_noise() if repeat else noise() 334 | 335 | 336 | # dummy replace 337 | def convert_module_to_f16(l): 338 | """ 339 | Convert primitive modules to float16. 340 | """ 341 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 342 | l.weight.data = l.weight.data.half() 343 | if l.bias is not None: 344 | l.bias.data = l.bias.data.half() 345 | 346 | def convert_module_to_f32(l): 347 | """ 348 | Convert primitive modules to float32, undoing convert_module_to_f16(). 349 | """ 350 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 351 | l.weight.data = l.weight.data.float() 352 | if l.bias is not None: 353 | l.bias.data = l.bias.data.float() 354 | -------------------------------------------------------------------------------- /imagedream/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/imagedream/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /imagedream/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /imagedream/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /imagedream/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/imagedream/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /imagedream/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 6 | 7 | import numpy as np 8 | import open_clip 9 | from PIL import Image 10 | from ...util import default, count_params 11 | 12 | 13 | class AbstractEncoder(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def encode(self, *args, **kwargs): 18 | raise NotImplementedError 19 | 20 | 21 | class IdentityEncoder(AbstractEncoder): 22 | def encode(self, x): 23 | return x 24 | 25 | 26 | class ClassEmbedder(nn.Module): 27 | def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): 28 | super().__init__() 29 | self.key = key 30 | self.embedding = nn.Embedding(n_classes, embed_dim) 31 | self.n_classes = n_classes 32 | self.ucg_rate = ucg_rate 33 | 34 | def forward(self, batch, key=None, disable_dropout=False): 35 | if key is None: 36 | key = self.key 37 | # this is for use in crossattn 38 | c = batch[key][:, None] 39 | if self.ucg_rate > 0.0 and not disable_dropout: 40 | mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 41 | c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) 42 | c = c.long() 43 | c = self.embedding(c) 44 | return c 45 | 46 | def get_unconditional_conditioning(self, bs, device="cuda"): 47 | uc_class = ( 48 | self.n_classes - 1 49 | ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 50 | uc = torch.ones((bs,), device=device) * uc_class 51 | uc = {self.key: uc} 52 | return uc 53 | 54 | 55 | def disabled_train(self, mode=True): 56 | """Overwrite model.train with this function to make sure train/eval mode 57 | does not change anymore.""" 58 | return self 59 | 60 | 61 | class FrozenT5Embedder(AbstractEncoder): 62 | """Uses the T5 transformer encoder for text""" 63 | 64 | def __init__( 65 | self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True 66 | ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 67 | super().__init__() 68 | self.tokenizer = T5Tokenizer.from_pretrained(version) 69 | self.transformer = T5EncoderModel.from_pretrained(version) 70 | self.device = device 71 | self.max_length = max_length # TODO: typical value? 72 | if freeze: 73 | self.freeze() 74 | 75 | def freeze(self): 76 | self.transformer = self.transformer.eval() 77 | # self.train = disabled_train 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, text): 82 | batch_encoding = self.tokenizer( 83 | text, 84 | truncation=True, 85 | max_length=self.max_length, 86 | return_length=True, 87 | return_overflowing_tokens=False, 88 | padding="max_length", 89 | return_tensors="pt", 90 | ) 91 | tokens = batch_encoding["input_ids"].to(self.device) 92 | outputs = self.transformer(input_ids=tokens) 93 | 94 | z = outputs.last_hidden_state 95 | return z 96 | 97 | def encode(self, text): 98 | return self(text) 99 | 100 | 101 | class FrozenCLIPEmbedder(AbstractEncoder): 102 | """Uses the CLIP transformer encoder for text (from huggingface)""" 103 | 104 | LAYERS = ["last", "pooled", "hidden"] 105 | 106 | def __init__( 107 | self, 108 | version="openai/clip-vit-large-patch14", 109 | device="cuda", 110 | max_length=77, 111 | freeze=True, 112 | layer="last", 113 | layer_idx=None, 114 | ): # clip-vit-base-patch32 115 | super().__init__() 116 | assert layer in self.LAYERS 117 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 118 | self.transformer = CLIPTextModel.from_pretrained(version) 119 | self.device = device 120 | self.max_length = max_length 121 | if freeze: 122 | self.freeze() 123 | self.layer = layer 124 | self.layer_idx = layer_idx 125 | if layer == "hidden": 126 | assert layer_idx is not None 127 | assert 0 <= abs(layer_idx) <= 12 128 | 129 | def freeze(self): 130 | self.transformer = self.transformer.eval() 131 | # self.train = disabled_train 132 | for param in self.parameters(): 133 | param.requires_grad = False 134 | 135 | def forward(self, text): 136 | batch_encoding = self.tokenizer( 137 | text, 138 | truncation=True, 139 | max_length=self.max_length, 140 | return_length=True, 141 | return_overflowing_tokens=False, 142 | padding="max_length", 143 | return_tensors="pt", 144 | ) 145 | tokens = batch_encoding["input_ids"].to(self.device) 146 | outputs = self.transformer( 147 | input_ids=tokens, output_hidden_states=self.layer == "hidden" 148 | ) 149 | if self.layer == "last": 150 | z = outputs.last_hidden_state 151 | elif self.layer == "pooled": 152 | z = outputs.pooler_output[:, None, :] 153 | else: 154 | z = outputs.hidden_states[self.layer_idx] 155 | return z 156 | 157 | def encode(self, text): 158 | return self(text) 159 | 160 | 161 | class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module): 162 | """ 163 | Uses the OpenCLIP transformer encoder for text 164 | """ 165 | 166 | LAYERS = [ 167 | # "pooled", 168 | "last", 169 | "penultimate", 170 | ] 171 | 172 | def __init__( 173 | self, 174 | arch="ViT-H-14", 175 | version="laion2b_s32b_b79k", 176 | device="cuda", 177 | max_length=77, 178 | freeze=True, 179 | layer="last", 180 | ip_mode=None 181 | ): 182 | """_summary_ 183 | 184 | Args: 185 | ip_mode (str, optional): what is the image promcessing mode. Defaults to None. 186 | 187 | """ 188 | super().__init__() 189 | assert layer in self.LAYERS 190 | model, _, preprocess = open_clip.create_model_and_transforms( 191 | arch, device=torch.device("cpu"), pretrained=version 192 | ) 193 | if ip_mode is None: 194 | del model.visual 195 | 196 | self.model = model 197 | self.preprocess = preprocess 198 | self.device = device 199 | self.max_length = max_length 200 | self.ip_mode = ip_mode 201 | if freeze: 202 | self.freeze() 203 | self.layer = layer 204 | if self.layer == "last": 205 | self.layer_idx = 0 206 | elif self.layer == "penultimate": 207 | self.layer_idx = 1 208 | else: 209 | raise NotImplementedError() 210 | 211 | def freeze(self): 212 | self.model = self.model.eval() 213 | for param in self.parameters(): 214 | param.requires_grad = False 215 | 216 | def forward(self, text): 217 | tokens = open_clip.tokenize(text) 218 | z = self.encode_with_transformer(tokens.to(self.device)) 219 | return z 220 | 221 | def forward_image(self, pil_image): 222 | if isinstance(pil_image, Image.Image): 223 | pil_image = [pil_image] 224 | if isinstance(pil_image, torch.Tensor): 225 | pil_image = pil_image.cpu().numpy() 226 | if isinstance(pil_image, np.ndarray): 227 | if pil_image.ndim == 3: 228 | pil_image = pil_image[None, :, :, :] 229 | pil_image = [Image.fromarray(x) for x in pil_image] 230 | 231 | images = [] 232 | for image in pil_image: 233 | images.append(self.preprocess(image).to(self.device)) 234 | 235 | image = torch.stack(images, 0) # to [b, 3, h, w] 236 | if self.ip_mode == "global": 237 | image_features = self.model.encode_image(image) 238 | image_features /= image_features.norm(dim=-1, keepdim=True) 239 | elif "local" in self.ip_mode: 240 | image_features = self.encode_image_with_transformer(image) 241 | 242 | return image_features # b, l 243 | 244 | def encode_image_with_transformer(self, x): 245 | visual = self.model.visual 246 | x = visual.conv1(x) # shape = [*, width, grid, grid] 247 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 248 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 249 | 250 | # class embeddings and positional embeddings 251 | x = torch.cat( 252 | [visual.class_embedding.to(x.dtype) + \ 253 | torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 254 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 255 | x = x + visual.positional_embedding.to(x.dtype) 256 | 257 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 258 | # x = visual.patch_dropout(x) 259 | x = visual.ln_pre(x) 260 | 261 | x = x.permute(1, 0, 2) # NLD -> LND 262 | hidden = self.image_transformer_forward(x) 263 | x = hidden[-2].permute(1, 0, 2) # LND -> NLD 264 | return x 265 | 266 | def image_transformer_forward(self, x): 267 | encoder_states = () 268 | trans = self.model.visual.transformer 269 | for r in trans.resblocks: 270 | if trans.grad_checkpointing and not torch.jit.is_scripting(): 271 | # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 272 | x = checkpoint(r, x, None, None, None) 273 | else: 274 | x = r(x, attn_mask=None) 275 | encoder_states = encoder_states + (x, ) 276 | return encoder_states 277 | 278 | def encode_with_transformer(self, text): 279 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 280 | x = x + self.model.positional_embedding 281 | x = x.permute(1, 0, 2) # NLD -> LND 282 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 283 | x = x.permute(1, 0, 2) # LND -> NLD 284 | x = self.model.ln_final(x) 285 | return x 286 | 287 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): 288 | for i, r in enumerate(self.model.transformer.resblocks): 289 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 290 | break 291 | if ( 292 | self.model.transformer.grad_checkpointing 293 | and not torch.jit.is_scripting() 294 | ): 295 | x = checkpoint(r, x, attn_mask) 296 | else: 297 | x = r(x, attn_mask=attn_mask) 298 | return x 299 | 300 | def encode(self, text): 301 | return self(text) 302 | 303 | 304 | class FrozenCLIPT5Encoder(AbstractEncoder): 305 | def __init__( 306 | self, 307 | clip_version="openai/clip-vit-large-patch14", 308 | t5_version="google/t5-v1_1-xl", 309 | device="cuda", 310 | clip_max_length=77, 311 | t5_max_length=77, 312 | ): 313 | super().__init__() 314 | self.clip_encoder = FrozenCLIPEmbedder( 315 | clip_version, device, max_length=clip_max_length 316 | ) 317 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 318 | print( 319 | f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 320 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." 321 | ) 322 | 323 | def encode(self, text): 324 | return self(text) 325 | 326 | def forward(self, text): 327 | clip_z = self.clip_encoder.encode(text) 328 | t5_z = self.t5_encoder.encode(text) 329 | return [clip_z, t5_z] 330 | -------------------------------------------------------------------------------- /imagedream/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import random 4 | import torch 5 | import numpy as np 6 | from collections import abc 7 | 8 | import multiprocessing as mp 9 | from threading import Thread 10 | from queue import Queue 11 | 12 | from inspect import isfunction 13 | from PIL import Image, ImageDraw, ImageFont 14 | 15 | 16 | def log_txt_as_img(wh, xc, size=10): 17 | # wh a tuple of (width, height) 18 | # xc a list of captions to plot 19 | b = len(xc) 20 | txts = list() 21 | for bi in range(b): 22 | txt = Image.new("RGB", wh, color="white") 23 | draw = ImageDraw.Draw(txt) 24 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 25 | nc = int(40 * (wh[0] / 256)) 26 | lines = "\n".join( 27 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) 28 | ) 29 | 30 | try: 31 | draw.text((0, 0), lines, fill="black", font=font) 32 | except UnicodeEncodeError: 33 | print("Cant encode string for logging. Skipping.") 34 | 35 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 36 | txts.append(txt) 37 | txts = np.stack(txts) 38 | txts = torch.tensor(txts) 39 | return txts 40 | 41 | 42 | def ismap(x): 43 | if not isinstance(x, torch.Tensor): 44 | return False 45 | return (len(x.shape) == 4) and (x.shape[1] > 3) 46 | 47 | 48 | def isimage(x): 49 | if not isinstance(x, torch.Tensor): 50 | return False 51 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 52 | 53 | 54 | def exists(x): 55 | return x is not None 56 | 57 | 58 | def default(val, d): 59 | if exists(val): 60 | return val 61 | return d() if isfunction(d) else d 62 | 63 | 64 | def mean_flat(tensor): 65 | """ 66 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 67 | Take the mean over all non-batch dimensions. 68 | """ 69 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 70 | 71 | 72 | def count_params(model, verbose=False): 73 | total_params = sum(p.numel() for p in model.parameters()) 74 | if verbose: 75 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 76 | return total_params 77 | 78 | 79 | def instantiate_from_config(config): 80 | if not "target" in config: 81 | if config == "__is_first_stage__": 82 | return None 83 | elif config == "__is_unconditional__": 84 | return None 85 | raise KeyError("Expected key `target` to instantiate.") 86 | # import pdb; pdb.set_trace() 87 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 88 | 89 | 90 | def get_obj_from_str(string, reload=False): 91 | module, cls = string.rsplit(".", 1) 92 | # import pdb; pdb.set_trace() 93 | if reload: 94 | module_imp = importlib.import_module(module) 95 | importlib.reload(module_imp) 96 | return getattr(importlib.import_module(module, package=None), cls) 97 | 98 | 99 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 100 | # create dummy dataset instance 101 | 102 | # run prefetching 103 | if idx_to_fn: 104 | res = func(data, worker_id=idx) 105 | else: 106 | res = func(data) 107 | Q.put([idx, res]) 108 | Q.put("Done") 109 | 110 | 111 | def parallel_data_prefetch( 112 | func: callable, 113 | data, 114 | n_proc, 115 | target_data_type="ndarray", 116 | cpu_intensive=True, 117 | use_worker_id=False, 118 | ): 119 | # if target_data_type not in ["ndarray", "list"]: 120 | # raise ValueError( 121 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 122 | # ) 123 | if isinstance(data, np.ndarray) and target_data_type == "list": 124 | raise ValueError("list expected but function got ndarray.") 125 | elif isinstance(data, abc.Iterable): 126 | if isinstance(data, dict): 127 | print( 128 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 129 | ) 130 | data = list(data.values()) 131 | if target_data_type == "ndarray": 132 | data = np.asarray(data) 133 | else: 134 | data = list(data) 135 | else: 136 | raise TypeError( 137 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 138 | ) 139 | 140 | if cpu_intensive: 141 | Q = mp.Queue(1000) 142 | proc = mp.Process 143 | else: 144 | Q = Queue(1000) 145 | proc = Thread 146 | # spawn processes 147 | if target_data_type == "ndarray": 148 | arguments = [ 149 | [func, Q, part, i, use_worker_id] 150 | for i, part in enumerate(np.array_split(data, n_proc)) 151 | ] 152 | else: 153 | step = ( 154 | int(len(data) / n_proc + 1) 155 | if len(data) % n_proc != 0 156 | else int(len(data) / n_proc) 157 | ) 158 | arguments = [ 159 | [func, Q, part, i, use_worker_id] 160 | for i, part in enumerate( 161 | [data[i : i + step] for i in range(0, len(data), step)] 162 | ) 163 | ] 164 | processes = [] 165 | for i in range(n_proc): 166 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 167 | processes += [p] 168 | 169 | # start processes 170 | print(f"Start prefetching...") 171 | import time 172 | 173 | start = time.time() 174 | gather_res = [[] for _ in range(n_proc)] 175 | try: 176 | for p in processes: 177 | p.start() 178 | 179 | k = 0 180 | while k < n_proc: 181 | # get result 182 | res = Q.get() 183 | if res == "Done": 184 | k += 1 185 | else: 186 | gather_res[res[0]] = res[1] 187 | 188 | except Exception as e: 189 | print("Exception: ", e) 190 | for p in processes: 191 | p.terminate() 192 | 193 | raise e 194 | finally: 195 | for p in processes: 196 | p.join() 197 | print(f"Prefetching complete. [{time.time() - start} sec.]") 198 | 199 | if target_data_type == "ndarray": 200 | if not isinstance(gather_res[0], np.ndarray): 201 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 202 | 203 | # order outputs 204 | return np.concatenate(gather_res, axis=0) 205 | elif target_data_type == "list": 206 | out = [] 207 | for r in gather_res: 208 | out.extend(r) 209 | return out 210 | else: 211 | return gather_res 212 | 213 | def set_seed(seed=None): 214 | random.seed(seed) 215 | np.random.seed(seed) 216 | if seed is not None: 217 | torch.manual_seed(seed) 218 | torch.cuda.manual_seed_all(seed) 219 | 220 | def add_random_background(image, bg_color=None): 221 | bg_color = np.random.rand() * 255 if bg_color is None else bg_color 222 | image = np.array(image) 223 | rgb, alpha = image[..., :3], image[..., 3:] 224 | alpha = alpha.astype(np.float32) / 255.0 225 | image_new = rgb * alpha + bg_color * (1 - alpha) 226 | return Image.fromarray(image_new.astype(np.uint8)) -------------------------------------------------------------------------------- /imagedream/model_zoo.py: -------------------------------------------------------------------------------- 1 | """ Utiliy functions to load pre-trained models more easily """ 2 | import os 3 | import pkg_resources 4 | from omegaconf import OmegaConf 5 | 6 | import torch 7 | from huggingface_hub import hf_hub_download 8 | 9 | from imagedream.ldm.util import instantiate_from_config 10 | 11 | 12 | PRETRAINED_MODELS = { 13 | "sd-v2.1-base-4view-ipmv": { 14 | "config": "sd_v2_base_ipmv.yaml", 15 | "repo_id": "Peng-Wang/ImageDream", 16 | "filename": "sd-v2.1-base-4view-ipmv.pt", 17 | }, 18 | "sd-v2.1-base-4view-ipmv-local": { 19 | "config": "sd_v2_base_ipmv_local.yaml", 20 | "repo_id": "Peng-Wang/ImageDream", 21 | "filename": "sd-v2.1-base-4view-ipmv-local.pt", 22 | }, 23 | } 24 | 25 | 26 | def get_config_file(config_path): 27 | cfg_file = pkg_resources.resource_filename( 28 | "imagedream", os.path.join("configs", config_path) 29 | ) 30 | if not os.path.exists(cfg_file): 31 | raise RuntimeError(f"Config {config_path} not available!") 32 | return cfg_file 33 | 34 | 35 | def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None): 36 | if (config_path is not None) and (ckpt_path is not None): 37 | config = OmegaConf.load(config_path) 38 | model = instantiate_from_config(config.model) 39 | model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) 40 | return model 41 | 42 | if not model_name in PRETRAINED_MODELS: 43 | raise RuntimeError( 44 | f"Model name {model_name} is not a pre-trained model. Available models are:\n- " 45 | + "\n- ".join(PRETRAINED_MODELS.keys()) 46 | ) 47 | model_info = PRETRAINED_MODELS[model_name] 48 | 49 | # Instiantiate the model 50 | print(f"Loading model from config: {model_info['config']}") 51 | config_file = get_config_file(model_info["config"]) 52 | config = OmegaConf.load(config_file) 53 | model = instantiate_from_config(config.model) 54 | 55 | # Load pre-trained checkpoint from huggingface 56 | if not ckpt_path: 57 | ckpt_path = hf_hub_download( 58 | repo_id=model_info["repo_id"], 59 | filename=model_info["filename"], 60 | cache_dir=cache_dir, 61 | ) 62 | print(f"Loading model from cache file: {ckpt_path}") 63 | model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) 64 | return model 65 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time 4 | import nvdiffrast.torch as dr 5 | from util.utils import get_tri 6 | import tempfile 7 | from mesh import Mesh 8 | import zipfile 9 | def generate3d(model, rgb, ccm, device): 10 | 11 | color_tri = torch.from_numpy(rgb)/255 12 | xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255 13 | color = color_tri.permute(2,0,1) 14 | xyz = xyz_tri.permute(2,0,1) 15 | 16 | 17 | def get_imgs(color): 18 | # color : [C, H, W*6] 19 | color_list = [] 20 | color_list.append(color[:,:,256*5:256*(1+5)]) 21 | for i in range(0,5): 22 | color_list.append(color[:,:,256*i:256*(1+i)]) 23 | return torch.stack(color_list, dim=0)# [6, C, H, W] 24 | 25 | triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C] 26 | 27 | color = get_imgs(color) 28 | xyz = get_imgs(xyz) 29 | 30 | color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0) 31 | xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0) 32 | 33 | triplane = torch.cat([color,xyz],dim=1).to(device) 34 | # 3D visualize 35 | model.eval() 36 | glctx = dr.RasterizeCudaContext() 37 | 38 | if model.denoising == True: 39 | tnew = 20 40 | tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device) 41 | noise_new = torch.randn_like(triplane) *0.5+0.5 42 | triplane = model.scheduler.add_noise(triplane, noise_new, tnew) 43 | start_time = time.time() 44 | with torch.no_grad(): 45 | triplane_feature2 = model.unet2(triplane,tnew) 46 | end_time = time.time() 47 | elapsed_time = end_time - start_time 48 | print(f"unet takes {elapsed_time}s") 49 | else: 50 | triplane_feature2 = model.unet2(triplane) 51 | 52 | 53 | with torch.no_grad(): 54 | data_config = { 55 | 'resolution': [1024, 1024], 56 | "triview_color": triplane_color.to(device), 57 | } 58 | 59 | verts, faces = model.decode(data_config, triplane_feature2) 60 | 61 | data_config['verts'] = verts[0] 62 | data_config['faces'] = faces 63 | 64 | 65 | from kiui.mesh_utils import clean_mesh 66 | verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=False, remesh_size=0.005) 67 | data_config['verts'] = torch.from_numpy(verts).cuda().contiguous() 68 | data_config['faces'] = torch.from_numpy(faces).cuda().contiguous() 69 | 70 | start_time = time.time() 71 | with torch.no_grad(): 72 | mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name 73 | model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2) 74 | 75 | mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z") 76 | mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name 77 | mesh.write(mesh_path_glb+".glb") 78 | 79 | # mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb') 80 | # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name 81 | # mesh_obj2.export(mesh_path_obj2+".obj") 82 | 83 | with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip: 84 | myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj') 85 | myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png') 86 | myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl') 87 | 88 | end_time = time.time() 89 | elapsed_time = end_time - start_time 90 | print(f"uv takes {elapsed_time}s") 91 | return mesh_path_glb+".glb", mesh_path_obj+'.zip' 92 | -------------------------------------------------------------------------------- /launch_train.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | # set default values for the environment variables 4 | export OMP_NUM_THREADS=8 5 | if [ -z "$ADDR" ] 6 | then 7 | export ADDR=127.0.0.1 8 | fi 9 | 10 | if [ -z "$WORLD_SIZE" ] 11 | then 12 | export WORLD_SIZE=1 13 | fi 14 | 15 | if [ -z "$RANK" ] 16 | then 17 | export RANK=0 18 | fi 19 | 20 | if [ -z "$MASTER_PORT" ] 21 | then 22 | export MASTER_PORT=29501 23 | fi 24 | 25 | export WANDB_MODE=offline 26 | accelerate_args="--config_file acce.yaml --num_machines $WORLD_SIZE \ 27 | --machine_rank $RANK --num_processes 1 \ 28 | --main_process_port $MASTER_PORT \ 29 | --main_process_ip $ADDR" 30 | echo $accelerate_args 31 | 32 | # train stage 1 33 | accelerate launch $accelerate_args train.py --config configs/nf7_v3_SNR_rd_size_stroke_train.yaml \ 34 | config.batch_size=1 \ 35 | config.eval_interval=100 36 | 37 | 38 | # train stage 2 39 | # accelerate launch $accelerate_args train_stage2.py --config configs/stage2-v2-snr_train.yaml \ 40 | # config.batch_size=1 \ 41 | # config.eval_interval=100 -------------------------------------------------------------------------------- /libs/sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from imagedream.camera_utils import get_camera_for_index 4 | from imagedream.ldm.util import set_seed, add_random_background 5 | from libs.base_utils import do_resize_content 6 | from imagedream.ldm.models.diffusion.ddim import DDIMSampler 7 | from torchvision import transforms as T 8 | 9 | 10 | class ImageDreamDiffusion: 11 | def __init__( 12 | self, 13 | model, 14 | device, 15 | dtype, 16 | mode, 17 | num_frames, 18 | camera_views, 19 | ref_position, 20 | random_background=False, 21 | offset_noise=False, 22 | resize_rate=1, 23 | image_size=256, 24 | seed=1234, 25 | ) -> None: 26 | assert mode in ["pixel", "local"] 27 | size = image_size 28 | self.seed = seed 29 | batch_size = max(4, num_frames) 30 | 31 | neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." 32 | uc = model.get_learned_conditioning([neg_texts]).to(device) 33 | sampler = DDIMSampler(model) 34 | 35 | # pre-compute camera matrices 36 | camera = [get_camera_for_index(i).squeeze() for i in camera_views] 37 | camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero 38 | camera = torch.stack(camera) 39 | camera = camera.repeat(batch_size // num_frames, 1).to(device) 40 | 41 | self.image_transform = T.Compose( 42 | [ 43 | T.Resize((size, size)), 44 | T.ToTensor(), 45 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 46 | ] 47 | ) 48 | self.dtype = dtype 49 | self.ref_position = ref_position 50 | self.mode = mode 51 | self.random_background = random_background 52 | self.resize_rate = resize_rate 53 | self.num_frames = num_frames 54 | self.size = size 55 | self.device = device 56 | self.batch_size = batch_size 57 | self.model = model 58 | self.sampler = sampler 59 | self.uc = uc 60 | self.camera = camera 61 | self.offset_noise = offset_noise 62 | 63 | @staticmethod 64 | def i2i( 65 | model, 66 | image_size, 67 | prompt, 68 | uc, 69 | sampler, 70 | ip=None, 71 | step=20, 72 | scale=5.0, 73 | batch_size=8, 74 | ddim_eta=0.0, 75 | dtype=torch.float32, 76 | device="cuda", 77 | camera=None, 78 | num_frames=4, 79 | pixel_control=False, 80 | transform=None, 81 | offset_noise=False, 82 | ): 83 | """ The function supports additional image prompt. 84 | Args: 85 | model (_type_): the image dream model 86 | image_size (_type_): size of diffusion output (standard 256) 87 | prompt (_type_): text prompt for the image (prompt in type str) 88 | uc (_type_): unconditional vector (tensor in shape [1, 77, 1024]) 89 | sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler 90 | ip (Image, optional): the image prompt. Defaults to None. 91 | step (int, optional): _description_. Defaults to 20. 92 | scale (float, optional): _description_. Defaults to 7.5. 93 | batch_size (int, optional): _description_. Defaults to 8. 94 | ddim_eta (float, optional): _description_. Defaults to 0.0. 95 | dtype (_type_, optional): _description_. Defaults to torch.float32. 96 | device (str, optional): _description_. Defaults to "cuda". 97 | camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 98 | num_frames (int, optional): _num of frames (views) to generate 99 | pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode 100 | transform: Compose( 101 | Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn) 102 | ToTensor() 103 | Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 104 | ) 105 | """ 106 | ip_raw = ip 107 | if type(prompt) != list: 108 | prompt = [prompt] 109 | with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype): 110 | c = model.get_learned_conditioning(prompt).to( 111 | device 112 | ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05 113 | c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size 114 | uc_ = {"context": uc.repeat(batch_size, 1, 1)} 115 | 116 | if camera is not None: 117 | c_["camera"] = uc_["camera"] = ( 118 | camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 119 | ) 120 | c_["num_frames"] = uc_["num_frames"] = num_frames 121 | 122 | if ip is not None: 123 | ip_embed = model.get_learned_image_conditioning(ip).to( 124 | device 125 | ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12 126 | ip_ = ip_embed.repeat(batch_size, 1, 1) 127 | c_["ip"] = ip_ 128 | uc_["ip"] = torch.zeros_like(ip_) 129 | 130 | if pixel_control: 131 | assert camera is not None 132 | ip = transform(ip).to( 133 | device 134 | ) # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00 135 | ip_img = model.get_first_stage_encoding( 136 | model.encode_first_stage(ip[None, :, :, :]) 137 | ) # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55 138 | c_["ip_img"] = ip_img 139 | uc_["ip_img"] = torch.zeros_like(ip_img) 140 | 141 | shape = [4, image_size // 8, image_size // 8] # [4, 32, 32] 142 | if offset_noise: 143 | ref = transform(ip_raw).to(device) 144 | ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :])) 145 | ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True) 146 | time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device) 147 | x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps) 148 | 149 | samples_ddim, _ = ( 150 | sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43 151 | S=step, 152 | conditioning=c_, 153 | batch_size=batch_size, 154 | shape=shape, 155 | verbose=False, 156 | unconditional_guidance_scale=scale, 157 | unconditional_conditioning=uc_, 158 | eta=ddim_eta, 159 | x_T=x_T if offset_noise else None, 160 | ) 161 | ) 162 | 163 | x_sample = model.decode_first_stage(samples_ddim) 164 | x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) 165 | x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() 166 | 167 | return list(x_sample.astype(np.uint8)) 168 | 169 | def diffuse(self, t, ip, n_test=2): 170 | set_seed(self.seed) 171 | ip = do_resize_content(ip, self.resize_rate) 172 | if self.random_background: 173 | ip = add_random_background(ip) 174 | 175 | images = [] 176 | for _ in range(n_test): 177 | img = self.i2i( 178 | self.model, 179 | self.size, 180 | t, 181 | self.uc, 182 | self.sampler, 183 | ip=ip, 184 | step=50, 185 | scale=5, 186 | batch_size=self.batch_size, 187 | ddim_eta=0.0, 188 | dtype=self.dtype, 189 | device=self.device, 190 | camera=self.camera, 191 | num_frames=self.num_frames, 192 | pixel_control=(self.mode == "pixel"), 193 | transform=self.image_transform, 194 | offset_noise=self.offset_noise, 195 | ) 196 | img = np.concatenate(img, 1) 197 | img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1) 198 | images.append(img) 199 | set_seed() # unset random and numpy seed 200 | return images 201 | 202 | 203 | class ImageDreamDiffusionStage2: 204 | def __init__( 205 | self, 206 | model, 207 | device, 208 | dtype, 209 | num_frames, 210 | camera_views, 211 | ref_position, 212 | random_background=False, 213 | offset_noise=False, 214 | resize_rate=1, 215 | mode="pixel", 216 | image_size=256, 217 | seed=1234, 218 | ) -> None: 219 | assert mode in ["pixel", "local"] 220 | 221 | size = image_size 222 | self.seed = seed 223 | batch_size = max(4, num_frames) 224 | 225 | neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." 226 | uc = model.get_learned_conditioning([neg_texts]).to(device) 227 | sampler = DDIMSampler(model) 228 | 229 | # pre-compute camera matrices 230 | camera = [get_camera_for_index(i).squeeze() for i in camera_views] 231 | if ref_position is not None: 232 | camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero 233 | camera = torch.stack(camera) 234 | camera = camera.repeat(batch_size // num_frames, 1).to(device) 235 | 236 | self.image_transform = T.Compose( 237 | [ 238 | T.Resize((size, size)), 239 | T.ToTensor(), 240 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 241 | ] 242 | ) 243 | 244 | self.dtype = dtype 245 | self.mode = mode 246 | self.ref_position = ref_position 247 | self.random_background = random_background 248 | self.resize_rate = resize_rate 249 | self.num_frames = num_frames 250 | self.size = size 251 | self.device = device 252 | self.batch_size = batch_size 253 | self.model = model 254 | self.sampler = sampler 255 | self.uc = uc 256 | self.camera = camera 257 | self.offset_noise = offset_noise 258 | 259 | @staticmethod 260 | def i2iStage2( 261 | model, 262 | image_size, 263 | prompt, 264 | uc, 265 | sampler, 266 | pixel_images, 267 | ip=None, 268 | step=20, 269 | scale=5.0, 270 | batch_size=8, 271 | ddim_eta=0.0, 272 | dtype=torch.float32, 273 | device="cuda", 274 | camera=None, 275 | num_frames=4, 276 | pixel_control=False, 277 | transform=None, 278 | offset_noise=False, 279 | ): 280 | ip_raw = ip 281 | if type(prompt) != list: 282 | prompt = [prompt] 283 | with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype): 284 | c = model.get_learned_conditioning(prompt).to( 285 | device 286 | ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05 287 | c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size 288 | uc_ = {"context": uc.repeat(batch_size, 1, 1)} 289 | 290 | if camera is not None: 291 | c_["camera"] = uc_["camera"] = ( 292 | camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 293 | ) 294 | c_["num_frames"] = uc_["num_frames"] = num_frames 295 | 296 | if ip is not None: 297 | ip_embed = model.get_learned_image_conditioning(ip).to( 298 | device 299 | ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12 300 | ip_ = ip_embed.repeat(batch_size, 1, 1) 301 | c_["ip"] = ip_ 302 | uc_["ip"] = torch.zeros_like(ip_) 303 | 304 | if pixel_control: 305 | assert camera is not None 306 | 307 | transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images]) 308 | latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images)) 309 | 310 | c_["pixel_images"] = latent_pixel_images 311 | uc_["pixel_images"] = torch.zeros_like(latent_pixel_images) 312 | 313 | shape = [4, image_size // 8, image_size // 8] # [4, 32, 32] 314 | if offset_noise: 315 | ref = transform(ip_raw).to(device) 316 | ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :])) 317 | ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True) 318 | time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device) 319 | x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps) 320 | 321 | samples_ddim, _ = ( 322 | sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43 323 | S=step, 324 | conditioning=c_, 325 | batch_size=batch_size, 326 | shape=shape, 327 | verbose=False, 328 | unconditional_guidance_scale=scale, 329 | unconditional_conditioning=uc_, 330 | eta=ddim_eta, 331 | x_T=x_T if offset_noise else None, 332 | ) 333 | ) 334 | x_sample = model.decode_first_stage(samples_ddim) 335 | x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) 336 | x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() 337 | 338 | return list(x_sample.astype(np.uint8)) 339 | 340 | @torch.no_grad() 341 | def diffuse(self, t, ip, pixel_images, n_test=2): 342 | set_seed(self.seed) 343 | ip = do_resize_content(ip, self.resize_rate) 344 | pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images] 345 | 346 | if self.random_background: 347 | bg_color = np.random.rand() * 255 348 | ip = add_random_background(ip, bg_color) 349 | pixel_images = [add_random_background(i, bg_color) for i in pixel_images] 350 | 351 | images = [] 352 | for _ in range(n_test): 353 | img = self.i2iStage2( 354 | self.model, 355 | self.size, 356 | t, 357 | self.uc, 358 | self.sampler, 359 | pixel_images=pixel_images, 360 | ip=ip, 361 | step=50, 362 | scale=5, 363 | batch_size=self.batch_size, 364 | ddim_eta=0.0, 365 | dtype=self.dtype, 366 | device=self.device, 367 | camera=self.camera, 368 | num_frames=self.num_frames, 369 | pixel_control=(self.mode == "pixel"), 370 | transform=self.image_transform, 371 | offset_noise=self.offset_noise, 372 | ) 373 | img = np.concatenate(img, 1) 374 | img = np.concatenate( 375 | (img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]), 376 | axis=1, 377 | ) 378 | images.append(img) 379 | set_seed() # unset random and numpy seed 380 | return images 381 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.crm.model import CRM -------------------------------------------------------------------------------- /model/archs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/model/archs/__init__.py -------------------------------------------------------------------------------- /model/archs/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/archs/decoders/shape_texture_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TetTexNet(nn.Module): 7 | def __init__(self, plane_reso=64, padding=0.1, fea_concat=True): 8 | super().__init__() 9 | # self.c_dim = c_dim 10 | self.plane_reso = plane_reso 11 | self.padding = padding 12 | self.fea_concat = fea_concat 13 | 14 | def forward(self, rolled_out_feature, query): 15 | # rolled_out_feature: rolled-out triplane feature 16 | # query: queried xyz coordinates (should be scaled consistently to ptr cloud) 17 | 18 | plane_reso = self.plane_reso 19 | 20 | triplane_feature = dict() 21 | triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso] 22 | triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso] 23 | triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:] 24 | 25 | query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy') 26 | query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz') 27 | query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx') 28 | 29 | if self.fea_concat: 30 | query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1) 31 | else: 32 | query_feature = query_feature_xy + query_feature_yz + query_feature_zx 33 | 34 | output = query_feature.permute(0, 2, 1) 35 | 36 | return output 37 | 38 | # uses values from plane_feature and pixel locations from vgrid to interpolate feature 39 | def sample_plane_feature(self, query, plane_feature, plane): 40 | # CYF note: 41 | # for pretraining, query are uniformly sampled positions w.i. [-scale, scale] 42 | # for training, query are essentially tetrahedra grid vertices, which are 43 | # also within [-scale, scale] in the current version! 44 | # xy range [-scale, scale] 45 | if plane == 'xy': 46 | xy = query[:, :, [0, 1]] 47 | elif plane == 'yz': 48 | xy = query[:, :, [1, 2]] 49 | elif plane == 'zx': 50 | xy = query[:, :, [2, 0]] 51 | else: 52 | raise ValueError("Error! Invalid plane type!") 53 | 54 | xy = xy[:, :, None].float() 55 | # not seem necessary to rescale the grid, because from 56 | # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html, 57 | # it specifies sampling locations normalized by plane_feature's spatial dimension, 58 | # which is within [-scale, scale] as specified by encoder's calling of coordinate2index() 59 | vgrid = 1.0 * xy 60 | sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1) 61 | 62 | return sampled_feat 63 | -------------------------------------------------------------------------------- /model/archs/mlp_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class SdfMlp(nn.Module): 6 | def __init__(self, input_dim, hidden_dim=512, bias=True): 7 | super().__init__() 8 | self.input_dim = input_dim 9 | self.hidden_dim = hidden_dim 10 | 11 | self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) 12 | self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias) 13 | self.fc3 = nn.Linear(hidden_dim, 4, bias=bias) 14 | 15 | 16 | def forward(self, input): 17 | x = F.relu(self.fc1(input)) 18 | x = F.relu(self.fc2(x)) 19 | out = self.fc3(x) 20 | return out 21 | 22 | 23 | class RgbMlp(nn.Module): 24 | def __init__(self, input_dim, hidden_dim=512, bias=True): 25 | super().__init__() 26 | self.input_dim = input_dim 27 | self.hidden_dim = hidden_dim 28 | 29 | self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) 30 | self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias) 31 | self.fc3 = nn.Linear(hidden_dim, 3, bias=bias) 32 | 33 | def forward(self, input): 34 | x = F.relu(self.fc1(input)) 35 | x = F.relu(self.fc2(x)) 36 | out = self.fc3(x) 37 | 38 | return out 39 | 40 | -------------------------------------------------------------------------------- /model/archs/unet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Codes are from: 3 | https://github.com/jaxony/unet-pytorch/blob/master/model.py 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | from diffusers import UNet2DModel 9 | import einops 10 | class UNetPP(nn.Module): 11 | ''' 12 | Wrapper for UNet in diffusers 13 | ''' 14 | def __init__(self, in_channels): 15 | super(UNetPP, self).__init__() 16 | self.in_channels = in_channels 17 | self.unet = UNet2DModel( 18 | sample_size=[256, 256*3], 19 | in_channels=in_channels, 20 | out_channels=32, 21 | layers_per_block=2, 22 | block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4), 23 | down_block_types=( 24 | "DownBlock2D", 25 | "DownBlock2D", 26 | "DownBlock2D", 27 | "AttnDownBlock2D", 28 | "AttnDownBlock2D", 29 | "AttnDownBlock2D", 30 | "DownBlock2D", 31 | ), 32 | up_block_types=( 33 | "UpBlock2D", 34 | "AttnUpBlock2D", 35 | "AttnUpBlock2D", 36 | "AttnUpBlock2D", 37 | "UpBlock2D", 38 | "UpBlock2D", 39 | "UpBlock2D", 40 | ), 41 | ) 42 | 43 | self.unet.enable_xformers_memory_efficient_attention() 44 | if in_channels > 12: 45 | self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3])) 46 | 47 | def forward(self, x, t=256): 48 | learned_plane = self.learned_plane 49 | if x.shape[1] < self.in_channels: 50 | learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device) 51 | x = torch.cat([x, learned_plane], dim = 1) 52 | return self.unet(x, t).sample 53 | 54 | -------------------------------------------------------------------------------- /model/crm/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | 8 | from pathlib import Path 9 | import cv2 10 | import trimesh 11 | import nvdiffrast.torch as dr 12 | 13 | from model.archs.decoders.shape_texture_net import TetTexNet 14 | from model.archs.unet import UNetPP 15 | from util.renderer import Renderer 16 | from model.archs.mlp_head import SdfMlp, RgbMlp 17 | import xatlas 18 | 19 | 20 | class Dummy: 21 | pass 22 | 23 | class CRM(nn.Module): 24 | def __init__(self, specs): 25 | super(CRM, self).__init__() 26 | 27 | self.specs = specs 28 | # configs 29 | input_specs = specs["Input"] 30 | self.input = Dummy() 31 | self.input.scale = input_specs['scale'] 32 | self.input.resolution = input_specs['resolution'] 33 | self.tet_grid_size = input_specs['tet_grid_size'] 34 | self.camera_angle_num = input_specs['camera_angle_num'] 35 | 36 | self.arch = Dummy() 37 | self.arch.fea_concat = specs["ArchSpecs"]["fea_concat"] 38 | self.arch.mlp_bias = specs["ArchSpecs"]["mlp_bias"] 39 | 40 | self.dec = Dummy() 41 | self.dec.c_dim = specs["DecoderSpecs"]["c_dim"] 42 | self.dec.plane_resolution = specs["DecoderSpecs"]["plane_resolution"] 43 | 44 | self.geo_type = specs["Train"].get("geo_type", "flex") # "dmtet" or "flex" 45 | 46 | self.unet2 = UNetPP(in_channels=self.dec.c_dim) 47 | 48 | mlp_chnl_s = 3 if self.arch.fea_concat else 1 # 3 for queried triplane feature concatenation 49 | self.decoder = TetTexNet(plane_reso=self.dec.plane_resolution, fea_concat=self.arch.fea_concat) 50 | 51 | if self.geo_type == "flex": 52 | self.weightMlp = nn.Sequential( 53 | nn.Linear(mlp_chnl_s * 32 * 8, 512), 54 | nn.SiLU(), 55 | nn.Linear(512, 21)) 56 | 57 | self.sdfMlp = SdfMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias) 58 | self.rgbMlp = RgbMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias) 59 | self.renderer = Renderer(tet_grid_size=self.tet_grid_size, camera_angle_num=self.camera_angle_num, 60 | scale=self.input.scale, geo_type = self.geo_type) 61 | 62 | 63 | self.spob = True if specs['Pretrain']['mode'] is None else False # whether to add sphere 64 | self.radius = specs['Pretrain']['radius'] # used when spob 65 | 66 | self.denoising = True 67 | from diffusers import DDIMScheduler 68 | self.scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler") 69 | 70 | def decode(self, data, triplane_feature2): 71 | if self.geo_type == "flex": 72 | tet_verts = self.renderer.flexicubes.verts.unsqueeze(0) 73 | tet_indices = self.renderer.flexicubes.indices 74 | 75 | dec_verts = self.decoder(triplane_feature2, tet_verts) 76 | out = self.sdfMlp(dec_verts) 77 | 78 | weight = None 79 | if self.geo_type == "flex": 80 | grid_feat = torch.index_select(input=dec_verts, index=self.renderer.flexicubes.indices.reshape(-1),dim=1) 81 | grid_feat = grid_feat.reshape(dec_verts.shape[0], self.renderer.flexicubes.indices.shape[0], self.renderer.flexicubes.indices.shape[1] * dec_verts.shape[-1]) 82 | weight = self.weightMlp(grid_feat) 83 | weight = weight * 0.1 84 | 85 | pred_sdf, deformation = out[..., 0], out[..., 1:] 86 | if self.spob: 87 | pred_sdf = pred_sdf + self.radius - torch.sqrt((tet_verts**2).sum(-1)) 88 | 89 | _, verts, faces = self.renderer(data, pred_sdf, deformation, tet_verts, tet_indices, weight= weight) 90 | return verts[0].unsqueeze(0), faces[0].int() 91 | 92 | def export_mesh(self, data, out_dir, ind, device=None, tri_fea_2 = None): 93 | verts = data['verts'] 94 | faces = data['faces'] 95 | 96 | dec_verts = self.decoder(tri_fea_2, verts.unsqueeze(0)) 97 | colors = self.rgbMlp(dec_verts).squeeze().detach().cpu().numpy() 98 | # Expect predicted colors value range from [-1, 1] 99 | colors = (colors * 0.5 + 0.5).clip(0, 1) 100 | 101 | verts = verts.squeeze().cpu().numpy() 102 | faces = faces[..., [2, 1, 0]].squeeze().cpu().numpy() 103 | 104 | # export the final mesh 105 | with torch.no_grad(): 106 | mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault... 107 | mesh.export(out_dir / f'{ind}.obj') 108 | 109 | def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None): 110 | 111 | mesh_v = data['verts'].squeeze().cpu().numpy() 112 | mesh_pos_idx = data['faces'].squeeze().cpu().numpy() 113 | 114 | def interpolate(attr, rast, attr_idx, rast_db=None): 115 | return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, 116 | diff_attrs=None if rast_db is None else 'all') 117 | 118 | vmapping, indices, uvs = xatlas.parametrize(mesh_v, mesh_pos_idx) 119 | 120 | mesh_v = torch.tensor(mesh_v, dtype=torch.float32, device=device) 121 | mesh_pos_idx = torch.tensor(mesh_pos_idx, dtype=torch.int64, device=device) 122 | 123 | # Convert to tensors 124 | indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) 125 | 126 | uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) 127 | mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) 128 | # mesh_v_tex. ture 129 | uv_clip = uvs[None, ...] * 2.0 - 1.0 130 | 131 | # pad to four component coordinate 132 | uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) 133 | 134 | # rasterize 135 | rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), res) 136 | 137 | # Interpolate world space position 138 | gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) 139 | mask = rast[..., 3:4] > 0 140 | 141 | # return uvs, mesh_tex_idx, gb_pos, mask 142 | gb_pos_unsqz = gb_pos.view(-1, 3) 143 | mask_unsqz = mask.view(-1) 144 | tex_unsqz = torch.zeros_like(gb_pos_unsqz) + 1 145 | 146 | gb_mask_pos = gb_pos_unsqz[mask_unsqz] 147 | 148 | gb_mask_pos = gb_mask_pos[None, ] 149 | 150 | with torch.no_grad(): 151 | 152 | dec_verts = self.decoder(tri_fea_2, gb_mask_pos) 153 | colors = self.rgbMlp(dec_verts).squeeze() 154 | 155 | # Expect predicted colors value range from [-1, 1] 156 | lo, hi = (-1, 1) 157 | colors = (colors - lo) * (255 / (hi - lo)) 158 | colors = colors.clip(0, 255) 159 | 160 | tex_unsqz[mask_unsqz] = colors 161 | 162 | tex = tex_unsqz.view(res + (3,)) 163 | 164 | verts = mesh_v.squeeze().cpu().numpy() 165 | faces = mesh_pos_idx[..., [2, 1, 0]].squeeze().cpu().numpy() 166 | # faces = mesh_pos_idx 167 | # faces = faces.detach().cpu().numpy() 168 | # faces = faces[..., [2, 1, 0]] 169 | indices = indices[..., [2, 1, 0]] 170 | 171 | # xatlas.export(f"{out_dir}/{ind}.obj", verts[vmapping], indices, uvs) 172 | matname = f'{out_dir}.mtl' 173 | # matname = f'{out_dir}/{ind}.mtl' 174 | fid = open(matname, 'w') 175 | fid.write('newmtl material_0\n') 176 | fid.write('Kd 1 1 1\n') 177 | fid.write('Ka 1 1 1\n') 178 | # fid.write('Ks 0 0 0\n') 179 | fid.write('Ks 0.4 0.4 0.4\n') 180 | fid.write('Ns 10\n') 181 | fid.write('illum 2\n') 182 | fid.write(f'map_Kd {out_dir.split("/")[-1]}.png\n') 183 | fid.close() 184 | 185 | fid = open(f'{out_dir}.obj', 'w') 186 | # fid = open(f'{out_dir}/{ind}.obj', 'w') 187 | fid.write('mtllib %s.mtl\n' % out_dir.split("/")[-1]) 188 | 189 | for pidx, p in enumerate(verts): 190 | pp = p 191 | fid.write('v %f %f %f\n' % (pp[0], pp[2], - pp[1])) 192 | 193 | for pidx, p in enumerate(uvs): 194 | pp = p 195 | fid.write('vt %f %f\n' % (pp[0], 1 - pp[1])) 196 | 197 | fid.write('usemtl material_0\n') 198 | for i, f in enumerate(faces): 199 | f1 = f + 1 200 | f2 = indices[i] + 1 201 | fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) 202 | fid.close() 203 | 204 | img = np.asarray(tex.data.cpu().numpy(), dtype=np.float32) 205 | mask = np.sum(img.astype(float), axis=-1, keepdims=True) 206 | mask = (mask <= 3.0).astype(float) 207 | kernel = np.ones((3, 3), 'uint8') 208 | dilate_img = cv2.dilate(img, kernel, iterations=1) 209 | img = img * (1 - mask) + dilate_img * mask 210 | img = img.clip(0, 255).astype(np.uint8) 211 | 212 | cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]]) 213 | # cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]]) 214 | -------------------------------------------------------------------------------- /pipelines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from libs.base_utils import do_resize_content 3 | from imagedream.ldm.util import ( 4 | instantiate_from_config, 5 | get_obj_from_str, 6 | ) 7 | from omegaconf import OmegaConf 8 | from PIL import Image 9 | import PIL 10 | import rembg 11 | class TwoStagePipeline(object): 12 | def __init__( 13 | self, 14 | stage1_model_config, 15 | stage2_model_config, 16 | stage1_sampler_config, 17 | stage2_sampler_config, 18 | device="cuda", 19 | dtype=torch.float16, 20 | resize_rate=1, 21 | ) -> None: 22 | """ 23 | only for two stage generate process. 24 | - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config 25 | - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config 26 | """ 27 | self.resize_rate = resize_rate 28 | 29 | self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model) 30 | self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False) 31 | self.stage1_model = self.stage1_model.to(device).to(dtype) 32 | 33 | self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model) 34 | sd = torch.load(stage2_model_config.resume, map_location="cpu") 35 | self.stage2_model.load_state_dict(sd, strict=False) 36 | self.stage2_model = self.stage2_model.to(device).to(dtype) 37 | 38 | self.stage1_model.device = device 39 | self.stage2_model.device = device 40 | self.device = device 41 | self.dtype = dtype 42 | self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)( 43 | self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params 44 | ) 45 | self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)( 46 | self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params 47 | ) 48 | 49 | def stage1_sample( 50 | self, 51 | pixel_img, 52 | prompt="3D assets", 53 | neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.", 54 | step=50, 55 | scale=5, 56 | ddim_eta=0.0, 57 | ): 58 | if type(pixel_img) == str: 59 | pixel_img = Image.open(pixel_img) 60 | 61 | if isinstance(pixel_img, Image.Image): 62 | if pixel_img.mode == "RGBA": 63 | background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) 64 | pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") 65 | else: 66 | pixel_img = pixel_img.convert("RGB") 67 | else: 68 | raise 69 | uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device) 70 | stage1_images = self.stage1_sampler.i2i( 71 | self.stage1_sampler.model, 72 | self.stage1_sampler.size, 73 | prompt, 74 | uc=uc, 75 | sampler=self.stage1_sampler.sampler, 76 | ip=pixel_img, 77 | step=step, 78 | scale=scale, 79 | batch_size=self.stage1_sampler.batch_size, 80 | ddim_eta=ddim_eta, 81 | dtype=self.stage1_sampler.dtype, 82 | device=self.stage1_sampler.device, 83 | camera=self.stage1_sampler.camera, 84 | num_frames=self.stage1_sampler.num_frames, 85 | pixel_control=(self.stage1_sampler.mode == "pixel"), 86 | transform=self.stage1_sampler.image_transform, 87 | offset_noise=self.stage1_sampler.offset_noise, 88 | ) 89 | 90 | stage1_images = [Image.fromarray(img) for img in stage1_images] 91 | stage1_images.pop(self.stage1_sampler.ref_position) 92 | return stage1_images 93 | 94 | def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50): 95 | if type(pixel_img) == str: 96 | pixel_img = Image.open(pixel_img) 97 | 98 | if isinstance(pixel_img, Image.Image): 99 | if pixel_img.mode == "RGBA": 100 | background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) 101 | pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") 102 | else: 103 | pixel_img = pixel_img.convert("RGB") 104 | else: 105 | raise 106 | stage2_images = self.stage2_sampler.i2iStage2( 107 | self.stage2_sampler.model, 108 | self.stage2_sampler.size, 109 | "3D assets", 110 | self.stage2_sampler.uc, 111 | self.stage2_sampler.sampler, 112 | pixel_images=stage1_images, 113 | ip=pixel_img, 114 | step=step, 115 | scale=scale, 116 | batch_size=self.stage2_sampler.batch_size, 117 | ddim_eta=0.0, 118 | dtype=self.stage2_sampler.dtype, 119 | device=self.stage2_sampler.device, 120 | camera=self.stage2_sampler.camera, 121 | num_frames=self.stage2_sampler.num_frames, 122 | pixel_control=(self.stage2_sampler.mode == "pixel"), 123 | transform=self.stage2_sampler.image_transform, 124 | offset_noise=self.stage2_sampler.offset_noise, 125 | ) 126 | stage2_images = [Image.fromarray(img) for img in stage2_images] 127 | return stage2_images 128 | 129 | def set_seed(self, seed): 130 | self.stage1_sampler.seed = seed 131 | self.stage2_sampler.seed = seed 132 | 133 | def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50): 134 | pixel_img = do_resize_content(pixel_img, self.resize_rate) 135 | stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step) 136 | stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step) 137 | 138 | return { 139 | "ref_img": pixel_img, 140 | "stage1_images": stage1_images, 141 | "stage2_images": stage2_images, 142 | } 143 | 144 | rembg_session = rembg.new_session() 145 | 146 | def expand_to_square(image, bg_color=(0, 0, 0, 0)): 147 | # expand image to 1:1 148 | width, height = image.size 149 | if width == height: 150 | return image 151 | new_size = (max(width, height), max(width, height)) 152 | new_image = Image.new("RGBA", new_size, bg_color) 153 | paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) 154 | new_image.paste(image, paste_position) 155 | return new_image 156 | 157 | def remove_background( 158 | image: PIL.Image.Image, 159 | rembg_session = None, 160 | force: bool = False, 161 | **rembg_kwargs, 162 | ) -> PIL.Image.Image: 163 | do_remove = True 164 | if image.mode == "RGBA" and image.getextrema()[3][0] < 255: 165 | # explain why current do not rm bg 166 | print("alhpa channl not enpty, skip remove background, using alpha channel as mask") 167 | background = Image.new("RGBA", image.size, (0, 0, 0, 0)) 168 | image = Image.alpha_composite(background, image) 169 | do_remove = False 170 | do_remove = do_remove or force 171 | if do_remove: 172 | image = rembg.remove(image, session=rembg_session, **rembg_kwargs) 173 | return image 174 | 175 | def do_resize_content(original_image: Image, scale_rate): 176 | # resize image content wile retain the original image size 177 | if scale_rate != 1: 178 | # Calculate the new size after rescaling 179 | new_size = tuple(int(dim * scale_rate) for dim in original_image.size) 180 | # Resize the image while maintaining the aspect ratio 181 | resized_image = original_image.resize(new_size) 182 | # Create a new image with the original size and black background 183 | padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) 184 | paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) 185 | padded_image.paste(resized_image, paste_position) 186 | return padded_image 187 | else: 188 | return original_image 189 | 190 | def add_background(image, bg_color=(255, 255, 255)): 191 | # given an RGBA image, alpha channel is used as mask to add background color 192 | background = Image.new("RGBA", image.size, bg_color) 193 | return Image.alpha_composite(background, image) 194 | 195 | 196 | def preprocess_image(image, background_choice, foreground_ratio, backgroud_color): 197 | """ 198 | input image is a pil image in RGBA, return RGB image 199 | """ 200 | print(background_choice) 201 | if background_choice == "Alpha as mask": 202 | background = Image.new("RGBA", image.size, (0, 0, 0, 0)) 203 | image = Image.alpha_composite(background, image) 204 | else: 205 | image = remove_background(image, rembg_session, force_remove=True) 206 | image = do_resize_content(image, foreground_ratio) 207 | image = expand_to_square(image) 208 | image = add_background(image, backgroud_color) 209 | return image.convert("RGB") 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio 2 | huggingface-hub 3 | diffusers==0.24.0 4 | einops==0.7.0 5 | Pillow==10.1.0 6 | transformers==4.27.1 7 | open-clip-torch==2.7.0 8 | opencv-contrib-python-headless==4.9.0.80 9 | opencv-python-headless==4.9.0.80 10 | omegaconf 11 | rembg 12 | pygltflib 13 | kiui 14 | trimesh 15 | xatlas 16 | pymeshlab 17 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from libs.base_utils import do_resize_content 3 | from imagedream.ldm.util import ( 4 | instantiate_from_config, 5 | get_obj_from_str, 6 | ) 7 | from omegaconf import OmegaConf 8 | from PIL import Image 9 | import numpy as np 10 | from inference import generate3d 11 | from huggingface_hub import hf_hub_download 12 | import json 13 | import argparse 14 | import shutil 15 | from model import CRM 16 | import PIL 17 | import rembg 18 | import os 19 | from pipelines import TwoStagePipeline 20 | 21 | rembg_session = rembg.new_session() 22 | 23 | def expand_to_square(image, bg_color=(0, 0, 0, 0)): 24 | # expand image to 1:1 25 | width, height = image.size 26 | if width == height: 27 | return image 28 | new_size = (max(width, height), max(width, height)) 29 | new_image = Image.new("RGBA", new_size, bg_color) 30 | paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) 31 | new_image.paste(image, paste_position) 32 | return new_image 33 | 34 | def remove_background( 35 | image: PIL.Image.Image, 36 | rembg_session = None, 37 | force: bool = False, 38 | **rembg_kwargs, 39 | ) -> PIL.Image.Image: 40 | do_remove = True 41 | if image.mode == "RGBA" and image.getextrema()[3][0] < 255: 42 | # explain why current do not rm bg 43 | print("alhpa channl not enpty, skip remove background, using alpha channel as mask") 44 | background = Image.new("RGBA", image.size, (0, 0, 0, 0)) 45 | image = Image.alpha_composite(background, image) 46 | do_remove = False 47 | do_remove = do_remove or force 48 | if do_remove: 49 | image = rembg.remove(image, session=rembg_session, **rembg_kwargs) 50 | return image 51 | 52 | def do_resize_content(original_image: Image, scale_rate): 53 | # resize image content wile retain the original image size 54 | if scale_rate != 1: 55 | # Calculate the new size after rescaling 56 | new_size = tuple(int(dim * scale_rate) for dim in original_image.size) 57 | # Resize the image while maintaining the aspect ratio 58 | resized_image = original_image.resize(new_size) 59 | # Create a new image with the original size and black background 60 | padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) 61 | paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) 62 | padded_image.paste(resized_image, paste_position) 63 | return padded_image 64 | else: 65 | return original_image 66 | 67 | def add_background(image, bg_color=(255, 255, 255)): 68 | # given an RGBA image, alpha channel is used as mask to add background color 69 | background = Image.new("RGBA", image.size, bg_color) 70 | return Image.alpha_composite(background, image) 71 | 72 | 73 | def preprocess_image(image, background_choice, foreground_ratio, backgroud_color): 74 | """ 75 | input image is a pil image in RGBA, return RGB image 76 | """ 77 | print(background_choice) 78 | if background_choice == "Alpha as mask": 79 | background = Image.new("RGBA", image.size, (0, 0, 0, 0)) 80 | image = Image.alpha_composite(background, image) 81 | else: 82 | image = remove_background(image, rembg_session, force_remove=True) 83 | image = do_resize_content(image, foreground_ratio) 84 | image = expand_to_square(image) 85 | image = add_background(image, backgroud_color) 86 | return image.convert("RGB") 87 | 88 | if __name__ == "__main__": 89 | 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument( 92 | "--inputdir", 93 | type=str, 94 | default="examples/kunkun.webp", 95 | help="dir for input image", 96 | ) 97 | parser.add_argument( 98 | "--scale", 99 | type=float, 100 | default=5.0, 101 | ) 102 | parser.add_argument( 103 | "--step", 104 | type=int, 105 | default=50, 106 | ) 107 | parser.add_argument( 108 | "--bg_choice", 109 | type=str, 110 | default="Auto Remove background", 111 | help="[Auto Remove background] or [Alpha as mask]", 112 | ) 113 | parser.add_argument( 114 | "--outdir", 115 | type=str, 116 | default="out/", 117 | ) 118 | args = parser.parse_args() 119 | 120 | 121 | img = Image.open(args.inputdir) 122 | img = preprocess_image(img, args.bg_choice, 1.0, (127, 127, 127)) 123 | os.makedirs(args.outdir, exist_ok=True) 124 | img.save(args.outdir+"preprocessed_image.png") 125 | 126 | crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth") 127 | specs = json.load(open("configs/specs_objaverse_total.json")) 128 | model = CRM(specs).to("cuda") 129 | model.load_state_dict(torch.load(crm_path, map_location = "cuda"), strict=False) 130 | 131 | stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config 132 | stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config 133 | stage2_sampler_config = stage2_config.sampler 134 | stage1_sampler_config = stage1_config.sampler 135 | 136 | stage1_model_config = stage1_config.models 137 | stage2_model_config = stage2_config.models 138 | 139 | xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth") 140 | pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth") 141 | stage1_model_config.resume = pixel_path 142 | stage2_model_config.resume = xyz_path 143 | 144 | pipeline = TwoStagePipeline( 145 | stage1_model_config, 146 | stage2_model_config, 147 | stage1_sampler_config, 148 | stage2_sampler_config, 149 | ) 150 | 151 | rt_dict = pipeline(img, scale=args.scale, step=args.step) 152 | stage1_images = rt_dict["stage1_images"] 153 | stage2_images = rt_dict["stage2_images"] 154 | np_imgs = np.concatenate(stage1_images, 1) 155 | np_xyzs = np.concatenate(stage2_images, 1) 156 | Image.fromarray(np_imgs).save(args.outdir+"pixel_images.png") 157 | Image.fromarray(np_xyzs).save(args.outdir+"xyz_images.png") 158 | 159 | glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, "cuda") 160 | shutil.copy(obj_path, args.outdir+"output3d.zip") -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | training script for imagedream 3 | - the config system is similar with stable diffusion ldm code base(using omigaconf, yaml; target, params initialization, etc.) 4 | - the training code base is similar with unidiffuser training code base using accelerate 5 | 6 | """ 7 | 8 | from omegaconf import OmegaConf 9 | import argparse 10 | from pathlib import Path 11 | from torch.utils.data import DataLoader 12 | import os.path as osp 13 | import numpy as np 14 | import os 15 | import torch 16 | from PIL import Image 17 | import numpy as np 18 | import wandb 19 | from libs.base_utils import get_data_generator, PrintContext 20 | from libs.base_utils import ( 21 | setup, 22 | instantiate_from_config, 23 | dct2str, 24 | add_prefix, 25 | get_obj_from_str, 26 | ) 27 | from absl import logging 28 | from einops import rearrange 29 | from imagedream.camera_utils import get_camera 30 | from libs.sample import ImageDreamDiffusion 31 | from rich import print 32 | 33 | 34 | def train(config, unk): 35 | # using pipeline to extract models 36 | accelerator, device = setup(config, unk) 37 | with PrintContext(f"{'access STAT':-^50}", accelerator.is_main_process): 38 | print(accelerator.state) 39 | dtype = { 40 | "fp16": torch.float16, 41 | "fp32": torch.float32, 42 | "no": torch.float32, 43 | "bf16": torch.bfloat16, 44 | }[accelerator.state.mixed_precision] 45 | 46 | num_frames = config.num_frames 47 | 48 | ################## load models ################## 49 | model_config = config.models.config 50 | model_config = OmegaConf.load(model_config) 51 | model = instantiate_from_config(model_config.model) 52 | state_dict = torch.load(config.models.resume, map_location="cpu") 53 | 54 | print(model.load_state_dict(state_dict, strict=False)) 55 | print("loaded model from {}".format(config.models.resume)) 56 | 57 | latest_step = 0 58 | if config.get("resume", False): 59 | print("resuming from specified workdir") 60 | ckpts = os.listdir(config.ckpt_root) 61 | if len(ckpts) == 0: 62 | print("no ckpt found") 63 | else: 64 | latest_ckpt = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))[-1] 65 | latest_step = int(latest_ckpt.split("-")[-1]) 66 | print("loadding ckpt from ", osp.join(config.ckpt_root, latest_ckpt)) 67 | unet_state_dict = torch.load( 68 | osp.join(config.ckpt_root, latest_ckpt), map_location="cpu" 69 | ) 70 | print(model.model.load_state_dict(unet_state_dict, strict=False)) 71 | 72 | elif config.models.get("resume_unet", None) is not None: 73 | unet_state_dict = torch.load(config.models.resume_unet, map_location="cpu") 74 | print(model.model.load_state_dict(unet_state_dict, strict=False)) 75 | print(f"______ load unet from {config.models.resume_unet} ______") 76 | model.to(device) 77 | model.device = device 78 | model.clip_model.device = device 79 | 80 | ################# setup optimizer ################# 81 | from torch.optim import AdamW 82 | from accelerate.utils import DummyOptim 83 | 84 | optimizer_cls = ( 85 | AdamW 86 | if accelerator.state.deepspeed_plugin is None 87 | or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config 88 | else DummyOptim 89 | ) 90 | optimizer = optimizer_cls(model.model.parameters(), **config.optimizer) 91 | 92 | ################# prepare datasets ################# 93 | dataset = instantiate_from_config(config.train_data) 94 | eval_dataset = instantiate_from_config(config.eval_data) 95 | in_the_wild_images = ( 96 | instantiate_from_config(config.in_the_wild_images) 97 | if config.get("in_the_wild_images", None) is not None 98 | else None 99 | ) 100 | 101 | dl_config = config.dataloader 102 | dataloader = DataLoader(dataset, **dl_config, batch_size=config.batch_size) 103 | 104 | ( 105 | model, 106 | optimizer, 107 | dataloader, 108 | ) = accelerator.prepare(model, optimizer, dataloader) 109 | 110 | generator = get_data_generator(dataloader, accelerator.is_main_process, "train") 111 | if config.get("sampler", None) is not None: 112 | sampler_cls = get_obj_from_str(config.sampler.target) 113 | sampler = sampler_cls(model, device, dtype, **config.sampler.params) 114 | else: 115 | sampler = ImageDreamDiffusion( 116 | model, 117 | mode=config.mode, 118 | num_frames=num_frames, 119 | device=device, 120 | dtype=dtype, 121 | camera_views=dataset.camera_views, 122 | offset_noise=config.get("offset_noise", False), 123 | ref_position=dataset.ref_position, 124 | random_background=dataset.random_background, 125 | resize_rate=dataset.resize_rate, 126 | ) 127 | 128 | ################# evaluation code ################# 129 | def evaluation(): 130 | return_ls = [] 131 | for i in range( 132 | accelerator.process_index, len(eval_dataset), accelerator.num_processes 133 | ): 134 | cond = eval_dataset[i]["cond"] 135 | 136 | images = sampler.diffuse("3D assets.", cond, n_test=2) 137 | images = np.concatenate(images, 0) 138 | images = [Image.fromarray(images)] 139 | return_ls.append(dict(images=images, ident=eval_dataset[i]["ident"])) 140 | return return_ls 141 | 142 | def evaluation2(): 143 | # eval for common used in the wild image 144 | return_ls = [] 145 | in_the_wild_images.init_item() 146 | for i in range( 147 | accelerator.process_index, 148 | len(in_the_wild_images), 149 | accelerator.num_processes, 150 | ): 151 | cond = in_the_wild_images[i]["cond"] 152 | images = sampler.diffuse("3D assets.", cond, n_test=2) 153 | images = np.concatenate(images, 0) 154 | images = [Image.fromarray(images)] 155 | return_ls.append(dict(images=images, ident=in_the_wild_images[i]["ident"])) 156 | return return_ls 157 | 158 | if latest_step == 0: 159 | global_step = 0 160 | total_step = 0 161 | log_step = 0 162 | eval_step = 0 163 | save_step = 0 164 | else: 165 | global_step = latest_step // config.total_batch_size 166 | total_step = latest_step 167 | log_step = latest_step + config.log_interval 168 | eval_step = latest_step + config.eval_interval 169 | save_step = latest_step + config.save_interval 170 | 171 | unet = model.model 172 | while True: 173 | item = next(generator) 174 | unet.train() 175 | bs = item["clip_cond"].shape[0] 176 | BS = bs * num_frames 177 | item["clip_cond"] = item["clip_cond"].to(device).to(dtype) 178 | item["vae_cond"] = item["vae_cond"].to(device).to(dtype) 179 | camera_input = item["cameras"].to(device) 180 | camera_input = camera_input.reshape((BS, camera_input.shape[-1])) 181 | 182 | gd_type = config.get("gd_type", "pixel") 183 | if gd_type == "pixel": 184 | item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype) 185 | gd = item["target_images_vae"] 186 | elif gd_type == "xyz": 187 | item["target_images_xyz_vae"] = ( 188 | item["target_images_xyz_vae"].to(device).to(dtype) 189 | ) 190 | gd = item["target_images_xyz_vae"] 191 | elif gd_type == "fusechannel": 192 | item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype) 193 | item["target_images_xyz_vae"] = ( 194 | item["target_images_xyz_vae"].to(device).to(dtype) 195 | ) 196 | gd = torch.cat( 197 | (item["target_images_vae"], item["target_images_xyz_vae"]), dim=0 198 | ) 199 | else: 200 | raise NotImplementedError 201 | 202 | with torch.no_grad(), accelerator.autocast("cuda"): 203 | ip_embed = model.clip_model.encode_image_with_transformer(item["clip_cond"]) 204 | ip_ = ip_embed.repeat_interleave(num_frames, dim=0) 205 | 206 | ip_img = model.get_first_stage_encoding( 207 | model.encode_first_stage(item["vae_cond"]) 208 | ) 209 | 210 | gd = rearrange(gd, "B F C H W -> (B F) C H W") 211 | 212 | latent_target_images = model.get_first_stage_encoding( 213 | model.encode_first_stage(gd) 214 | ) 215 | 216 | if gd_type == "fusechannel": 217 | latent_target_images = rearrange( 218 | latent_target_images, "(B F) C H W -> B F C H W", B=bs * 2 219 | ) 220 | image_latent, xyz_latent = torch.chunk(latent_target_images, 2) 221 | fused_channel_latent = torch.cat((image_latent, xyz_latent), dim=-3) 222 | latent_target_images = rearrange( 223 | fused_channel_latent, "B F C H W -> (B F) C H W" 224 | ) 225 | 226 | if item.get("captions", None) is not None: 227 | caption_ls = np.array(item["caption"]).T.reshape((-1, BS)).squeeze() 228 | prompt_cond = model.get_learned_conditioning(caption_ls) 229 | elif item.get("caption", None) is not None: 230 | prompt_cond = model.get_learned_conditioning(item["caption"]) 231 | prompt_cond = prompt_cond.repeat_interleave(num_frames, dim=0) 232 | else: 233 | prompt_cond = model.get_learned_conditioning(["3D assets."]).repeat( 234 | BS, 1, 1 235 | ) 236 | condition = { 237 | "context": prompt_cond, 238 | "ip": ip_, 239 | "ip_img": ip_img, 240 | "camera": camera_input, 241 | } 242 | 243 | with torch.autocast("cuda"), accelerator.accumulate(model): 244 | time_steps = torch.randint(0, model.num_timesteps, (BS,), device=device) 245 | noise = torch.randn_like(latent_target_images, device=device) 246 | # noise_img, _ = torch.chunk(noise, 2, dim=1) 247 | # noise = torch.cat((noise_img, noise_img), dim=1) 248 | x_noisy = model.q_sample(latent_target_images, time_steps, noise) 249 | output = unet(x_noisy, time_steps, **condition, num_frames=num_frames) 250 | reshaped_pred = output.reshape(bs, num_frames, *output.shape[1:]).permute( 251 | 1, 0, 2, 3, 4 252 | ) 253 | reshaped_noise = noise.reshape(bs, num_frames, *noise.shape[1:]).permute( 254 | 1, 0, 2, 3, 4 255 | ) 256 | true_pred = reshaped_pred[: num_frames - 1] 257 | fake_pred = reshaped_pred[num_frames - 1 :] 258 | true_noise = reshaped_noise[: num_frames - 1] 259 | fake_noise = reshaped_noise[num_frames - 1 :] 260 | loss = ( 261 | torch.nn.functional.mse_loss(true_noise, true_pred) 262 | + torch.nn.functional.mse_loss(fake_noise, fake_pred) * 0 263 | ) 264 | 265 | accelerator.backward(loss) 266 | optimizer.step() 267 | optimizer.zero_grad() 268 | global_step += 1 269 | 270 | total_step = global_step * config.total_batch_size 271 | if total_step > log_step: 272 | metrics = dict( 273 | loss=accelerator.gather(loss.detach().mean()).mean().item(), 274 | scale=( 275 | accelerator.scaler.get_scale() 276 | if accelerator.scaler is not None 277 | else -1 278 | ), 279 | ) 280 | log_step += config.log_interval 281 | if accelerator.is_main_process: 282 | logging.info(dct2str(dict(step=total_step, **metrics))) 283 | wandb.log(add_prefix(metrics, "train"), step=total_step) 284 | 285 | if total_step > save_step and accelerator.is_main_process: 286 | logging.info("saving done") 287 | torch.save( 288 | unet.state_dict(), osp.join(config.ckpt_root, f"unet-{total_step}") 289 | ) 290 | save_step += config.save_interval 291 | logging.info("save done") 292 | 293 | if total_step > eval_step: 294 | logging.info("evaluationing") 295 | unet.eval() 296 | return_ls = evaluation() 297 | cur_eval_base = osp.join(config.eval_root, f"{total_step:07d}") 298 | os.makedirs(cur_eval_base, exist_ok=True) 299 | for item in return_ls: 300 | for i, im in enumerate(item["images"]): 301 | im.save( 302 | osp.join( 303 | cur_eval_base, 304 | f"{item['ident']}-{i:03d}-{accelerator.process_index}-.png", 305 | ) 306 | ) 307 | 308 | return_ls2 = evaluation2() 309 | cur_eval_base = osp.join(config.eval_root2, f"{total_step:07d}") 310 | os.makedirs(cur_eval_base, exist_ok=True) 311 | for item in return_ls2: 312 | for i, im in enumerate(item["images"]): 313 | im.save( 314 | osp.join( 315 | cur_eval_base, 316 | f"{item['ident']}-{i:03d}-{accelerator.process_index}-inthewild.png", 317 | ) 318 | ) 319 | eval_step += config.eval_interval 320 | logging.info("evaluation done") 321 | 322 | accelerator.wait_for_everyone() 323 | if total_step > config.max_step: 324 | break 325 | 326 | 327 | if __name__ == "__main__": 328 | # load config from config path, then merge with cli args 329 | parser = argparse.ArgumentParser() 330 | parser.add_argument( 331 | "--config", type=str, default="configs/nf7_v3_SNR_rd_size_stroke.yaml" 332 | ) 333 | parser.add_argument( 334 | "--logdir", type=str, default="train_logs", help="the dir to put logs" 335 | ) 336 | parser.add_argument( 337 | "--resume_workdir", type=str, default=None, help="specify to do resume" 338 | ) 339 | args, unk = parser.parse_known_args() 340 | print(args, unk) 341 | config = OmegaConf.load(args.config) 342 | if args.resume_workdir is not None: 343 | assert osp.exists(args.resume_workdir), f"{args.resume_workdir} not exists" 344 | config.config.workdir = args.resume_workdir 345 | config.config.resume = True 346 | OmegaConf.set_struct(config, True) # prevent adding new keys 347 | cli_conf = OmegaConf.from_cli(unk) 348 | config = OmegaConf.merge(config, cli_conf) 349 | config = config.config 350 | OmegaConf.set_struct(config, False) 351 | config.logdir = args.logdir 352 | config.config_name = Path(args.config).stem 353 | 354 | train(config, unk) 355 | -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/000.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/001.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/002.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/003.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/004.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/005.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_000.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_001.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_002.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_003.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_004.png -------------------------------------------------------------------------------- /train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/train_examples/0011662ee0fc4b4481bfd28314d154c1/xyz_new_005.png -------------------------------------------------------------------------------- /train_examples/caption.csv: -------------------------------------------------------------------------------- 1 | id,caption 2 | 0011662ee0fc4b4481bfd28314d154c1,"A 3D model of a pink, black, and purple tower-like structure with psychedelic elements, resembling a hat, sculpture, and robot." -------------------------------------------------------------------------------- /train_stage2.py: -------------------------------------------------------------------------------- 1 | """ 2 | training script for imagedream 3 | - the config system is similar with stable diffusion ldm code base(using omigaconf, yaml; target, params initialization, etc.) 4 | - the training code base is similar with unidiffuser training code base using accelerate 5 | 6 | concat channel as input, pred xyz value mapped pixedl as groundtruth 7 | """ 8 | from omegaconf import OmegaConf 9 | import argparse 10 | import datetime 11 | from pathlib import Path 12 | from torch.utils.data import DataLoader 13 | import os.path as osp 14 | import numpy as np 15 | import os 16 | import torch 17 | import wandb 18 | from libs.base_utils import get_data_generator, PrintContext 19 | from libs.base_utils import setup, instantiate_from_config, dct2str, add_prefix, get_obj_from_str 20 | from absl import logging 21 | from einops import rearrange 22 | from libs.sample import ImageDreamDiffusion 23 | 24 | def train(config, unk): 25 | # using pipeline to extract models 26 | accelerator, device = setup(config, unk) 27 | with PrintContext(f"{'access STAT':-^50}", accelerator.is_main_process): 28 | print(accelerator.state) 29 | dtype = { 30 | "fp16": torch.float16, 31 | "fp32": torch.float32, 32 | "no": torch.float32, 33 | "bf16": torch.bfloat16, 34 | }[accelerator.state.mixed_precision] 35 | num_frames = config.num_frames 36 | 37 | 38 | ################## load models ################## 39 | model_config = config.models.config 40 | model_config = OmegaConf.load(model_config) 41 | model = instantiate_from_config(model_config.model) 42 | state_dict = torch.load(config.models.resume, map_location="cpu") 43 | 44 | 45 | model_in_conv_keys = ["model.diffusion_model.input_blocks.0.0.weight",] 46 | in_conv_keys = ["diffusion_model.input_blocks.0.0.weight"] 47 | 48 | 49 | def modify_keys(state_dict, in_keys, out_keys, cur_state_dict=None): 50 | print("this function only for fuse channel model") 51 | for in_key in in_keys: 52 | p = state_dict[in_key] 53 | if cur_state_dict is not None: 54 | p_cur = cur_state_dict[in_key] 55 | print(p_cur.shape, p.shape) 56 | if p_cur.shape == p.shape: 57 | print(f"skip {in_key} because of same shape") 58 | continue 59 | state_dict[in_key] = torch.cat([p, torch.zeros_like(p)], dim=1) * 0.5 60 | for out_key in out_keys: 61 | p = state_dict[out_key] 62 | if cur_state_dict is not None: 63 | p_cur = cur_state_dict[out_key] 64 | print(p_cur.shape, p.shape) 65 | if p_cur.shape == p.shape: 66 | print(f"skip {out_key} because of same shape") 67 | continue 68 | state_dict[out_key] = torch.cat([p, torch.zeros_like(p)], dim=0) 69 | return state_dict 70 | 71 | def wipe_keys(state_dict, keys): 72 | for key in keys: 73 | state_dict.pop(key) 74 | return state_dict 75 | 76 | unet_config = model_config.model.params.unet_config 77 | is_normal_inout_channel = not (unet_config.params.in_channels != 4 or unet_config.params.out_channels != 4) 78 | 79 | if not is_normal_inout_channel: 80 | state_dict = modify_keys(state_dict, model_in_conv_keys, [], model.state_dict()) 81 | 82 | print(model.load_state_dict(state_dict, strict=False)) 83 | print("loaded model from {}".format(config.models.resume)) 84 | if config.models.get("resume_unet", None) is not None: 85 | unet_state_dict = torch.load(config.models.resume_unet, map_location="cpu") 86 | if not is_normal_inout_channel: 87 | unet_state_dict = modify_keys(unet_state_dict, in_conv_keys, [], model.model.state_dict()) 88 | print(model.model.load_state_dict(unet_state_dict, strict= False)) 89 | print(f"______ load unet from {config.models.resume_unet} ______") 90 | model.to(device) 91 | model.device = device 92 | model.clip_model.device = device 93 | 94 | 95 | ################# setup optimizer ################# 96 | from torch.optim import AdamW 97 | from accelerate.utils import DummyOptim 98 | optimizer_cls = ( 99 | AdamW 100 | if accelerator.state.deepspeed_plugin is None 101 | or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config 102 | else DummyOptim 103 | ) 104 | optimizer = optimizer_cls(model.model.parameters(), **config.optimizer) 105 | 106 | ################# prepare datasets ################# 107 | dataset = instantiate_from_config(config.train_data) 108 | eval_dataset = instantiate_from_config(config.eval_data) 109 | 110 | dl_config = config.dataloader 111 | dataloader = DataLoader(dataset, **dl_config, batch_size=config.batch_size) 112 | 113 | model, optimizer, dataloader, = accelerator.prepare(model, optimizer, dataloader) 114 | 115 | generator = get_data_generator(dataloader, accelerator.is_main_process, "train") 116 | if config.get("sampler", None) is not None: 117 | sampler_cls = get_obj_from_str(config.sampler.target) 118 | sampler = sampler_cls(model, device, dtype, **config.sampler.params) 119 | else: 120 | sampler = ImageDreamDiffusion(model, config.mode, num_frames, device, dtype, dataset.camera_views, 121 | offset_noise=config.get("offset_noise", False), 122 | ref_position=dataset.ref_position, 123 | random_background=dataset.random_background, 124 | resize_rate=dataset.resize_rate) 125 | 126 | ################# evaluation code ################# 127 | def evaluation(): 128 | from PIL import Image 129 | import numpy as np 130 | return_ls = [] 131 | for i in range(accelerator.process_index, len(eval_dataset), accelerator.num_processes): 132 | item = eval_dataset[i] 133 | cond = item['cond'] 134 | images = sampler.diffuse("3D assets.", cond, 135 | pixel_images=item["cond_raw_images"], 136 | n_test=2) 137 | images = np.concatenate(images, 0) 138 | images = [Image.fromarray(images)] 139 | return_ls.append(dict(images=images, ident=eval_dataset[i]['ident'])) 140 | return return_ls 141 | 142 | 143 | global_step = 0 144 | total_step = 0 145 | log_step = 0 146 | eval_step = 0 147 | save_step = config.save_interval 148 | 149 | unet = model.model 150 | while True: 151 | item = next(generator) 152 | unet.train() 153 | bs = item["clip_cond"].shape[0] 154 | BS = bs * num_frames 155 | item["clip_cond"] = item["clip_cond"].to(device).to(dtype) 156 | item["vae_cond"] = item["vae_cond"].to(device).to(dtype) 157 | camera_input = item["cameras"].to(device) 158 | camera_input = camera_input.reshape((BS, camera_input.shape[-1])) 159 | 160 | gd_type = config.get("gd_type", "pixel") 161 | if gd_type == "pixel": 162 | item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype) 163 | gd = item["target_images_vae"] 164 | elif gd_type == "xyz": 165 | item["target_images_xyz_vae"] = item["target_images_xyz_vae"].to(device).to(dtype) 166 | item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype) 167 | gd = item["target_images_xyz_vae"] 168 | elif gd_type == "fusechannel": 169 | item["target_images_vae"] = item["target_images_vae"].to(device).to(dtype) 170 | item["target_images_xyz_vae"] = item["target_images_xyz_vae"].to(device).to(dtype) 171 | gd = torch.cat((item["target_images_vae"], item["target_images_xyz_vae"]), dim=0) 172 | else: 173 | raise NotImplementedError 174 | 175 | with torch.no_grad(), accelerator.autocast("cuda"): 176 | ip_embed = model.clip_model.encode_image_with_transformer(item["clip_cond"]) 177 | ip_ = ip_embed.repeat_interleave(num_frames, dim=0) 178 | 179 | ip_img = model.get_first_stage_encoding(model.encode_first_stage(item["vae_cond"])) 180 | 181 | gd = rearrange(gd, "B F C H W -> (B F) C H W") 182 | pixel_images = rearrange(item["target_images_vae"], "B F C H W -> (B F) C H W") 183 | latent_target_images = model.get_first_stage_encoding(model.encode_first_stage(gd)) 184 | pixel_images = model.get_first_stage_encoding(model.encode_first_stage(pixel_images)) 185 | 186 | if gd_type == "fusechannel": 187 | latent_target_images = rearrange(latent_target_images, "(B F) C H W -> B F C H W", B=bs * 2) 188 | image_latent, xyz_latent = torch.chunk(latent_target_images, 2) 189 | fused_channel_latent = torch.cat((image_latent, xyz_latent), dim=-3) 190 | latent_target_images = rearrange(fused_channel_latent, "B F C H W -> (B F) C H W") 191 | 192 | 193 | if item.get("captions", None) is not None: 194 | caption_ls = np.array(item["caption"]).T.reshape((-1, BS)).squeeze() 195 | prompt_cond = model.get_learned_conditioning(caption_ls) 196 | elif item.get("caption", None) is not None: 197 | prompt_cond = model.get_learned_conditioning(item["caption"]) 198 | prompt_cond = prompt_cond.repeat_interleave(num_frames, dim=0) 199 | else: 200 | prompt_cond = model.get_learned_conditioning(["3D assets."]).repeat(BS, 1, 1) 201 | condition = { 202 | "context": prompt_cond, 203 | "ip": ip_, 204 | # "ip_img": ip_img, 205 | "camera": camera_input, 206 | "pixel_images": pixel_images, 207 | } 208 | 209 | with torch.autocast("cuda"), accelerator.accumulate(model): 210 | time_steps = torch.randint(0, model.num_timesteps, (BS,), device=device) 211 | noise = torch.randn_like(latent_target_images, device=device) 212 | x_noisy = model.q_sample(latent_target_images, time_steps, noise) 213 | output = unet(x_noisy, time_steps, **condition, num_frames=num_frames) 214 | loss = torch.nn.functional.mse_loss(noise, output) 215 | 216 | accelerator.backward(loss) 217 | optimizer.step() 218 | optimizer.zero_grad() 219 | global_step += 1 220 | 221 | 222 | 223 | total_step = global_step * config.total_batch_size 224 | if total_step > log_step: 225 | metrics = dict( 226 | loss = accelerator.gather(loss.detach().mean()).mean().item(), 227 | scale = accelerator.scaler.get_scale() if accelerator.scaler is not None else -1 228 | ) 229 | log_step += config.log_interval 230 | if accelerator.is_main_process: 231 | logging.info(dct2str(dict(step=total_step, **metrics))) 232 | wandb.log(add_prefix(metrics, 'train'), step=total_step) 233 | 234 | if total_step > save_step and accelerator.is_main_process: 235 | logging.info("saving done") 236 | torch.save(unet.state_dict(), osp.join(config.ckpt_root, f"unet-{total_step}")) 237 | save_step += config.save_interval 238 | logging.info("save done") 239 | 240 | if total_step > eval_step: 241 | logging.info("evaluationing") 242 | unet.eval() 243 | return_ls = evaluation() 244 | cur_eval_base = osp.join(config.eval_root, f"{total_step:07d}") 245 | os.makedirs(cur_eval_base, exist_ok=True) 246 | wandb_image_ls = [] 247 | for item in return_ls: 248 | for i, im in enumerate(item["images"]): 249 | im.save(osp.join(cur_eval_base, f"{item['ident']}-{i:03d}-{accelerator.process_index}-.png")) 250 | wandb_image_ls.append(wandb.Image(im, caption=f"{item['ident']}-{i:03d}-{accelerator.process_index}")) 251 | 252 | wandb.log({"eval_samples": wandb_image_ls}) 253 | eval_step += config.eval_interval 254 | logging.info("evaluation done") 255 | 256 | accelerator.wait_for_everyone() 257 | if total_step > config.max_step: 258 | break 259 | 260 | 261 | if __name__ == "__main__": 262 | # load config from config path, then merge with cli args 263 | parser = argparse.ArgumentParser() 264 | parser.add_argument( 265 | "--config", type=str, default="configs/nf7_v3_SNR_rd_size_stroke.yaml" 266 | ) 267 | parser.add_argument( 268 | "--logdir", type=str, default="train_logs", help="the dir to put logs" 269 | ) 270 | parser.add_argument( 271 | "--resume_workdir", type=str, default=None, help="specify to do resume" 272 | ) 273 | args, unk = parser.parse_known_args() 274 | print(args, unk) 275 | config = OmegaConf.load(args.config) 276 | if args.resume_workdir is not None: 277 | assert osp.exists(args.resume_workdir), f"{args.resume_workdir} not exists" 278 | config.config.workdir = args.resume_workdir 279 | config.config.resume = True 280 | OmegaConf.set_struct(config, True) # prevent adding new keys 281 | cli_conf = OmegaConf.from_cli(unk) 282 | config = OmegaConf.merge(config, cli_conf) 283 | config = config.config 284 | OmegaConf.set_struct(config, False) 285 | config.logdir = args.logdir 286 | config.config_name = Path(args.config).stem 287 | 288 | train(config, unk) 289 | 290 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/CRM/4964e36a593070a3045eb0f300935a771ff5c172/util/__init__.py -------------------------------------------------------------------------------- /util/flexicubes_geometry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | from util.flexicubes import FlexiCubes # replace later 11 | # from dmtet import sdf_reg_loss_batch 12 | import torch.nn.functional as F 13 | 14 | def get_center_boundary_index(grid_res, device): 15 | v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device) 16 | v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True 17 | center_indices = torch.nonzero(v.reshape(-1)) 18 | 19 | v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False 20 | v[:2, ...] = True 21 | v[-2:, ...] = True 22 | v[:, :2, ...] = True 23 | v[:, -2:, ...] = True 24 | v[:, :, :2] = True 25 | v[:, :, -2:] = True 26 | boundary_indices = torch.nonzero(v.reshape(-1)) 27 | return center_indices, boundary_indices 28 | 29 | ############################################################################### 30 | # Geometry interface 31 | ############################################################################### 32 | class FlexiCubesGeometry(object): 33 | def __init__( 34 | self, grid_res=64, scale=2.0, device='cuda', renderer=None, 35 | render_type='neural_render', args=None): 36 | super(FlexiCubesGeometry, self).__init__() 37 | self.grid_res = grid_res 38 | self.device = device 39 | self.args = args 40 | self.fc = FlexiCubes(device, weight_scale=0.5) 41 | self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) 42 | if isinstance(scale, list): 43 | self.verts[:, 0] = self.verts[:, 0] * scale[0] 44 | self.verts[:, 1] = self.verts[:, 1] * scale[1] 45 | self.verts[:, 2] = self.verts[:, 2] * scale[1] 46 | else: 47 | self.verts = self.verts * scale 48 | 49 | all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) 50 | self.all_edges = torch.unique(all_edges, dim=0) 51 | 52 | # Parameters used for fix boundary sdf 53 | self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device) 54 | self.renderer = renderer 55 | self.render_type = render_type 56 | 57 | def getAABB(self): 58 | return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values 59 | 60 | def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): 61 | if indices is None: 62 | indices = self.indices 63 | 64 | verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res, 65 | beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], 66 | gamma_f=weight_n[:, 20], training=is_training 67 | ) 68 | return verts, faces, v_reg_loss 69 | 70 | 71 | def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): 72 | return_value = dict() 73 | if self.render_type == 'neural_render': 74 | tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh( 75 | mesh_v_nx3.unsqueeze(dim=0), 76 | mesh_f_fx3.int(), 77 | camera_mv_bx4x4, 78 | mesh_v_nx3.unsqueeze(dim=0), 79 | resolution=resolution, 80 | device=self.device, 81 | hierarchical_mask=hierarchical_mask 82 | ) 83 | 84 | return_value['tex_pos'] = tex_pos 85 | return_value['mask'] = mask 86 | return_value['hard_mask'] = hard_mask 87 | return_value['rast'] = rast 88 | return_value['v_pos_clip'] = v_pos_clip 89 | return_value['mask_pyramid'] = mask_pyramid 90 | return_value['depth'] = depth 91 | else: 92 | raise NotImplementedError 93 | 94 | return return_value 95 | 96 | def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): 97 | # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 98 | v_list = [] 99 | f_list = [] 100 | n_batch = v_deformed_bxnx3.shape[0] 101 | all_render_output = [] 102 | for i_batch in range(n_batch): 103 | verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) 104 | v_list.append(verts_nx3) 105 | f_list.append(faces_fx3) 106 | render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) 107 | all_render_output.append(render_output) 108 | 109 | # Concatenate all render output 110 | return_keys = all_render_output[0].keys() 111 | return_value = dict() 112 | for k in return_keys: 113 | value = [v[k] for v in all_render_output] 114 | return_value[k] = value 115 | # We can do concatenation outside of the render 116 | return return_value 117 | -------------------------------------------------------------------------------- /util/renderer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import nvdiffrast.torch as dr 5 | from util.flexicubes_geometry import FlexiCubesGeometry 6 | 7 | class Renderer(nn.Module): 8 | def __init__(self, tet_grid_size, camera_angle_num, scale, geo_type): 9 | super().__init__() 10 | 11 | self.tet_grid_size = tet_grid_size 12 | self.camera_angle_num = camera_angle_num 13 | self.scale = scale 14 | self.geo_type = geo_type 15 | self.glctx = dr.RasterizeCudaContext() 16 | 17 | if self.geo_type == "flex": 18 | self.flexicubes = FlexiCubesGeometry(grid_res = self.tet_grid_size) 19 | 20 | def forward(self, data, sdf, deform, verts, tets, training=False, weight = None): 21 | 22 | results = {} 23 | 24 | deform = torch.tanh(deform) / self.tet_grid_size * self.scale / 0.95 25 | if self.geo_type == "flex": 26 | deform = deform *0.5 27 | 28 | v_deformed = verts + deform 29 | 30 | verts_list = [] 31 | faces_list = [] 32 | reg_list = [] 33 | n_shape = verts.shape[0] 34 | for i in range(n_shape): 35 | verts_i, faces_i, reg_i = self.flexicubes.get_mesh(v_deformed[i], sdf[i].squeeze(dim=-1), 36 | with_uv=False, indices=tets, weight_n=weight[i], is_training=training) 37 | 38 | verts_list.append(verts_i) 39 | faces_list.append(faces_i) 40 | reg_list.append(reg_i) 41 | verts = verts_list 42 | faces = faces_list 43 | 44 | flexicubes_surface_reg = torch.cat(reg_list).mean() 45 | flexicubes_weight_reg = (weight ** 2).mean() 46 | results["flex_surf_loss"] = flexicubes_surface_reg 47 | results["flex_weight_loss"] = flexicubes_weight_reg 48 | 49 | return results, verts, faces -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | 6 | # Reworked so this matches gluPerspective / glm::perspective, using fovy 7 | def perspective(fovx=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): 8 | # y = np.tan(fovy / 2) 9 | x = np.tan(fovx / 2) 10 | return torch.tensor([[1/x, 0, 0, 0], 11 | [ 0, -aspect/x, 0, 0], 12 | [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], 13 | [ 0, 0, -1, 0]], dtype=torch.float32, device=device) 14 | 15 | 16 | def translate(x, y, z, device=None): 17 | return torch.tensor([[1, 0, 0, x], 18 | [0, 1, 0, y], 19 | [0, 0, 1, z], 20 | [0, 0, 0, 1]], dtype=torch.float32, device=device) 21 | 22 | 23 | def rotate_x(a, device=None): 24 | s, c = np.sin(a), np.cos(a) 25 | return torch.tensor([[1, 0, 0, 0], 26 | [0, c, -s, 0], 27 | [0, s, c, 0], 28 | [0, 0, 0, 1]], dtype=torch.float32, device=device) 29 | 30 | 31 | def rotate_y(a, device=None): 32 | s, c = np.sin(a), np.cos(a) 33 | return torch.tensor([[ c, 0, s, 0], 34 | [ 0, 1, 0, 0], 35 | [-s, 0, c, 0], 36 | [ 0, 0, 0, 1]], dtype=torch.float32, device=device) 37 | 38 | 39 | def rotate_z(a, device=None): 40 | s, c = np.sin(a), np.cos(a) 41 | return torch.tensor([[c, -s, 0, 0], 42 | [s, c, 0, 0], 43 | [0, 0, 1, 0], 44 | [0, 0, 0, 1]], dtype=torch.float32, device=device) 45 | 46 | @torch.no_grad() 47 | def batch_random_rotation_translation(b, t, device=None): 48 | m = np.random.normal(size=[b, 3, 3]) 49 | m[:, 1] = np.cross(m[:, 0], m[:, 2]) 50 | m[:, 2] = np.cross(m[:, 0], m[:, 1]) 51 | m = m / np.linalg.norm(m, axis=2, keepdims=True) 52 | m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant') 53 | m[:, 3, 3] = 1.0 54 | m[:, :3, 3] = np.random.uniform(-t, t, size=[b, 3]) 55 | return torch.tensor(m, dtype=torch.float32, device=device) 56 | 57 | @torch.no_grad() 58 | def random_rotation_translation(t, device=None): 59 | m = np.random.normal(size=[3, 3]) 60 | m[1] = np.cross(m[0], m[2]) 61 | m[2] = np.cross(m[0], m[1]) 62 | m = m / np.linalg.norm(m, axis=1, keepdims=True) 63 | m = np.pad(m, [[0, 1], [0, 1]], mode='constant') 64 | m[3, 3] = 1.0 65 | m[:3, 3] = np.random.uniform(-t, t, size=[3]) 66 | return torch.tensor(m, dtype=torch.float32, device=device) 67 | 68 | 69 | @torch.no_grad() 70 | def random_rotation(device=None): 71 | m = np.random.normal(size=[3, 3]) 72 | m[1] = np.cross(m[0], m[2]) 73 | m[2] = np.cross(m[0], m[1]) 74 | m = m / np.linalg.norm(m, axis=1, keepdims=True) 75 | m = np.pad(m, [[0, 1], [0, 1]], mode='constant') 76 | m[3, 3] = 1.0 77 | m[:3, 3] = np.array([0,0,0]).astype(np.float32) 78 | return torch.tensor(m, dtype=torch.float32, device=device) 79 | 80 | 81 | def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 82 | return torch.sum(x*y, -1, keepdim=True) 83 | 84 | 85 | def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: 86 | return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN 87 | 88 | 89 | def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: 90 | return x / length(x, eps) 91 | 92 | 93 | def lr_schedule(iter, warmup_iter, scheduler_decay): 94 | if iter < warmup_iter: 95 | return iter / warmup_iter 96 | return max(0.0, 10 ** ( 97 | -(iter - warmup_iter) * scheduler_decay)) 98 | 99 | 100 | def trans_depth(depth): 101 | depth = depth[0].detach().cpu().numpy() 102 | valid = depth > 0 103 | depth[valid] -= depth[valid].min() 104 | depth[valid] = ((depth[valid] / depth[valid].max()) * 255) 105 | return depth.astype('uint8') 106 | 107 | 108 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): 109 | assert isinstance(input, torch.Tensor) 110 | if posinf is None: 111 | posinf = torch.finfo(input.dtype).max 112 | if neginf is None: 113 | neginf = torch.finfo(input.dtype).min 114 | assert nan == 0 115 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 116 | 117 | 118 | def load_item(filepath): 119 | with open(filepath, 'r') as f: 120 | items = [name.strip() for name in f.readlines()] 121 | return set(items) 122 | 123 | def load_prompt(filepath): 124 | uuid2prompt = {} 125 | with open(filepath, 'r') as f: 126 | for line in f.readlines(): 127 | list_line = line.split(',') 128 | uuid2prompt[list_line[0]] = ','.join(list_line[1:]).strip() 129 | return uuid2prompt 130 | 131 | def resize_and_center_image(image_tensor, scale=0.95, c = 0, shift = 0, rgb=False, aug_shift = 0): 132 | if scale == 1: 133 | return image_tensor 134 | B, C, H, W = image_tensor.shape 135 | new_H, new_W = int(H * scale), int(W * scale) 136 | resized_image = torch.nn.functional.interpolate(image_tensor, size=(new_H, new_W), mode='bilinear', align_corners=False).squeeze(0) 137 | background = torch.zeros_like(image_tensor) + c 138 | start_y, start_x = (H - new_H) // 2, (W - new_W) // 2 139 | if shift == 0: 140 | background[:, :, start_y:start_y + new_H, start_x:start_x + new_W] = resized_image 141 | else: 142 | for i in range(B): 143 | randx = random.randint(-shift, shift) 144 | randy = random.randint(-shift, shift) 145 | if rgb == True: 146 | if i == 0 or i==2 or i==4: 147 | randx = 0 148 | randy = 0 149 | background[i, :, start_y+randy:start_y + new_H+randy, start_x+randx:start_x + new_W+randx] = resized_image[i] 150 | if aug_shift == 0: 151 | return background 152 | for i in range(B): 153 | for j in range(C): 154 | background[i, j, :, :] += (random.random() - 0.5)*2 * aug_shift / 255 155 | return background 156 | 157 | def get_tri(triview_color, dim = 1, blender=True, c = 0, scale=0.95, shift = 0, fix = False, rgb=False, aug_shift = 0): 158 | # triview_color: [6,C,H,W] 159 | # rgb is useful when shift is not 0 160 | triview_color = resize_and_center_image(triview_color, scale=scale, c = c, shift=shift,rgb=rgb, aug_shift = aug_shift) 161 | if blender is False: 162 | triview_color0 = torch.rot90(triview_color[0],k=2,dims=[1,2]) 163 | triview_color1 = torch.rot90(triview_color[4],k=1,dims=[1,2]).flip(2).flip(1) 164 | triview_color2 = torch.rot90(triview_color[5],k=1,dims=[1,2]).flip(2) 165 | triview_color3 = torch.rot90(triview_color[3],k=2,dims=[1,2]).flip(2) 166 | triview_color4 = torch.rot90(triview_color[1],k=3,dims=[1,2]).flip(1) 167 | triview_color5 = torch.rot90(triview_color[2],k=3,dims=[1,2]).flip(1).flip(2) 168 | else: 169 | triview_color0 = torch.rot90(triview_color[2],k=2,dims=[1,2]) 170 | triview_color1 = torch.rot90(triview_color[4],k=0,dims=[1,2]).flip(2).flip(1) 171 | triview_color2 = torch.rot90(torch.rot90(triview_color[0],k=3,dims=[1,2]).flip(2), k=2,dims=[1,2]) 172 | triview_color3 = torch.rot90(torch.rot90(triview_color[5],k=2,dims=[1,2]).flip(2), k=2,dims=[1,2]) 173 | triview_color4 = torch.rot90(triview_color[1],k=2,dims=[1,2]).flip(1).flip(1).flip(2) 174 | triview_color5 = torch.rot90(triview_color[3],k=1,dims=[1,2]).flip(1).flip(2) 175 | if fix == True: 176 | triview_color0[1] = triview_color0[1] * 0 177 | triview_color0[2] = triview_color0[2] * 0 178 | triview_color3[1] = triview_color3[1] * 0 179 | triview_color3[2] = triview_color3[2] * 0 180 | 181 | triview_color1[0] = triview_color1[0] * 0 182 | triview_color1[1] = triview_color1[1] * 0 183 | triview_color4[0] = triview_color4[0] * 0 184 | triview_color4[1] = triview_color4[1] * 0 185 | 186 | triview_color2[0] = triview_color2[0] * 0 187 | triview_color2[2] = triview_color2[2] * 0 188 | triview_color5[0] = triview_color5[0] * 0 189 | triview_color5[2] = triview_color5[2] * 0 190 | color_tensor1_gt = torch.cat((triview_color0, triview_color1, triview_color2), dim=2) 191 | color_tensor2_gt = torch.cat((triview_color3, triview_color4, triview_color5), dim=2) 192 | color_tensor_gt = torch.cat((color_tensor1_gt, color_tensor2_gt), dim = dim) 193 | return color_tensor_gt 194 | 195 | --------------------------------------------------------------------------------